This is a PyTorch implementation of popular optimizer Adam from paper Adam: A Method for Stochastic Optimization.
Adam update is,
where , , and are scalar hyper parameters. and are first and second order moments. and are biased corrected moments. is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.
Effective step taken assuming is, This is bounded by, when and otherwise. And in most common scenarios,
40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecayWe extend the class GenericAdaptiveOptimizer
 defined in __init__.py
 to implement the Adam optimizer.
50class Adam(GenericAdaptiveOptimizer):params
 is the list of parameters lr
 is the learning rate  betas
 is a tuple of (, ) eps
 is  or  based on optimized_update
 weight_decay
 is an instance of class WeightDecay
 defined in __init__.py
 optimized_update
 is a flag whether to optimize the bias correction of the second moment  by doing it after adding  defaults
 is a dictionary of default for group values.  This is useful when you want to extend the class Adam
.58    def __init__(self, params,
59                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60                 weight_decay: WeightDecay = WeightDecay(),
61                 optimized_update: bool = True,
62                 defaults: Optional[Dict[str, Any]] = None):76        defaults = {} if defaults is None else defaults
77        defaults.update(weight_decay.defaults())
78        super().__init__(params, defaults, lr, betas, eps)
79
80        self.weight_decay = weight_decay
81        self.optimized_update = optimized_updatestate
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group param
 is the parameter tensor 83    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):This is the number of optimizer steps taken on the parameter,
93        state['step'] = 0Exponential moving average of gradients,
95        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)Exponential moving average of squared gradient values,
97        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)state
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group grad
 is the current gradient tensor  for the parameter 99    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):Get and
109        beta1, beta2 = group['betas']Get and
112        m, v = state['exp_avg'], state['exp_avg_sq']In-place calculation of
116        m.mul_(beta1).add_(grad, alpha=1 - beta1)In-place calculation of
119        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121        return m, vThis returns the modified learning rate based on the state. For Adam this is just the specified learning rate for the parameter group, .
123    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):131        return group['lr']state
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group param
 is the parameter tensor  m
 and v
 are the uncorrected first and second moments  and .This computes the following
Since , , and are scalars and others are tensors we modify this calculation to optimize the computation.
where is what we should specify as the hyper-parameter.
133    def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134                    m: torch.Tensor, v: torch.Tensor):Get and
166        beta1, beta2 = group['betas']Bias correction term for ,
168        bias_correction1 = 1 - beta1 ** state['step']Bias correction term for ,
170        bias_correction2 = 1 - beta2 ** state['step']Get learning rate
173        lr = self.get_lr(state, group)Whether to optimize the computation
176        if self.optimized_update:178            denominator = v.sqrt().add_(group['eps'])180            step_size = lr * math.sqrt(bias_correction2) / bias_correction1183            param.data.addcdiv_(m, denominator, value=-step_size)Computation without optimization
185        else:187            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])189            step_size = lr / bias_correction1192            param.data.addcdiv_(m, denominator, value=-step_size)state
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group grad
 is the current gradient tensor  for the parameter  param
 is the parameter tensor 194    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):Calculate weight decay
205        grad = self.weight_decay(param, grad, group)Get and
208        m, v = self.get_mv(state, group, grad)Increment the number of optimizer steps
211        state['step'] += 1Perform Adam update
214        self.adam_update(state, group, param, m, v)