This implementation is based on the official implementation of the paper On the Variance of the Adaptive Learning Rate and Beyond.
We have implemented it in PyTorch as an extension to our AMSGrad implementation thus requiring only the modifications to be implemented.
Adam optimizer sometimes converges to a bad local optima during the initial stages of the training; especially when training transformers. Researches use warmups to counter this; for the the initial training steps (warm-up stage) they use a low learning rate. This paper identifies the problem to be the high variance of adaptive learning rate during initial stages of training, and counters it using a new rectification term to reduce variance.
The paper also evaluates two variance reduction mechanisms: * Adam-2k: Only compute the adaptive learning rate ($v_t$ in Adam) during the first 2k steps, without changing parameters or calculating momentum ($m_t$). * Adam-eps: Adam with large $\epsilon \approx 10^{-4}$.
Let $\sigma(g_1, …, g_t)$ and $\psi(g_1, …, g_t)$ be the functions to calculate momentum and adaptive learning rate. For Adam, they are
The distribution of exponential moving average can be approximated as a simple moving average. Here we are taking the simple moving average of the last $f(t,\beta_2)$ gradients. $f(t,\beta_2)$ satisfies the following, which gives,
From above we have where $g_i \sim \mathcal{N}(0, \sigma^2)$. Note that $sigma$ here is the standard deviation and different from $\sigma(.)$ for momentum.
Scaled inverse chi-squared is the distribution of squared inverse of mean of $p$ normal distributions. where $\rho = f(t,\beta_2)$.
They prove that variance of $\psi(.)$ decreases with $\rho$ when $\psi^2(.) \sim \text{Scale-inv} \mathcal{X}^2(\rho,\frac{1}{\sigma^2})$.
Therefore the variance is minimized at maximal $\rho$ which is $\rho_{\infty} = \frac{2}{1-\beta_2} - 1$. Let the minimum variance be $C_{\text{var}}$
In order to ensure that the adaptive learning rate $\psi(.)$ has consistent variance, we rectify the variance with $r$
They estimate $Var[\psi(.)] \approx \frac{Var[\psi^2(.)]}{4 \mathbb{E}[\psi^2(.)}$ based on first order expansion of $\sqrt{\psi^2(.)}$ 🤪 I didn’t get how it was derived.
From $\text{Scale-inv} \mathcal{X}^2$ distribution we have, which gives,
We have
where $C_{\text{var}}$ is $Var\big[\psi(.)\big]$ for $\rho_\infty$. Lt $\rho$ and step $t$ be $\rho_t$, and $r_t$ be the rectification term at step $t$.
This gives,
129import math
130from typing import Dict, Optional
131
132import torch
133
134from labml_nn.optimizers import WeightDecay
135from labml_nn.optimizers.amsgrad import AMSGrad
138class RAdam(AMSGrad):
params
is the list of parameterslr
is the learning rate $\alpha$betas
is a tuple of ($\beta_1$, $\beta_2$)eps
is $\hat{\epsilon}$ or $\epsilon$ based on optimized_update
weight_decay
is an instance of class WeightDecay
defined in __init__.py
amsgrad
is a flag indicating whether to use AMSGrad or fallback to plain Adamdegenerate_to_sgd
whether to use sgd when the rectification term $r_t is intractable.defaults
is a dictionary of default for group values.
This is useful when you want to extend the class RAdam
.145 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
146 weight_decay: WeightDecay = WeightDecay(),
147 optimized_update: bool = True,
148 amsgrad=False,
149 degenerated_to_sgd=True, defaults=None):
165 self.degenerated_to_sgd = degenerated_to_sgd
166 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupgrad
is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$param
is the parameter tensor $\theta_{t-1}$168 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
Calculate weight decay
179 grad = self.weight_decay(param, grad, group)
Get $m_t$ and $v_t$; i.e. $\sigma(.)$ and $\psi(.)$ without bias correction
182 m, v = self.get_mv(state, group, grad)
Calculate $t$ the number of optimizer steps
185 state['step'] += 1
Perform RAdam update
188 self.r_adam_update(state, group, param, m, v)
190 @staticmethod
191 def calc_rectification_term(beta2: float, step: int) -> Optional[float]:
$\beta_2^t$
197 beta2_t = beta2 ** step
199 rho_inf = 2 / (1 - beta2) - 1
201 rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)
$r_t$ is tractable when $\rho_t >= 4$. We are being a little more conservative since it’s an approximated value
205 if rho >= 5:
207 r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)
208 return math.sqrt(r2)
209 else:
210 return None
state
is the optimizer state of the parameter (tensor)group
stores optimizer attributes of the parameter groupparam
is the parameter tensor $\theta_{t-1}$m
and v
are the uncorrected first and second moments $m_t$ and $v_t$;
i.e. $\sigma(.)$ and $\psi(.)$ without bias correction212 def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
213 m: torch.Tensor, v: torch.Tensor):
Get $\beta_1$ and $\beta_2$
225 beta1, beta2 = group['betas']
Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$
227 bias_correction1 = 1 - beta1 ** state['step']
Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
229 bias_correction2 = 1 - beta2 ** state['step']
230
231 r = self.calc_rectification_term(beta2, state['step'])
Get learning rate
234 lr = self.get_lr(state, group)
If $r_t$ is intractable
237 if r is not None:
Whether to optimize the computation by combining scalar computations
239 if self.optimized_update:
Denominator $\sqrt{v_t} + \hat{\epsilon}$
241 denominator = v.sqrt().add_(group['eps'])
Step size $\alpha \sqrt{r_t} * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
243 step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1
Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
246 param.data.addcdiv_(m, denominator, value=-step_size)
Computation without optimization
248 else:
Denominator $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
250 denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
Step size $\frac{\alpha \sqrt{r_t}}{1-\beta_1^t}$
252 step_size = lr * r / bias_correction1
Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
255 param.data.addcdiv_(m, denominator, value=-step_size)
If $r_t$ is intractable do a SGD with momentum
258 elif self.degenerated_to_sgd:
Step size $\frac{\alpha}{1-\beta_1^t}$
260 step_size = lr / bias_correction1
Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \hat{m}_t$
263 param.data.add_(m, alpha=-step_size)
266def _test_rectification_term():
272 import matplotlib.pyplot as plt
273 import numpy as np
274
275 beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]
276 plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])
277 plt.legend(beta2)
278 plt.title("Optimizer")
279 plt.show()
280
281
282if __name__ == '__main__':
283 _test_rectification_term()