mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 02:07:56 +08:00
160 lines
6.7 KiB
Python
160 lines
6.7 KiB
Python
"""
|
|
---
|
|
title: AdaBelief optimizer
|
|
summary: A simple PyTorch implementation/tutorial of AdaBelief optimizer.
|
|
---
|
|
|
|
# AdaBelief Optimizer
|
|
|
|
This is based from AdaBelief
|
|
[official implementation](https://github.com/juntang-zhuang/Adabelief-Optimizer)
|
|
of the paper
|
|
[AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://papers.labml.ai/paper/2010.07468).
|
|
|
|
This is implemented in [PyTorch](https://pytorch.org) as an extension to [RAdam](radam.html).
|
|
|
|
The main difference between Adam optimizer and AdaBelief is that,
|
|
how it calculates the adaptive learning rate;
|
|
instead of dividing by the exponential moving average of square of the gradients,
|
|
AdaBelief divides by the exponential mean of variance.
|
|
|
|
\begin{align}
|
|
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\
|
|
\textcolor{cyan}{s_t} &\textcolor{cyan}{\leftarrow} \textcolor{cyan}{\beta_2 s_{t-1} + (1 - \beta_2) \cdot (g_t - m_t)^2} \\
|
|
\hat{m}_t &\leftarrow \frac{m_t}{1-\beta_1^t} \\
|
|
\textcolor{cyan}{\hat{s}_t} &\textcolor{cyan}{\leftarrow} \frac{\textcolor{cyan}{s_t} + \textcolor{red}{\epsilon}}{\textcolor{cyan}{1-\beta_2^t}} \\
|
|
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\textcolor{cyan}{\hat{s}_t}} + \epsilon}
|
|
\end{align}
|
|
|
|
🤔 The paper calculates variance as $(g_t - m_t)^2$,
|
|
but I feel it should use the bias corrected momentum
|
|
$(g_t - \textcolor{orange}{\hat{m}_t})^2$.
|
|
I guess this doesn't affect things much because
|
|
bias correction is $\approx 1$ after the initial training steps.
|
|
"""
|
|
|
|
from typing import Dict, Any
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from labml_nn.optimizers import WeightDecay
|
|
from labml_nn.optimizers.radam import RAdam
|
|
|
|
|
|
class AdaBelief(RAdam):
|
|
"""
|
|
## AdaBelief Optimizer
|
|
|
|
This class extends from RAdam optimizer defined in [`radam.py`](radam.html).
|
|
"""
|
|
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
|
weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
|
|
degenerate_to_sgd=True,
|
|
rectify=True, defaults=None):
|
|
"""
|
|
### Initialize the optimizer
|
|
|
|
* `params` is the list of parameters
|
|
* `lr` is the learning rate $\alpha$
|
|
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
|
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
|
|
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
|
* `optimized_update` is a flag whether to optimize the bias correction of the second moment
|
|
by doing it after adding $\epsilon$
|
|
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
|
|
* `degenerate_to_sgd` whether to use sgd when the rectification term $r_t$ is intractable
|
|
* `rectify` is whether to use RAdam update
|
|
* `defaults` is a dictionary of default for group values.
|
|
This is useful when you want to extend the class `AdaBelief`.
|
|
"""
|
|
|
|
defaults = {} if defaults is None else defaults
|
|
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
|
|
self.rectify = rectify
|
|
|
|
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
|
"""
|
|
### Initialize a parameter state
|
|
|
|
* `state` is the optimizer state of the parameter (tensor)
|
|
* `group` stores optimizer attributes of the parameter group
|
|
* `param` is the parameter tensor $\theta_{t-1}$
|
|
"""
|
|
state['step'] = 0
|
|
# Exponential moving average of gradient values
|
|
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
|
|
# Exponential moving average of variance
|
|
state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
|
|
|
|
# If `amsgrad` flag is `True` for this parameter group, we maintain the maximum of
|
|
# exponential moving average of variance
|
|
if group['amsgrad']:
|
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
|
state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
|
|
|
|
def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
|
|
"""
|
|
### Calculate $m_t$ and $s_t$ or $\max(s_1, s_2, ..., s_{t-1}, s_t)$
|
|
|
|
* `state` is the optimizer state of the parameter (tensor)
|
|
* `group` stores optimizer attributes of the parameter group
|
|
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
|
|
"""
|
|
|
|
# Get $\beta_1$ and $\beta_2$
|
|
beta1, beta2 = group['betas']
|
|
|
|
# Get $m_{t-1}$ and $s_{t-1}$
|
|
m, s = state['exp_avg'], state['exp_avg_var']
|
|
|
|
# In-place calculation of $m_t$
|
|
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
|
|
m.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
# Difference between gradient and momentum
|
|
grad_residual = grad - m
|
|
# In-place calculation of $s_t$
|
|
# $$s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) \cdot (g_t - m_t)^2$$
|
|
s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
|
|
|
|
# If this parameter group is using `amsgrad`
|
|
if group['amsgrad']:
|
|
# Get $\max(s_1, s_2, ..., s_{t-1})$.
|
|
s_max = state['max_exp_avg_var']
|
|
# Calculate $\max(s_1, s_2, ..., s_{t-1}, s_t)$.
|
|
torch.maximum(s_max, s, out=s_max)
|
|
|
|
return m, s_max
|
|
else:
|
|
# $m_t$ and $s_t$ otherwise
|
|
return m, s
|
|
|
|
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
|
"""
|
|
### Take an update step for a given parameter tensor
|
|
|
|
* `state` is the optimizer state of the parameter (tensor)
|
|
* `group` stores optimizer attributes of the parameter group
|
|
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
|
|
* `param` is the parameter tensor $\theta_{t-1}$
|
|
"""
|
|
|
|
# Calculate weight decay
|
|
grad = self.weight_decay(param, grad, group)
|
|
|
|
# Get $m_t$ and $v_t$
|
|
m, s = self.get_ms(state, group, grad)
|
|
|
|
# Increment $t$ the number of optimizer steps
|
|
state['step'] += 1
|
|
|
|
if not self.rectify:
|
|
# Perform *Adam* update, defined in [`adam.py`](adam.html), with
|
|
# $\textcolor{cyan}{s_t} + \textcolor{red}{\epsilon}$ in place of $v_t$.
|
|
self.adam_update(state, group, param, m, s + group['eps'])
|
|
else:
|
|
# Perform *Rectified Adam* update defined in [`radam.py`](radam.html), with
|
|
# $\textcolor{cyan}{s_t} + \textcolor{red}{\epsilon}$ in place of $v_t$.
|
|
self.r_adam_update(state, group, param, m, s + group['eps'])
|