Rectified Adam (RAdam) optimizer

This implementation is based on the official implementation of the paper On the Variance of the Adaptive Learning Rate and Beyond.

We have implemented it as an extension to our AMSGrad implementation thus requiring only the modifications to be implemented.

Adam optimizer sometimes converges to a bad local optima during the initial stages of the training; especially when training transformers. Researches use warmups to counter this; for the the initial training steps (warm-up stage) they use a low learning rate. This paper identifies the problem to be the high variance of adaptive learning rate during initial stages of training, and counters it using a new rectification term to reduce variance.

The paper also evaluates two variance reduction mechanisms: * Adam-2k: Only compute the adaptive learning rate ($v_t$ in Adam) during the first 2k steps, without changing parameters or calculating momentum ($m_t$). * Adam-eps: Adam with large $\epsilon \approx 10^{-4}$.

Rectified Adam

Let $\sigma(g_1, …, g_t)$ and $\psi(g_1, …, g_t)$ be the functions to calculate momentum and adaptive learning rate. For Adam, they are

Exponential moving average as simple moving average

The distribution of exponential moving average can be approximated as a simple moving average. Here we are taking the simple moving average of the last $f(t,\beta_2)$ gradients. $f(t,\beta_2)$ satisfies the following, which gives,

Scaled inverse chi-squared

From above we have where $g_i \sim \mathcal{N}(0, \sigma^2)$. Note that $sigma$ here is the standard deviation and different from $\sigma(.)$ for momentum.

Scaled inverse chi-squared is the distribution of squared inverse of mean of $p$ normal distributions. where $\rho = f(t,\beta_2)$.

Rectification

They prove that variance of $\psi(.)$ decreases with $\rho$ when $\psi^2(.) \sim \text{Scale-inv} \mathcal{X}^2(\rho,\frac{1}{\sigma^2})$.

Therefore the variance is minimized at maximal $\rho$ which is $\rho_{\infty} = \frac{2}{1-\beta_2} - 1$. Let the minimum variance be $C_{\text{var}}$

In order to ensure that the adaptive learning rate $\psi(.)$ has consistent variance, we rectify the variance with $r$

Approximating $Var[\psi(.)]$

They estimate $Var[\psi(.)] \approx \frac{Var[\psi^2(.)]}{4 \mathbb{E}[\psi^2(.)}$ based on first order expansion of $\sqrt{\psi^2(.)}$ 🤪 I didn’t get how it was derived.

From $\text{Scale-inv} \mathcal{X}^2$ distribution we have, which gives,

Rectification term

We have

where $C_{\text{var}}$ is $Var\big[\psi(.)\big]$ for $\rho_\infty$. Lt $\rho$ and step $t$ be $\rho_t$, and $r_t$ be the rectification term at step $t$.

This gives,

128import math
129from typing import Dict, Optional
130
131import torch
132
133from labml_nn.optimizers import WeightDecay
134from labml_nn.optimizers.amsgrad import AMSGrad

Rectified Adam Optimizer

This class extends from AMSAdam optimizer defined in amsadam.py.

137class RAdam(AMSGrad):

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate $\alpha$
  • betas is a tuple of ($\beta_1$, $\beta_2$)
  • eps is $\hat{\epsilon}$ or $\epsilon$ based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • ‘optimized_update’ is a flag whether to optimize the bias correction of the second moment by doing it after adding $\epsilon$
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • degenerate_to_sgd whether to use sgd when the rectification term $r_t is intractable.
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class RAdam.
144    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
145                 weight_decay: WeightDecay = WeightDecay(),
146                 optimized_update: bool = True,
147                 amsgrad=False,
148                 degenerated_to_sgd=True, defaults=None):
164        self.degenerated_to_sgd = degenerated_to_sgd
165        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
  • param is the parameter tensor $\theta_{t-1}$
167    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

Calculate weight decay

178        grad = self.weight_decay(param, grad, group)

Get $m_t$ and $v_t$; i.e. $\sigma(.)$ and $\psi(.)$ without bias correction

181        m, v = self.get_mv(state, group, grad)

Calculate $t$ the number of optimizer steps

184        state['step'] += 1

Perform RAdam update

187        self.r_adam_update(state, group, param, m, v)

Calculate rectification term $r_t$

189    @staticmethod
190    def calc_rectification_term(beta2: float, step: int) -> Optional[float]:

$\beta_2^t$

196        beta2_t = beta2 ** step

198        rho_inf = 2 / (1 - beta2) - 1

200        rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)

$r_t$ is tractable when $\rho_t >= 4$. We are being a little more conservative since it’s an approximated value

204        if rho >= 5:

206            r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)
207            return math.sqrt(r2)
208        else:
209            return None

Do the RAdam parameter update

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor $\theta_{t-1}$
  • m and v are the uncorrected first and second moments $m_t$ and $v_t$; i.e. $\sigma(.)$ and $\psi(.)$ without bias correction
211    def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
212                      m: torch.Tensor, v: torch.Tensor):

Get $\beta_1$ and $\beta_2$

224        beta1, beta2 = group['betas']

Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$

226        bias_correction1 = 1 - beta1 ** state['step']

Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$

228        bias_correction2 = 1 - beta2 ** state['step']
229
230        r = self.calc_rectification_term(beta2, state['step'])

Get learning rate

233        lr = self.get_lr(state, group)

If $r_t$ is intractable

236        if r is not None:

Whether to optimize the computation by combining scalar computations

238            if self.optimized_update:

Denominator $\sqrt{v_t} + \hat{\epsilon}$

240                denominator = v.sqrt().add_(group['eps'])

Step size $\alpha \sqrt{r_t} * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$

242                step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1

Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$

245                param.data.addcdiv_(m, denominator, value=-step_size)

Computation without optimization

247            else:

Denominator $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$

249                denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

Step size $\frac{\alpha \sqrt{r_t}}{1-\beta_1^t}$

251                step_size = lr * r / bias_correction1

Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$

254                param.data.addcdiv_(m, denominator, value=-step_size)

If $r_t$ is intractable do a SGD with momentum

257        elif self.degenerated_to_sgd:

Step size $\frac{\alpha}{1-\beta_1^t}$

259            step_size = lr / bias_correction1

Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \hat{m}_t$

262            param.data.add_(m, alpha=-step_size)

Plot $r_t$ against $t$ for various $\beta_2$

Plot of r_t

265def _test_rectification_term():
271    import matplotlib.pyplot as plt
272    import numpy as np
273
274    beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]
275    plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])
276    plt.legend(beta2)
277    plt.title("Optimizer")
278    plt.show()
279
280
281if __name__ == '__main__':
282    _test_rectification_term()