diff --git a/labml_nn/optimizers/__init__.py b/labml_nn/optimizers/__init__.py index 8d045656..eb6db6ab 100644 --- a/labml_nn/optimizers/__init__.py +++ b/labml_nn/optimizers/__init__.py @@ -67,14 +67,13 @@ class WeightDecay: def defaults(self): return dict(weight_decay=self.weight_decay) - def __call__(self, param: torch.nn.Parameter, group: Dict[str, any]): - grad = param.grad.data - + def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]): if self.weight_decouple: if not self.absolute: param.data.mul_(1.0 - group['lr'] * group['weight_decay']) else: param.data.mul_(1.0 - group['weight_decay']) + return grad else: if group['weight_decay'] != 0: - grad.add_(param.data, alpha=group['weight_decay']) + return grad.add(param.data, alpha=group['weight_decay']) diff --git a/labml_nn/optimizers/ada_belief.py b/labml_nn/optimizers/ada_belief.py index 0d5dec8e..27d8c848 100644 --- a/labml_nn/optimizers/ada_belief.py +++ b/labml_nn/optimizers/ada_belief.py @@ -81,7 +81,7 @@ class AdaBelief(RAdam): return m, v def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): - self.weight_decay(param, group) + grad = self.weight_decay(param, grad, group) m, v = self.get_mv(state, group, grad) state['step'] += 1 diff --git a/labml_nn/optimizers/adam.py b/labml_nn/optimizers/adam.py index fedcd3f3..5998fd99 100644 --- a/labml_nn/optimizers/adam.py +++ b/labml_nn/optimizers/adam.py @@ -6,6 +6,33 @@ This is an implementation of popular optimizer *Adam* from paper We extend the class `GenericAdaptiveOptimizer` defined in [__init__.py](index.html) to implement the Adam optimizer. + +*Adam* update is, + +\begin{align} +m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\ +v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\ +\hat{m}_t &\leftarrow \frac{m_t}/{1-\beta_1^t} \\ +\hat{v}_t &\leftarrow \frac{v_t}/{1-\beta_2^t} \\ +\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} +\end{align} + +where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters. +$m_t$ and $v_t$ are first and second order moments. +$\hat{m}_t$ and $\hat{v}_t$ are biased corrected moments. +$\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter +that acts against variance in gradients. + +Effective step taken assuming $\epsilon = 0$ is, +$$\Delta t = \alpha \cdot \frac{\hat{m}_t}{\hat{v}_t}$$ +This is bounded by, +$$\vert \Delta t \vert \le \alpha \cdot \frac{1 - \beta_1}{\sqrt{1-\beta_2}}$$ +when $1-\beta_1 \gt \sqrt{1-\beta_2}$ +and +$$\vert \Delta t\vert \le \alpha$$ +otherwise. +And in most common scenarios, +$$\vert \Delta t \vert \approx \alpha$$ """ import math @@ -28,6 +55,7 @@ class Adam(GenericAdaptiveOptimizer): * `params` is the list of parameters * 'lr' is the learning rate $\alpha$ * `betas` is a tuple of ($\beta_1$, $\beta_2$) + * `eps` is $\hat{\epsilon}$ * `weight_decay` is an instance of class `WeightDecay` defined in [__init__.py](index.html) * `defaults` is a dictionary of default for group values. This is useful when you want to extend the class `Adam`. @@ -101,8 +129,6 @@ class Adam(GenericAdaptiveOptimizer): This computes the following \begin{align} - \hat{m}_t &\leftarrow \frac{m_t}/{1-\beta_1^t} \\ - \hat{v}_t &\leftarrow \frac{v_t}/{1-\beta_2^t} \\ \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \end{align} @@ -114,12 +140,12 @@ class Adam(GenericAdaptiveOptimizer): \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\ \theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot - \frac{m_t}{\sqrt{v_t} + \epsilon'} \\ + \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}} \\ \end{align} where - $$\epsilon` = (1-\beta_2^t) \epsilon \approx \epsilon$$ - since $\beta_2 \approx 1$ + $$\hat{\epsilon} = (1-\beta_2^t) \epsilon$$ + is what we should specify as the hyper-parameter. """ # Get $\beta_1$ and $\beta_2$ @@ -134,7 +160,7 @@ class Adam(GenericAdaptiveOptimizer): # $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$ step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1 # $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot - # \frac{m_t}{\sqrt{v_t} + \epsilon}$ + # \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$ param.data.addcdiv_(m, denominator, value=-step_size) def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): @@ -148,7 +174,7 @@ class Adam(GenericAdaptiveOptimizer): """ # Calculate weight decay - self.weight_decay(param, group) + grad = self.weight_decay(param, grad, group) # Get $m_t$ and $v_t$ m, v = self.get_mv(state, group, grad) diff --git a/labml_nn/optimizers/radam.py b/labml_nn/optimizers/radam.py index 662f5765..a9796209 100644 --- a/labml_nn/optimizers/radam.py +++ b/labml_nn/optimizers/radam.py @@ -19,7 +19,7 @@ class RAdam(AMSGrad): super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults) def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): - self.weight_decay(param, group) + grad = self.weight_decay(param, grad, group) m, v = self.get_mv(state, group, grad) state['step'] += 1