10from typing import Tuple
11
12import torch
13
14from labml.configs import BaseConfigs, option, meta_config
15from labml_nn.optimizers import WeightDecay18class OptimizerConfigs(BaseConfigs):Optimizer
26 optimizer: torch.optim.AdamWeight decay
29 weight_decay_obj: WeightDecayWhether weight decay is decoupled; i.e. weight decay is not added to gradients
32 weight_decouple: bool = TrueWeight decay
34 weight_decay: float = 0.0Whether weight decay is absolute or should be multiplied by learning rate
36 weight_decay_absolute: bool = FalseWhether the adam update is optimized (different epsilon)
39 optimized_adam_update: bool = TrueParameters to be optimized
42 parameters: anyLearning rate $\alpha$
45 learning_rate: float = 0.01Beta values $(\beta_1, \beta_2)$ for Adam
47 betas: Tuple[float, float] = (0.9, 0.999)Epsilon $\epsilon$ for adam
49 eps: float = 1e-08Momentum for SGD
52 momentum: float = 0.5Whether to use AMSGrad
54 amsgrad: bool = FalseNumber of warmup optimizer steps
57 warmup: int = 2_000Total number of optimizer steps (for cosine decay)
59 total_steps: int = int(1e10)Whether to degenerate to SGD in AdaBelief
62 degenerate_to_sgd: bool = TrueWhether to use Rectified Adam in AdaBelief
65 rectify: bool = TrueModel embedding size for Noam optimizer
68 d_model: int70 def __init__(self):
71 super().__init__(_primary='optimizer')
72
73
74meta_config(OptimizerConfigs.parameters)77@option(OptimizerConfigs.weight_decay_obj, 'L2')
78def _weight_decay(c: OptimizerConfigs):
79 return WeightDecay(c.weight_decay, c.weight_decouple, c.weight_decay_absolute)
80
81
82@option(OptimizerConfigs.optimizer, 'SGD')
83def _sgd_optimizer(c: OptimizerConfigs):
84 return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum)
85
86
87@option(OptimizerConfigs.optimizer, 'Adam')
88def _adam_optimizer(c: OptimizerConfigs):
89 if c.amsgrad:
90 from labml_nn.optimizers.amsgrad import AMSGrad
91 return AMSGrad(c.parameters,
92 lr=c.learning_rate, betas=c.betas, eps=c.eps,
93 optimized_update=c.optimized_adam_update,
94 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad)
95 else:
96 from labml_nn.optimizers.adam import Adam
97 return Adam(c.parameters,
98 lr=c.learning_rate, betas=c.betas, eps=c.eps,
99 optimized_update=c.optimized_adam_update,
100 weight_decay=c.weight_decay_obj)
101
102
103@option(OptimizerConfigs.optimizer, 'AdamW')
104def _adam_warmup_optimizer(c: OptimizerConfigs):
105 from labml_nn.optimizers.adam_warmup import AdamWarmup
106 return AdamWarmup(c.parameters,
107 lr=c.learning_rate, betas=c.betas, eps=c.eps,
108 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup)
109
110
111@option(OptimizerConfigs.optimizer, 'RAdam')
112def _radam_optimizer(c: OptimizerConfigs):
113 from labml_nn.optimizers.radam import RAdam
114 return RAdam(c.parameters,
115 lr=c.learning_rate, betas=c.betas, eps=c.eps,
116 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
117 degenerated_to_sgd=c.degenerate_to_sgd)
118
119
120@option(OptimizerConfigs.optimizer, 'AdaBelief')
121def _ada_belief_optimizer(c: OptimizerConfigs):
122 from labml_nn.optimizers.ada_belief import AdaBelief
123 return AdaBelief(c.parameters,
124 lr=c.learning_rate, betas=c.betas, eps=c.eps,
125 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
126 degenerate_to_sgd=c.degenerate_to_sgd,
127 rectify=c.rectify)
128
129
130@option(OptimizerConfigs.optimizer, 'Noam')
131def _noam_optimizer(c: OptimizerConfigs):
132 from labml_nn.optimizers.noam import Noam
133 return Noam(c.parameters,
134 lr=c.learning_rate, betas=c.betas, eps=c.eps,
135 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup,
136 d_model=c.d_model)
137
138
139@option(OptimizerConfigs.optimizer, 'AdamWarmupCosineDecay')
140def _noam_optimizer(c: OptimizerConfigs):
141 from labml_nn.optimizers.adam_warmup_cosine_decay import AdamWarmupCosineDecay
142 return AdamWarmupCosineDecay(c.parameters,
143 lr=c.learning_rate, betas=c.betas, eps=c.eps,
144 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
145 warmup=c.warmup, total_steps=c.total_steps)