This implementation is based on the official implementation of the paper On the Variance of the Adaptive Learning Rate and Beyond.
We have implemented it as an extension to our AMSGrad implementation thus requiring only the modifications to be implemented.
Adam optimizer sometimes converges to a bad local optima during the initial stages of the training; especially when training transformers. Researches use warmups to counter this; for the the initial training steps (warm-up stage) they use a low learning rate. This paper identifies the problem to be the high variance of adaptive learning rate during initial stages of training, and counters it using a new rectification term to reduce variance.
The paper also evaluates two variance reduction mechanisms: * Adam-2k: Only compute the adaptive learning rate ($v_t$ in Adam) during the first 2k steps, without changing parameters or calculating momentum ($m_t$). * Adam-eps: Adam with large $\epsilon \approx 10^{-4}$.
Let $\sigma(g_1, …, g_t)$ and $\psi(g_1, …, g_t)$ be the functions to calculate momentum and adaptive learning rate. For Adam, they are
The distribution of exponential moving average can be approximated as a simple moving average. Here we are taking the simple moving average of the last $f(t,\beta_2)$ gradients. $f(t,\beta_2)$ satisfies the following, which gives,
From above we have where $g_i \sim \mathcal{N}(0, \sigma^2)$. Note that $sigma$ here is the standard deviation and different from $\sigma(.)$ for momentum.
Scaled inverse chi-squared is the distribution of squared inverse of mean of $p$ normal distributions. where $\rho = f(t,\beta_2)$.
They prove that variance of $\psi(.)$ decreases with $\rho$ when $\psi^2(.) \sim \text{Scale-inv} \mathcal{X}^2(\rho,\frac{1}{\sigma^2})$.
Therefore the variance is minimized at maximal $\rho$ which is $\rho_{\infty} = \frac{2}{1-\beta_2} - 1$. Let the minimum variance be $C_{\text{var}}$
In order to ensure that the adaptive learning rate $\psi(.)$ has consistent variance, we rectify the variance with $r$
They estimate $Var[\psi(.)] \approx \frac{Var[\psi^2(.)]}{4 \mathbb{E}[\psi^2(.)}$ based on first order expansion of $\sqrt{\psi^2(.)}$ 🤪 I didn’t get how it was derived.
From $\text{Scale-inv} \mathcal{X}^2$ distribution we have, which gives,
We have
where $C_{\text{var}}$ is $Var\big[\psi(.)\big]$ for $\rho_\infty$. Lt $\rho$ and step $t$ be $\rho_t$, and $r_t$ be the rectification term at step $t$.
This gives,
128import math
129from typing import Dict, Optional
130
131import torch
132
133from labml_nn.optimizers import WeightDecay
134from labml_nn.optimizers.amsgrad import AMSGrad
137class RAdam(AMSGrad):
params
is the list of parameterslr
is the learning rate $\alpha$betas
is a tuple of ($\beta_1$, $\beta_2$)eps
is $\hat{\epsilon}$ or $\epsilon$ based on optimized_update
weight_decay
is an instance of class WeightDecay
defined in __init__.py
amsgrad
is a flag indicating whether to use AMSGrad or fallback to plain Adamdegenerate_to_sgd
whether to use sgd when the rectification term $r_t is intractable.defaults
is a dictionary of default for group values.
This is useful when you want to extend the class RAdam
.144 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
145 weight_decay: WeightDecay = WeightDecay(),
146 optimized_update: bool = True,
147 amsgrad=False,
148 degenerated_to_sgd=True, defaults=None):
164 self.degenerated_to_sgd = degenerated_to_sgd
165 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupgrad
is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$param
is the parameter tensor $\theta_{t-1}$167 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
Calculate weight decay
178 grad = self.weight_decay(param, grad, group)
Get $m_t$ and $v_t$; i.e. $\sigma(.)$ and $\psi(.)$ without bias correction
181 m, v = self.get_mv(state, group, grad)
Calculate $t$ the number of optimizer steps
184 state['step'] += 1
Perform RAdam update
187 self.r_adam_update(state, group, param, m, v)
189 @staticmethod
190 def calc_rectification_term(beta2: float, step: int) -> Optional[float]:
$\beta_2^t$
196 beta2_t = beta2 ** step
198 rho_inf = 2 / (1 - beta2) - 1
200 rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)
$r_t$ is tractable when $\rho_t >= 4$. We are being a little more conservative since it’s an approximated value
204 if rho >= 5:
206 r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)
207 return math.sqrt(r2)
208 else:
209 return None
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupparam
is the parameter tensor $\theta_{t-1}$m
and v
are the uncorrected first and second moments $m_t$ and $v_t$;
i.e. $\sigma(.)$ and $\psi(.)$ without bias correction211 def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
212 m: torch.Tensor, v: torch.Tensor):
Get $\beta_1$ and $\beta_2$
224 beta1, beta2 = group['betas']
Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$
226 bias_correction1 = 1 - beta1 ** state['step']
Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
228 bias_correction2 = 1 - beta2 ** state['step']
229
230 r = self.calc_rectification_term(beta2, state['step'])
Get learning rate
233 lr = self.get_lr(state, group)
If $r_t$ is intractable
236 if r is not None:
Whether to optimize the computation by combining scalar computations
238 if self.optimized_update:
Denominator $\sqrt{v_t} + \hat{\epsilon}$
240 denominator = v.sqrt().add_(group['eps'])
Step size $\alpha \sqrt{r_t} * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
242 step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1
Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
245 param.data.addcdiv_(m, denominator, value=-step_size)
Computation without optimization
247 else:
Denominator $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
249 denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
Step size $\frac{\alpha \sqrt{r_t}}{1-\beta_1^t}$
251 step_size = lr * r / bias_correction1
Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
254 param.data.addcdiv_(m, denominator, value=-step_size)
If $r_t$ is intractable do a SGD with momentum
257 elif self.degenerated_to_sgd:
Step size $\frac{\alpha}{1-\beta_1^t}$
259 step_size = lr / bias_correction1
Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \hat{m}_t$
262 param.data.add_(m, alpha=-step_size)
265def _test_rectification_term():
271 import matplotlib.pyplot as plt
272 import numpy as np
273
274 beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]
275 plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])
276 plt.legend(beta2)
277 plt.title("Optimizer")
278 plt.show()
279
280
281if __name__ == '__main__':
282 _test_rectification_term()