12from typing import Dict
13
14from labml_nn.optimizers import WeightDecay
15from labml_nn.optimizers.amsgrad import AMSGrad18class AdamWarmup(AMSGrad):params
 is the list of parameters lr
 is the learning rate  betas
 is a tuple of (, ) eps
 is  or  based on optimized_update
 weight_decay
 is an instance of class WeightDecay
 defined in __init__.py
 amsgrad
 is a flag indicating whether to use AMSGrad or fallback to plain Adam warmup
 number of warmup steps defaults
 is a dictionary of default for group values.  This is useful when you want to extend the class AdamWarmup
.24    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
25                 weight_decay: WeightDecay = WeightDecay(),
26                 optimized_update: bool = True,
27                 amsgrad=False, warmup=0, defaults=None):44        defaults = {} if defaults is None else defaults
45        defaults.update(dict(warmup=warmup))
46        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)48    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):If we are in warmup stage
56        if group['warmup'] > state['step']:A linearly increasing learning rate from to
58            return 1e-8 + state['step'] * group['lr'] / group['warmup']
59        else:Constant learning rate
61            return group['lr']