This is based from AdaBelief official implementation of the paper AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients.
This is implemented in PyTorch as an extension to RAdam.
The main difference between Adam optimizer and AdaBelief is that, how it calculates the adaptive learning rate; instead of dividing by the exponential moving average of square of the gradients, AdaBelief divides by the exponential mean of variance.
🤔 The paper calculates variance as , but I feel it should use the bias corrected momentum . I guess this doesn't affect things much because bias correction is after the initial training steps.
36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam45class AdaBelief(RAdam):params
 is the list of parameters lr
 is the learning rate  betas
 is a tuple of (, ) eps
 is  or  based on optimized_update
 weight_decay
 is an instance of class WeightDecay
 defined in __init__.py
 optimized_update
 is a flag whether to optimize the bias correction of the second moment  by doing it after adding  amsgrad
 is a flag indicating whether to use AMSGrad or fallback to plain Adam degenerate_to_sgd
 whether to use sgd when the rectification term  is intractable rectify
 is whether to use RAdam update defaults
 is a dictionary of default for group values.  This is useful when you want to extend the class AdaBelief
.52    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53                 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54                 degenerate_to_sgd=True,
55                 rectify=True, defaults=None):73        defaults = {} if defaults is None else defaults
74        super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75        self.rectify = rectifystate
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group param
 is the parameter tensor 77    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):85        state['step'] = 0Exponential moving average of gradient values
87        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)Exponential moving average of variance
89        state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)If amsgrad
 flag is True
 for this parameter group, we maintain the maximum of exponential moving average of variance 
93        if group['amsgrad']:Maintains max of all exp. moving avg. of sq. grad. values
95            state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)state
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group grad
 is the current gradient tensor  for the parameter 97    def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):Get and
107        beta1, beta2 = group['betas']Get and
110        m, s = state['exp_avg'], state['exp_avg_var']In-place calculation of
114        m.mul_(beta1).add_(grad, alpha=1 - beta1)Difference between gradient and momentum
116        grad_residual = grad - mIn-place calculation of
119        s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)If this parameter group is using amsgrad
 
122        if group['amsgrad']:Get .
124            s_max = state['max_exp_avg_var']Calculate .
126            torch.maximum(s_max, s, out=s_max)
127
128            return m, s_max
129        else:and otherwise
131            return m, sstate
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group grad
 is the current gradient tensor  for the parameter  param
 is the parameter tensor 133    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):Calculate weight decay
144        grad = self.weight_decay(param, grad, group)Get and
147        m, s = self.get_ms(state, group, grad)Increment the number of optimizer steps
150        state['step'] += 1
151
152        if not self.rectify:155            self.adam_update(state, group, param, m, s + group['eps'])
156        else: