This implementation is based on the official implementation of the paper On the Variance of the Adaptive Learning Rate and Beyond.
We have implemented it in PyTorch as an extension to our AMSGrad implementation thus requiring only the modifications to be implemented.
Adam optimizer sometimes converges to a bad local optima during the initial stages of the training; especially when training transformers. Researches use warmups to counter this; for the the initial training steps (warm-up stage) they use a low learning rate. This paper identifies the problem to be the high variance of adaptive learning rate during initial stages of training, and counters it using a new rectification term to reduce variance.
The paper also evaluates two variance reduction mechanisms: * Adam-2k: Only compute the adaptive learning rate ($v_t$ in Adam) during the first 2k steps, without changing parameters or calculating momentum ($m_t$). * Adam-eps: Adam with large $\epsilon \approx 10^{-4}$.
Let $\sigma(g_1, …, g_t)$ and $\psi(g_1, …, g_t)$ be the functions to calculate momentum and adaptive learning rate. For Adam, they are
The distribution of exponential moving average can be approximated as a simple moving average. Here we are taking the simple moving average of the last $f(t,\beta_2)$ gradients. $f(t,\beta_2)$ satisfies the following, which gives,
From above we have where $g_i \sim \mathcal{N}(0, \sigma^2)$. Note that $sigma$ here is the standard deviation and different from $\sigma(.)$ for momentum.
Scaled inverse chi-squared is the distribution of squared inverse of mean of $p$ normal distributions. where $\rho = f(t,\beta_2)$.
They prove that variance of $\psi(.)$ decreases with $\rho$ when $\psi^2(.) \sim \text{Scale-inv} \mathcal{X}^2(\rho,\frac{1}{\sigma^2})$.
Therefore the variance is minimized at maximal $\rho$ which is $\rho_{\infty} = \frac{2}{1-\beta_2} - 1$. Let the minimum variance be $C_{\text{var}}$
In order to ensure that the adaptive learning rate $\psi(.)$ has consistent variance, we rectify the variance with $r$
They estimate $Var[\psi(.)] \approx \frac{Var[\psi^2(.)]}{4 \mathbb{E}[\psi^2(.)}$ based on first order expansion of $\sqrt{\psi^2(.)}$ 🤪 I didn’t get how it was derived.
From $\text{Scale-inv} \mathcal{X}^2$ distribution we have, which gives,
We have
where $C_{\text{var}}$ is $Var\big[\psi(.)\big]$ for $\rho_\infty$. Lt $\rho$ and step $t$ be $\rho_t$, and $r_t$ be the rectification term at step $t$.
This gives,
129import math
130from typing import Dict, Optional
131
132import torch
133
134from labml_nn.optimizers import WeightDecay
135from labml_nn.optimizers.amsgrad import AMSGrad138class RAdam(AMSGrad):params is the list of parameterslr is the learning rate $\alpha$betas is a tuple of ($\beta_1$, $\beta_2$)eps is $\hat{\epsilon}$ or $\epsilon$ based on optimized_updateweight_decay is an instance of class WeightDecay defined in __init__.pyamsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adamdegenerate_to_sgd whether to use sgd when the rectification term $r_t is intractable.defaults is a dictionary of default for group values.
 This is useful when you want to extend the class RAdam.145    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
146                 weight_decay: WeightDecay = WeightDecay(),
147                 optimized_update: bool = True,
148                 amsgrad=False,
149                 degenerated_to_sgd=True, defaults=None):165        self.degenerated_to_sgd = degenerated_to_sgd
166        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor  $g_t$ for the parameter $\theta_{t-1}$param is the parameter tensor $\theta_{t-1}$168    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):Calculate weight decay
179        grad = self.weight_decay(param, grad, group)Get $m_t$ and $v_t$; i.e. $\sigma(.)$ and $\psi(.)$ without bias correction
182        m, v = self.get_mv(state, group, grad)Calculate $t$ the number of optimizer steps
185        state['step'] += 1Perform RAdam update
188        self.r_adam_update(state, group, param, m, v)190    @staticmethod
191    def calc_rectification_term(beta2: float, step: int) -> Optional[float]:$\beta_2^t$
197        beta2_t = beta2 ** step199        rho_inf = 2 / (1 - beta2) - 1201        rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)$r_t$ is tractable when $\rho_t >= 4$. We are being a little more conservative since it’s an approximated value
205        if rho >= 5:207            r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)
208            return math.sqrt(r2)
209        else:
210            return Nonestate is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor $\theta_{t-1}$m and v are the uncorrected first and second moments $m_t$ and $v_t$;
  i.e. $\sigma(.)$ and $\psi(.)$ without bias correction212    def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
213                      m: torch.Tensor, v: torch.Tensor):Get $\beta_1$ and $\beta_2$
225        beta1, beta2 = group['betas']Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$
227        bias_correction1 = 1 - beta1 ** state['step']Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
229        bias_correction2 = 1 - beta2 ** state['step']
230
231        r = self.calc_rectification_term(beta2, state['step'])Get learning rate
234        lr = self.get_lr(state, group)If $r_t$ is intractable
237        if r is not None:Whether to optimize the computation by combining scalar computations
239            if self.optimized_update:Denominator $\sqrt{v_t} + \hat{\epsilon}$
241                denominator = v.sqrt().add_(group['eps'])Step size $\alpha \sqrt{r_t} * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
243                step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
246                param.data.addcdiv_(m, denominator, value=-step_size)Computation without optimization
248            else:Denominator $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
250                denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])Step size $\frac{\alpha \sqrt{r_t}}{1-\beta_1^t}$
252                step_size = lr * r / bias_correction1Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
255                param.data.addcdiv_(m, denominator, value=-step_size)If $r_t$ is intractable do a SGD with momentum
258        elif self.degenerated_to_sgd:Step size $\frac{\alpha}{1-\beta_1^t}$
260            step_size = lr / bias_correction1Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \hat{m}_t$
263            param.data.add_(m, alpha=-step_size)266def _test_rectification_term():272    import matplotlib.pyplot as plt
273    import numpy as np
274
275    beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]
276    plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])
277    plt.legend(beta2)
278    plt.title("Optimizer")
279    plt.show()
280
281
282if __name__ == '__main__':
283    _test_rectification_term()