This is a PyTorch implementation of popular optimizer Adam from paper Adam: A Method for Stochastic Optimization.
Adam update is,
where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters. $m_t$ and $v_t$ are first and second order moments. $\hat{m}_t$ and $\hat{v}_t$ are biased corrected moments. $\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.
Effective step taken assuming $\epsilon = 0$ is, This is bounded by, when $1-\beta_1 \gt \sqrt{1-\beta_2}$ and otherwise. And in most common scenarios,
40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecayWe extend the class GenericAdaptiveOptimizer defined in __init__.py
to implement the Adam optimizer.
50class Adam(GenericAdaptiveOptimizer):params is the list of parameterslr is the learning rate $\alpha$betas is a tuple of ($\beta_1$, $\beta_2$)eps is $\hat{\epsilon}$ or $\epsilon$ based on optimized_updateweight_decay is an instance of class WeightDecay defined in __init__.pyoptimized_update is a flag whether to optimize the bias correction of the second moment
  by doing it after adding $\epsilon$defaults is a dictionary of default for group values.
 This is useful when you want to extend the class Adam.58    def __init__(self, params,
59                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60                 weight_decay: WeightDecay = WeightDecay(),
61                 optimized_update: bool = True,
62                 defaults: Optional[Dict[str, Any]] = None):76        defaults = {} if defaults is None else defaults
77        defaults.update(weight_decay.defaults())
78        super().__init__(params, defaults, lr, betas, eps)
79
80        self.weight_decay = weight_decay
81        self.optimized_update = optimized_updatestate is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor $\theta_{t-1}$83    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):This is the number of optimizer steps taken on the parameter, $t$
93        state['step'] = 0Exponential moving average of gradients, $m_t$
95        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)Exponential moving average of squared gradient values, $v_t$
97        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$99    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):Get $\beta_1$ and $\beta_2$
109        beta1, beta2 = group['betas']Get $m_{t-1}$ and $v_{t-1}$
112        m, v = state['exp_avg'], state['exp_avg_sq']In-place calculation of $m_t$
116        m.mul_(beta1).add_(grad, alpha=1 - beta1)In-place calculation of $v_t$
119        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121        return m, vThis returns the modified learning rate based on the state. For Adam this is just the specified learning rate for the parameter group, $\alpha$.
123    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):131        return group['lr']state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor $\theta_{t-1}$m and v are the uncorrected first and second moments $m_t$ and $v_t$.This computes the following
Since $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalars and others are tensors we modify this calculation to optimize the computation.
where is what we should specify as the hyper-parameter.
133    def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134                    m: torch.Tensor, v: torch.Tensor):Get $\beta_1$ and $\beta_2$
166        beta1, beta2 = group['betas']Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$
168        bias_correction1 = 1 - beta1 ** state['step']Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
170        bias_correction2 = 1 - beta2 ** state['step']Get learning rate
173        lr = self.get_lr(state, group)Whether to optimize the computation
176        if self.optimized_update:$\sqrt{v_t} + \hat{\epsilon}$
178            denominator = v.sqrt().add_(group['eps'])$\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
180            step_size = lr * math.sqrt(bias_correction2) / bias_correction1$\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
183            param.data.addcdiv_(m, denominator, value=-step_size)Computation without optimization
185        else:$\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
187            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])$\frac{\alpha}{1-\beta_1^t}$
189            step_size = lr / bias_correction1$\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
192            param.data.addcdiv_(m, denominator, value=-step_size)state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor  $g_t$ for the parameter $\theta_{t-1}$param is the parameter tensor $\theta_{t-1}$194    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):Calculate weight decay
205        grad = self.weight_decay(param, grad, group)Get $m_t$ and $v_t$
208        m, v = self.get_mv(state, group, grad)Increment $t$ the number of optimizer steps
211        state['step'] += 1Perform Adam update
214        self.adam_update(state, group, param, m, v)