""" --- title: Adam optimizer with warm-up summary: A simple PyTorch implementation/tutorial of Adam optimizer with warm-up. --- # Adam Optimizer with Warmup This extends [AMSGrad optimizer](amsgrad.html) and adds a warmup stage. """ from typing import Dict from labml_nn.optimizers import WeightDecay from labml_nn.optimizers.amsgrad import AMSGrad class AdamWarmup(AMSGrad): """ ## Adam Optimizer with Warmup This class extends from AMSGrad optimizer defined in [`amsgrad.py`](amsgrad.html). """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True, amsgrad=False, warmup=0, defaults=None): """ ### Initialize the optimizer * `params` is the list of parameters * `lr` is the learning rate $\alpha$ * `betas` is a tuple of ($\beta_1$, $\beta_2$) * `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update` * `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html) * 'optimized_update' is a flag whether to optimize the bias correction of the second moment by doing it after adding $\epsilon$ * `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam * `warmup` number of warmup steps * `defaults` is a dictionary of default for group values. This is useful when you want to extend the class `AdamWarmup`. """ defaults = {} if defaults is None else defaults defaults.update(dict(warmup=warmup)) super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults) def get_lr(self, state: Dict[str, any], group: Dict[str, any]): """ ### Get learning-rate $$\alpha \min \bigg(1, \frac{t}{w}\bigg)$$ where $w$ is the number of warmup steps. """ # If we are in warmup stage if group['warmup'] > state['step']: # A linearly increasing learning rate from $0$ to $\alpha$ return 1e-8 + state['step'] * group['lr'] / group['warmup'] else: # Constant learning rate $\alpha$ return group['lr']