mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			89 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			89 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Noam optimizer from Attention is All You Need paper
 | |
| summary: >
 | |
|   This is a tutorial/implementation of Noam optimizer.
 | |
|   Noam optimizer has a warm-up period and then an exponentially decaying learning rate.
 | |
| ---
 | |
| 
 | |
| # Noam Optimizer
 | |
| 
 | |
| This is the [PyTorch](https://pytorch.org) implementation of optimizer introduced in the paper
 | |
| [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
 | |
| """
 | |
| from typing import Dict
 | |
| 
 | |
| from labml_nn.optimizers import WeightDecay
 | |
| from labml_nn.optimizers.amsgrad import AMSGrad
 | |
| 
 | |
| 
 | |
| class Noam(AMSGrad):
 | |
|     """
 | |
|     ## Noam Optimizer
 | |
| 
 | |
|     This class extends from Adam optimizer defined in [`adam.py`](adam.html).
 | |
|     """
 | |
| 
 | |
|     def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
 | |
|                  weight_decay: WeightDecay = WeightDecay(),
 | |
|                  optimized_update: bool = True,
 | |
|                  amsgrad=False,
 | |
|                  warmup=0, d_model=512, defaults=None):
 | |
|         """
 | |
|         ### Initialize the optimizer
 | |
| 
 | |
|         * `params` is the list of parameters
 | |
|         * `lr` is the learning rate $\alpha$
 | |
|         * `betas` is a tuple of ($\beta_1$, $\beta_2$)
 | |
|         * `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
 | |
|         * `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
 | |
|         * 'optimized_update' is a flag whether to optimize the bias correction of the second moment
 | |
|           by doing it after adding $\epsilon$
 | |
|         * `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
 | |
|         * `warmup` number of warmup steps
 | |
|         * `d_model` model size; i.e. number of dimensions in the transformer
 | |
|         * `defaults` is a dictionary of default for group values.
 | |
|          This is useful when you want to extend the class `AdamWarmup`.
 | |
|         """
 | |
| 
 | |
|         defaults = {} if defaults is None else defaults
 | |
|         defaults.update(dict(warmup=warmup))
 | |
|         super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
 | |
|         self.d_model = d_model
 | |
| 
 | |
|     def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
 | |
|         """
 | |
|         ### Get learning-rate
 | |
| 
 | |
|         $$\alpha \frac{1}{\sqrt{d_{model}}} \min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$
 | |
|         where $w$ is the number of warmup steps.
 | |
|         """
 | |
|         # $$\min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$
 | |
|         factor = min(state['step'] ** (-0.5), state['step'] * group['warmup'] ** (-1.5))
 | |
|         # $$\alpha \frac{1}{\sqrt{d_{model}}} \min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$
 | |
|         return group['lr'] * self.d_model ** (-0.5) * factor
 | |
| 
 | |
| 
 | |
| def _test_noam_lr():
 | |
|     """
 | |
|     ### Plot learning rate for different warmups and model sizes
 | |
| 
 | |
|     
 | |
|     """
 | |
|     import matplotlib.pyplot as plt
 | |
|     import numpy as np
 | |
|     from torch import nn
 | |
| 
 | |
|     model = nn.Linear(10, 10)
 | |
|     opts = [Noam(model.parameters(), d_model=512, warmup=4000, lr=1),
 | |
|             Noam(model.parameters(), d_model=512, warmup=8000, lr=1),
 | |
|             Noam(model.parameters(), d_model=2048, warmup=2000, lr=1)]
 | |
|     plt.plot(np.arange(1, 20000), [[opt.get_lr({'step': i}, opt.defaults) for opt in opts] for i in range(1, 20000)])
 | |
|     plt.legend(["512:4000", "512:8000", "2048:2000"])
 | |
|     plt.title("Learning Rate")
 | |
|     plt.show()
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     _test_noam_lr()
 | 
