This is based from AdaBelief official implementation of the paper AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients.
This is implemented in PyTorch as an extension to RAdam.
The main difference between Adam optimizer and AdaBelief is that, how it calculates the adaptive learning rate; instead of dividing by the exponential moving average of square of the gradients, AdaBelief divides by the exponential mean of variance.
🤔 The paper calculates variance as $(g_t - m_t)^2$, but I feel it should use the bias corrected momentum $(g_t - \color{orange}{\hat{m}_t})^2$. I guess this doesn’t affect things much because bias correction is $\approx 1$ after the initial training steps.
36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam
45class AdaBelief(RAdam):
params
is the list of parameterslr
is the learning rate $\alpha$betas
is a tuple of ($\beta_1$, $\beta_2$)eps
is $\hat{\epsilon}$ or $\epsilon$ based on optimized_update
weight_decay
is an instance of class WeightDecay
defined in __init__.py
amsgrad
is a flag indicating whether to use AMSGrad or fallback to plain Adamdegenerate_to_sgd
whether to use sgd when the rectification term $r_t is intractabledefaults
is a dictionary of default for group values.
This is useful when you want to extend the class AdaBelief
.52 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54 degenerate_to_sgd=True,
55 rectify=True, defaults=None):
73 defaults = {} if defaults is None else defaults
74 super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75 self.rectify = rectify
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupparam
is the parameter tensor $\theta_{t-1}$77 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
85 state['step'] = 0
Exponential moving average of gradient values
87 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
Exponential moving average of variance
89 state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
If amsgrad
flag is True
for this parameter group, we maintain the maximum of
exponential moving average of variance
93 if group['amsgrad']:
Maintains max of all exp. moving avg. of sq. grad. values
95 state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupgrad
is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$97 def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
Get $\beta_1$ and $\beta_2$
107 beta1, beta2 = group['betas']
Get $m_{t-1}$ and $s_{t-1}$
110 m, s = state['exp_avg'], state['exp_avg_var']
In-place calculation of $m_t$
114 m.mul_(beta1).add_(grad, alpha=1 - beta1)
Difference between gradient and momentum
116 grad_residual = grad - m
In-place calculation of $s_t$
119 s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
If this parameter group is using amsgrad
122 if group['amsgrad']:
Get $\max(s_1, s_2, …, s_{t-1})$.
124 s_max = state['max_exp_avg_var']
Calculate $\max(s_1, s_2, …, s_{t-1}, s_t)$.
126 torch.maximum(s_max, s, out=s_max)
127
128 return m, s_max
129 else:
$m_t$ and $s_t$ otherwise
131 return m, s
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupgrad
is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$param
is the parameter tensor $\theta_{t-1}$133 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
Calculate weight decay
144 grad = self.weight_decay(param, grad, group)
Get $m_t$ and $v_t$
147 m, s = self.get_ms(state, group, grad)
Increment $t$ the number of optimizer steps
150 state['step'] += 1
151
152 if not self.rectify:
Perform Adam update, defined in adam.py
, with
$\color{cyan}{s_t} + \color{red}{\epsilon}$ in place of $v_t$.
155 self.adam_update(state, group, param, m, s + group['eps'])
156 else: