This implementation is based on the official implementation of the paper On the Variance of the Adaptive Learning Rate and Beyond.
We have implemented it in PyTorch as an extension to our AMSGrad implementation thus requiring only the modifications to be implemented.
Adam optimizer sometimes converges to a bad local optima during the initial stages of the training; especially when training transformers. Researches use warmups to counter this; for the the initial training steps (warm-up stage) they use a low learning rate. This paper identifies the problem to be the high variance of adaptive learning rate during initial stages of training, and counters it using a new rectification term to reduce variance.
The paper also evaluates two variance reduction mechanisms: Adam-2k: Only compute the adaptive learning rate ( in Adam) during the first 2k steps, without changing parameters or calculating momentum (). Adam-eps: Adam with large .
Let and be the functions to calculate momentum and adaptive learning rate. For Adam, they are
The distribution of exponential moving average can be approximated as a simple moving average.
Here we are taking the simple moving average of the last gradients. satisfies the following,
which gives,
From above we have where . Note that here is the standard deviation and different from for momentum.
Scaled inverse chi-squared is the distribution of squared inverse of mean of normal distributions. where .
They prove that variance of decreases with when .
Therefore the variance is minimized at maximal which is . Let the minimum variance be
In order to ensure that the adaptive learning rate has consistent variance, we rectify the variance with
They estimate based on first order expansion of 🤪 I didn't get how it was derived.
From distribution we have,
which gives,
We have
where is for . Lt and step be , and be the rectification term at step .
This gives,
139import math
140from typing import Dict, Optional
141
142import torch
143
144from labml_nn.optimizers import WeightDecay
145from labml_nn.optimizers.amsgrad import AMSGrad148class RAdam(AMSGrad):params
 is the list of parameters lr
 is the learning rate  betas
 is a tuple of (, ) eps
 is  or  based on optimized_update
 weight_decay
 is an instance of class WeightDecay
 defined in __init__.py
 optimized_update
 is a flag whether to optimize the bias correction of the second moment  by doing it after adding  amsgrad
 is a flag indicating whether to use AMSGrad or fallback to plain Adam degenerate_to_sgd
 whether to use sgd when the rectification term  is intractable. defaults
 is a dictionary of default for group values.  This is useful when you want to extend the class RAdam
.155    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
156                 weight_decay: WeightDecay = WeightDecay(),
157                 optimized_update: bool = True,
158                 amsgrad=False,
159                 degenerated_to_sgd=True, defaults=None):175        self.degenerated_to_sgd = degenerated_to_sgd
176        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)state
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group grad
 is the current gradient tensor  for the parameter  param
 is the parameter tensor 178    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):Calculate weight decay
189        grad = self.weight_decay(param, grad, group)Get and ; i.e. and without bias correction
192        m, v = self.get_mv(state, group, grad)Calculate the number of optimizer steps
195        state['step'] += 1Perform RAdam update
198        self.r_adam_update(state, group, param, m, v)200    @staticmethod
201    def calc_rectification_term(beta2: float, step: int) -> Optional[float]:207        beta2_t = beta2 ** step209        rho_inf = 2 / (1 - beta2) - 1211        rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)is tractable when . We are being a little more conservative since it's an approximated value
215        if rho >= 5:217            r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)
218            return math.sqrt(r2)
219        else:
220            return Nonestate
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group param
 is the parameter tensor  m
 and v
 are the uncorrected first and second moments  and ;  i.e.  and  without bias correction222    def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
223                      m: torch.Tensor, v: torch.Tensor):Get and
235        beta1, beta2 = group['betas']Bias correction term for ,
237        bias_correction1 = 1 - beta1 ** state['step']Bias correction term for ,
239        bias_correction2 = 1 - beta2 ** state['step']
240
241        r = self.calc_rectification_term(beta2, state['step'])Get learning rate
244        lr = self.get_lr(state, group)If is intractable
247        if r is not None:Whether to optimize the computation by combining scalar computations
249            if self.optimized_update:Denominator
251                denominator = v.sqrt().add_(group['eps'])Step size
253                step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1Update parameters
256                param.data.addcdiv_(m, denominator, value=-step_size)Computation without optimization
258            else:Denominator
260                denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])Step size
262                step_size = lr * r / bias_correction1Update parameters
265                param.data.addcdiv_(m, denominator, value=-step_size)If is intractable do a SGD with momentum
268        elif self.degenerated_to_sgd:Step size
270            step_size = lr / bias_correction1Update parameters
273            param.data.add_(m, alpha=-step_size)276def _test_rectification_term():282    import matplotlib.pyplot as plt
283    import numpy as np
284
285    beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]
286    plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])
287    plt.legend(beta2)
288    plt.title("Optimizer")
289    plt.show()
290
291
292if __name__ == '__main__':
293    _test_rectification_term()