mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 10:51:23 +08:00
205 lines
7.8 KiB
Python
205 lines
7.8 KiB
Python
"""
|
|
---
|
|
title: AMSGrad Optimizer
|
|
summary: A simple PyTorch implementation/tutorial of AMSGrad optimizer.
|
|
---
|
|
|
|
# AMSGrad
|
|
|
|
This is a [PyTorch](https://pytorch.org) implementation of the paper
|
|
[On the Convergence of Adam and Beyond](https://arxiv.org/abs/1904.09237).
|
|
|
|
We implement this as an extension to our [Adam optimizer implementation](adam.html).
|
|
The implementation it self is really small since it's very similar to Adam.
|
|
|
|
We also have an implementation of the synthetic example described in the paper where Adam fails to converge.
|
|
"""
|
|
|
|
from typing import Dict
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from labml_nn.optimizers import WeightDecay
|
|
from labml_nn.optimizers.adam import Adam
|
|
|
|
|
|
class AMSGrad(Adam):
|
|
"""
|
|
## AMSGrad Optimizer
|
|
|
|
This class extends from Adam optimizer defined in [`adam.py`](adam.html).
|
|
Adam optimizer is extending the class `GenericAdaptiveOptimizer`
|
|
defined in [`__init__.py`](index.html).
|
|
"""
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
|
weight_decay: WeightDecay = WeightDecay(),
|
|
optimized_update: bool = True,
|
|
amsgrad=True, defaults=None):
|
|
"""
|
|
### Initialize the optimizer
|
|
|
|
* `params` is the list of parameters
|
|
* `lr` is the learning rate $\alpha$
|
|
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
|
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
|
|
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
|
* 'optimized_update' is a flag whether to optimize the bias correction of the second moment
|
|
by doing it after adding $\epsilon$
|
|
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
|
|
* `defaults` is a dictionary of default for group values.
|
|
This is useful when you want to extend the class `Adam`.
|
|
"""
|
|
defaults = {} if defaults is None else defaults
|
|
defaults.update(dict(amsgrad=amsgrad))
|
|
|
|
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
|
|
|
|
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
|
"""
|
|
### Initialize a parameter state
|
|
|
|
* `state` is the optimizer state of the parameter (tensor)
|
|
* `group` stores optimizer attributes of the parameter group
|
|
* `param` is the parameter tensor $\theta_{t-1}$
|
|
"""
|
|
|
|
# Call `init_state` of Adam optimizer which we are extending
|
|
super().init_state(state, group, param)
|
|
|
|
# If `amsgrad` flag is `True` for this parameter group, we maintain the maximum of
|
|
# exponential moving average of squared gradient
|
|
if group['amsgrad']:
|
|
state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
|
|
|
|
def get_mv(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor):
|
|
"""
|
|
### Calculate $m_t$ and and $v_t$ or $\max(v_1, v_2, ..., v_{t-1}, v_t)$
|
|
|
|
* `state` is the optimizer state of the parameter (tensor)
|
|
* `group` stores optimizer attributes of the parameter group
|
|
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
|
|
"""
|
|
|
|
# Get $m_t$ and $v_t$ from *Adam*
|
|
m, v = super().get_mv(state, group, grad)
|
|
|
|
# If this parameter group is using `amsgrad`
|
|
if group['amsgrad']:
|
|
# Get $\max(v_1, v_2, ..., v_{t-1})$.
|
|
#
|
|
# 🗒 The paper uses the notation $\hat{v}_t$ for this, which we don't use
|
|
# that here because it confuses with the Adam's usage of the same notation
|
|
# for bias corrected exponential moving average.
|
|
v_max = state['max_exp_avg_sq']
|
|
# Calculate $\max(v_1, v_2, ..., v_{t-1}, v_t)$.
|
|
#
|
|
# 🤔 I feel you should be taking / maintaining the max of the bias corrected
|
|
# second exponential average of squared gradient.
|
|
# But this is how it's
|
|
# [implemented in PyTorch also](https://github.com/pytorch/pytorch/blob/19f4c5110e8bcad5e7e75375194262fca0a6293a/torch/optim/functional.py#L90).
|
|
# I guess it doesn't really matter since bias correction only increases the value
|
|
# and it only makes an actual difference during the early few steps of the training.
|
|
torch.maximum(v_max, v, out=v_max)
|
|
|
|
return m, v_max
|
|
else:
|
|
# Fall back to *Adam* if the parameter group is not using `amsgrad`
|
|
return m, v
|
|
|
|
|
|
def _synthetic_experiment(is_adam: bool):
|
|
"""
|
|
## Synthetic Experiment
|
|
|
|
This is the synthetic experiment described in the paper,
|
|
that shows a scenario where *Adam* fails.
|
|
|
|
The paper (and Adam) formulates the problem of optimizing as
|
|
minimizing the expected value of a function, $\mathbb{E}[f(\theta)]$
|
|
with respect to the parameters $\theta$.
|
|
In the stochastic training setting we do not get hold of the function $f$
|
|
it self; that is,
|
|
when you are optimizing a NN $f$ would be the function on entire
|
|
batch of data.
|
|
What we actually evaluate is a mini-batch so the actual function is
|
|
realization of the stochastic $f$.
|
|
This is why we are talking about an expected value.
|
|
So let the function realizations be $f_1, f_2, ..., f_T$ for each time step
|
|
of training.
|
|
|
|
We measure the performance of the optimizer as the regret,
|
|
$$R(T) = \sum_{t=1}^T \big[ f_t(\theta_t) - f_t(\theta^*) \big]$$
|
|
where $theta_t$ is the parameters at time step $t$, and $\theta^*$ is the
|
|
optimal parameters that minimize $\mathbb{E}[f(\theta)]$.
|
|
|
|
Now lets define the synthetic problem,
|
|
\begin{align}
|
|
f_t(x) =
|
|
\begin{cases}
|
|
1010 x, & \text{for $t \mod 101 = 1$} \\
|
|
-10 x, & \text{otherwise}
|
|
\end{cases}
|
|
\end{align}
|
|
where $-1 \le x \le +1$.
|
|
The optimal solution is $x = -1$.
|
|
|
|
This code will try running *Adam* and *AMSGrad* on this problem.
|
|
"""
|
|
|
|
# Define $x$ parameter
|
|
x = nn.Parameter(torch.tensor([.0]))
|
|
# Optimal, $x^* = -1$
|
|
x_star = nn.Parameter(torch.tensor([-1]), requires_grad=False)
|
|
|
|
def func(t: int, x_: nn.Parameter):
|
|
"""
|
|
### $f_t(x)$
|
|
"""
|
|
if t % 101 == 1:
|
|
return (1010 * x_).sum()
|
|
else:
|
|
return (-10 * x_).sum()
|
|
|
|
# Initialize the relevant optimizer
|
|
if is_adam:
|
|
optimizer = Adam([x], lr=1e-2, betas=(0.9, 0.99))
|
|
else:
|
|
optimizer = AMSGrad([x], lr=1e-2, betas=(0.9, 0.99))
|
|
# $R(T)$
|
|
total_regret = 0
|
|
|
|
from labml import monit, tracker, experiment
|
|
|
|
# Create experiment to record results
|
|
with experiment.record(name='synthetic', comment='Adam' if is_adam else 'AMSGrad'):
|
|
# Run for $10^7$ steps
|
|
for step in monit.loop(10_000_000):
|
|
# $f_t(\theta_t) - f_t(\theta^*)$
|
|
regret = func(step, x) - func(step, x_star)
|
|
# $R(T) = \sum_{t=1}^T \big[ f_t(\theta_t) - f_t(\theta^*) \big]$
|
|
total_regret += regret.item()
|
|
# Track results every 1,000 steps
|
|
if (step + 1) % 1000 == 0:
|
|
tracker.save(loss=regret, x=x, regret=total_regret / (step + 1))
|
|
# Calculate gradients
|
|
regret.backward()
|
|
# Optimize
|
|
optimizer.step()
|
|
# Clear gradients
|
|
optimizer.zero_grad()
|
|
|
|
# Make sure $-1 \le x \le +1$
|
|
x.data.clamp_(-1., +1.)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Run the synthetic experiment is *Adam*.
|
|
# [Here are the results](https://app.labml.ai/run/61ebfdaa384411eb94d8acde48001122).
|
|
# You can see that Adam converges at $x = +1$
|
|
_synthetic_experiment(True)
|
|
# Run the synthetic experiment is *AMSGrad*
|
|
# [Here are the results](https://app.labml.ai/run/uuid=bc06405c384411eb8b82acde48001122).
|
|
# You can see that AMSGrad converges to true optimal $x = -1$
|
|
_synthetic_experiment(False)
|