This is based from AdaBelief official implementation of the paper AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients.
This is implemented in PyTorch as an extension to RAdam.
The main difference between Adam optimizer and AdaBelief is that, how it calculates the adaptive learning rate; instead of dividing by the exponential moving average of square of the gradients, AdaBelief divides by the exponential mean of variance.
🤔 The paper calculates variance as $(g_t - m_t)^2$, but I feel it should use the bias corrected momentum $(g_t - \color{orange}{\hat{m}_t})^2$. I guess this doesn’t affect things much because bias correction is $\approx 1$ after the initial training steps.
36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam45class AdaBelief(RAdam):params is the list of parameterslr is the learning rate $\alpha$betas is a tuple of ($\beta_1$, $\beta_2$)eps is $\hat{\epsilon}$ or $\epsilon$ based on optimized_updateweight_decay is an instance of class WeightDecay defined in __init__.pyamsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adamdegenerate_to_sgd whether to use sgd when the rectification term $r_t is intractabledefaults is a dictionary of default for group values.
This is useful when you want to extend the class AdaBelief.52 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54 degenerate_to_sgd=True,
55 rectify=True, defaults=None):73 defaults = {} if defaults is None else defaults
74 super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75 self.rectify = rectifystate is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor $\theta_{t-1}$77 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):85 state['step'] = 0Exponential moving average of gradient values
87 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)Exponential moving average of variance
89 state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)If amsgrad flag is True for this parameter group, we maintain the maximum of
exponential moving average of variance
93 if group['amsgrad']:Maintains max of all exp. moving avg. of sq. grad. values
95 state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$97 def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):Get $\beta_1$ and $\beta_2$
107 beta1, beta2 = group['betas']Get $m_{t-1}$ and $s_{t-1}$
110 m, s = state['exp_avg'], state['exp_avg_var']In-place calculation of $m_t$
114 m.mul_(beta1).add_(grad, alpha=1 - beta1)Difference between gradient and momentum
116 grad_residual = grad - mIn-place calculation of $s_t$
119 s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)If this parameter group is using amsgrad
122 if group['amsgrad']:Get $\max(s_1, s_2, …, s_{t-1})$.
124 s_max = state['max_exp_avg_var']Calculate $\max(s_1, s_2, …, s_{t-1}, s_t)$.
126 torch.maximum(s_max, s, out=s_max)
127
128 return m, s_max
129 else:$m_t$ and $s_t$ otherwise
131 return m, sstate is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$param is the parameter tensor $\theta_{t-1}$133 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):Calculate weight decay
144 grad = self.weight_decay(param, grad, group)Get $m_t$ and $v_t$
147 m, s = self.get_ms(state, group, grad)Increment $t$ the number of optimizer steps
150 state['step'] += 1
151
152 if not self.rectify:Perform Adam update, defined in adam.py, with
$\color{cyan}{s_t} + \color{red}{\epsilon}$ in place of $v_t$.
155 self.adam_update(state, group, param, m, s + group['eps'])
156 else: