mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 02:07:56 +08:00
89 lines
3.3 KiB
Python
89 lines
3.3 KiB
Python
"""
|
|
---
|
|
title: Noam optimizer from Attention is All You Need paper
|
|
summary: >
|
|
This is a tutorial/implementation of Noam optimizer.
|
|
Noam optimizer has a warm-up period and then an exponentially decaying learning rate.
|
|
---
|
|
|
|
# Noam Optimizer
|
|
|
|
This is the [PyTorch](https://pytorch.org) implementation of optimizer introduced in the paper
|
|
[Attention Is All You Need](https://papers.labml.ai/paper/1706.03762).
|
|
"""
|
|
from typing import Dict
|
|
|
|
from labml_nn.optimizers import WeightDecay
|
|
from labml_nn.optimizers.amsgrad import AMSGrad
|
|
|
|
|
|
class Noam(AMSGrad):
|
|
"""
|
|
## Noam Optimizer
|
|
|
|
This class extends from Adam optimizer defined in [`adam.py`](adam.html).
|
|
"""
|
|
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
|
weight_decay: WeightDecay = WeightDecay(),
|
|
optimized_update: bool = True,
|
|
amsgrad=False,
|
|
warmup=0, d_model=512, defaults=None):
|
|
"""
|
|
### Initialize the optimizer
|
|
|
|
* `params` is the list of parameters
|
|
* `lr` is the learning rate $\alpha$
|
|
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
|
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
|
|
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
|
* 'optimized_update' is a flag whether to optimize the bias correction of the second moment
|
|
by doing it after adding $\epsilon$
|
|
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
|
|
* `warmup` number of warmup steps
|
|
* `d_model` model size; i.e. number of dimensions in the transformer
|
|
* `defaults` is a dictionary of default for group values.
|
|
This is useful when you want to extend the class `AdamWarmup`.
|
|
"""
|
|
|
|
defaults = {} if defaults is None else defaults
|
|
defaults.update(dict(warmup=warmup))
|
|
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
|
|
self.d_model = d_model
|
|
|
|
def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
|
|
"""
|
|
### Get learning-rate
|
|
|
|
$$\alpha \frac{1}{\sqrt{d_{model}}} \min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$
|
|
where $w$ is the number of warmup steps.
|
|
"""
|
|
# $$\min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$
|
|
factor = min(state['step'] ** (-0.5), state['step'] * group['warmup'] ** (-1.5))
|
|
# $$\alpha \frac{1}{\sqrt{d_{model}}} \min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$
|
|
return group['lr'] * self.d_model ** (-0.5) * factor
|
|
|
|
|
|
def _test_noam_lr():
|
|
"""
|
|
### Plot learning rate for different warmups and model sizes
|
|
|
|

|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from torch import nn
|
|
|
|
model = nn.Linear(10, 10)
|
|
opts = [Noam(model.parameters(), d_model=512, warmup=4000, lr=1),
|
|
Noam(model.parameters(), d_model=512, warmup=8000, lr=1),
|
|
Noam(model.parameters(), d_model=2048, warmup=2000, lr=1)]
|
|
plt.plot(np.arange(1, 20000), [[opt.get_lr({'step': i}, opt.defaults) for opt in opts] for i in range(1, 20000)])
|
|
plt.legend(["512:4000", "512:8000", "2048:2000"])
|
|
plt.title("Learning Rate")
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
_test_noam_lr()
|