📚 adam notes

This commit is contained in:
Varuna Jayasiri
2020-12-05 11:21:11 +05:30
parent 27b3922647
commit 874c238651
5 changed files with 133 additions and 26 deletions

View File

@ -1,3 +1,9 @@
"""
# Optimizers
* [Adam](adam.html)
"""
from typing import Dict, Tuple from typing import Dict, Tuple
import torch import torch
@ -21,7 +27,7 @@ class GenericAdaptiveOptimizer(Optimizer):
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter): def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
pass pass
def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor): def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):
pass pass
@torch.no_grad() @torch.no_grad()
@ -45,7 +51,7 @@ class GenericAdaptiveOptimizer(Optimizer):
if len(state) == 0: if len(state) == 0:
self.init_state(state, group, p) self.init_state(state, group, p)
self.calculate(state, group, grad, p) self.step_param(state, group, grad, p)
return loss return loss

View File

@ -50,16 +50,16 @@ class AdaBelief(RAdam):
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerated_to_sgd, defaults) super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerated_to_sgd, defaults)
self.rectify = rectify self.rectify = rectify
def init_state(self, state: Dict[str, any], group: Dict[str, any], p: nn.Parameter): def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_var'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
if group['amsgrad']: if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_var'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor): def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
@ -80,7 +80,7 @@ class AdaBelief(RAdam):
else: else:
return m, v return m, v
def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
self.weight_decay(param, group) self.weight_decay(param, group)
m, v = self.get_mv(state, group, grad) m, v = self.get_mv(state, group, grad)
state['step'] += 1 state['step'] += 1

View File

@ -1,5 +1,15 @@
"""
# Adam Optimizer
This is an implementation of popular optimizer *Adam* from paper
[Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980v9).
We extend the class `GenericAdaptiveOptimizer` defined in [__init__.py](index.html)
to implement the Adam optimizer.
"""
import math import math
from typing import Dict, Any from typing import Dict, Any, Tuple, Optional
import torch import torch
from torch import nn from torch import nn
@ -8,52 +18,143 @@ from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
class Adam(GenericAdaptiveOptimizer): class Adam(GenericAdaptiveOptimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, def __init__(self, params,
weight_decay: WeightDecay = WeightDecay(), defaults=None): lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
weight_decay: WeightDecay = WeightDecay(),
defaults: Optional[Dict[str, Any]] = None):
"""
### Initialize the optimizer
* `params` is the list of parameters
* 'lr' is the learning rate $\alpha$
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
* `weight_decay` is an instance of class `WeightDecay` defined in [__init__.py](index.html)
* `defaults` is a dictionary of default for group values.
This is useful when you want to extend the class `Adam`.
"""
defaults = {} if defaults is None else defaults defaults = {} if defaults is None else defaults
defaults.update(weight_decay.defaults()) defaults.update(weight_decay.defaults())
super().__init__(params, defaults, lr, betas, eps) super().__init__(params, defaults, lr, betas, eps)
self.weight_decay = weight_decay self.weight_decay = weight_decay
def init_state(self, state: Dict[str, any], group: Dict[str, any], p: nn.Parameter): def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
"""
### Initialize a parameter state
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `param` is the parameter tensor $\theta_{t-1}$
"""
# This is the number of optimizer steps taken on the parameter, $t$
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradients, $m_t$
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values, $v_t$
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor): def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
"""
### Calculate $m_t$ and and $v_t$
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
"""
# Get $\beta_1$ and $\beta_2$
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
# get current state variable # Get $m_{t-1}$ and $v_{t-1}$
m, v = state['exp_avg'], state['exp_avg_sq'] m, v = state['exp_avg'], state['exp_avg_sq']
# Update first and second moment running average # In-place calculation of $m_t$
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
m.mul_(beta1).add_(grad, alpha=1 - beta1) m.mul_(beta1).add_(grad, alpha=1 - beta1)
# In-place calculation of $v_t$
# $$v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2$$
v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
return m, v return m, v
def get_lr(self, state: Dict[str, any], group: Dict[str, any]): def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
"""
### Get learning-rate
This returns the modified learning rate based on the state.
For *Adam* this is just the specified learning rate for the parameter group,
$\alpha$.
"""
return group['lr'] return group['lr']
def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter, def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
m: torch.Tensor, v: torch.Tensor): m: torch.Tensor, v: torch.Tensor):
"""
### Do the *Adam* parameter update
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `param` is the parameter tensor $\theta_{t-1}$
* `m` and `v` are the uncorrected first and second moments $m_t$ and $v_t$.
This computes the following
\begin{align}
\hat{m}_t &\leftarrow \frac{m_t}/{1-\beta_1^t} \\
\hat{v}_t &\leftarrow \frac{v_t}/{1-\beta_2^t} \\
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
\end{align}
Since $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalars and others are tensors
we modify this calculation to optimize the computation.
\begin{align}
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \\
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot
\frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\
\theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
\frac{m_t}{\sqrt{v_t} + \epsilon'} \\
\end{align}
where
$$\epsilon` = (1-\beta_2^t) \epsilon \approx \epsilon$$
since $\beta_2 \approx 1$
"""
# Get $\beta_1$ and $\beta_2$
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
# Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$
bias_correction1 = 1 - beta1 ** state['step'] bias_correction1 = 1 - beta1 ** state['step']
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state['step']
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) # $\sqrt{v_t} + \epsilon$
step_size = self.get_lr(state, group) / bias_correction1 denominator = v.sqrt().add_(group['eps'])
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
# \frac{m_t}{\sqrt{v_t} + \epsilon}$
param.data.addcdiv_(m, denominator, value=-step_size) param.data.addcdiv_(m, denominator, value=-step_size)
def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
"""
### Take an update step for a given paramter tensor
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
* `param` is the parameter tensor $\theta_{t-1}$
"""
# Calculate weight decay
self.weight_decay(param, group) self.weight_decay(param, group)
# Get $m_t$ and $v_t$
m, v = self.get_mv(state, group, grad) m, v = self.get_mv(state, group, grad)
# Calculate $t$ the number of optimizer steps
state['step'] += 1 state['step'] += 1
# Perform *Adam* update
self.adam_update(state, group, param, m, v) self.adam_update(state, group, param, m, v)

View File

@ -15,11 +15,11 @@ class AMSGrad(Adam):
super().__init__(params, lr, betas, eps, weight_decay, defaults) super().__init__(params, lr, betas, eps, weight_decay, defaults)
def init_state(self, state: Dict[str, any], group: Dict[str, any], p: nn.Parameter): def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
super().init_state(state, group, p) super().init_state(state, group, param)
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
if group['amsgrad']: if group['amsgrad']:
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
def get_mv(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor): def get_mv(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor):
m, v = super().get_mv(state, group, grad) m, v = super().get_mv(state, group, grad)

View File

@ -18,7 +18,7 @@ class RAdam(AMSGrad):
self.degenerated_to_sgd = degenerated_to_sgd self.degenerated_to_sgd = degenerated_to_sgd
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults) super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults)
def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
self.weight_decay(param, group) self.weight_decay(param, group)
m, v = self.get_mv(state, group, grad) m, v = self.get_mv(state, group, grad)