mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			62 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			62 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Adam optimizer with warm-up
 | |
| summary: A simple PyTorch implementation/tutorial of Adam optimizer with warm-up.
 | |
| ---
 | |
| 
 | |
| # Adam Optimizer with Warmup
 | |
| 
 | |
| This extends [AMSGrad optimizer](amsgrad.html) and adds a warmup stage.
 | |
| """
 | |
| 
 | |
| from typing import Dict
 | |
| 
 | |
| from labml_nn.optimizers import WeightDecay
 | |
| from labml_nn.optimizers.amsgrad import AMSGrad
 | |
| 
 | |
| 
 | |
| class AdamWarmup(AMSGrad):
 | |
|     """
 | |
|     ## Adam Optimizer with Warmup
 | |
| 
 | |
|     This class extends from AMSGrad optimizer defined in [`amsgrad.py`](amsgrad.html).
 | |
|     """
 | |
|     def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
 | |
|                  weight_decay: WeightDecay = WeightDecay(),
 | |
|                  optimized_update: bool = True,
 | |
|                  amsgrad=False, warmup=0, defaults=None):
 | |
|         """
 | |
|         ### Initialize the optimizer
 | |
| 
 | |
|         * `params` is the list of parameters
 | |
|         * `lr` is the learning rate $\alpha$
 | |
|         * `betas` is a tuple of ($\beta_1$, $\beta_2$)
 | |
|         * `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
 | |
|         * `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
 | |
|         * 'optimized_update' is a flag whether to optimize the bias correction of the second moment
 | |
|           by doing it after adding $\epsilon$
 | |
|         * `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
 | |
|         * `warmup` number of warmup steps
 | |
|         * `defaults` is a dictionary of default for group values.
 | |
|          This is useful when you want to extend the class `AdamWarmup`.
 | |
|         """
 | |
| 
 | |
|         defaults = {} if defaults is None else defaults
 | |
|         defaults.update(dict(warmup=warmup))
 | |
|         super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
 | |
| 
 | |
|     def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
 | |
|         """
 | |
|         ### Get learning-rate
 | |
| 
 | |
|         $$\alpha \min \bigg(1, \frac{t}{w}\bigg)$$
 | |
|         where $w$ is the number of warmup steps.
 | |
|         """
 | |
|         # If we are in warmup stage
 | |
|         if group['warmup'] > state['step']:
 | |
|             # A linearly increasing learning rate from $0$ to $\alpha$
 | |
|             return 1e-8 + state['step'] * group['lr'] / group['warmup']
 | |
|         else:
 | |
|             # Constant learning rate $\alpha$
 | |
|             return group['lr']
 | 
