10from typing import Tuple
11
12import torch
13
14from labml.configs import BaseConfigs, option, meta_config
15from labml_nn.optimizers import WeightDecay18class OptimizerConfigs(BaseConfigs):优化器
26    optimizer: torch.optim.Adam体重衰减
29    weight_decay_obj: WeightDecay权重衰减是否解耦;即权重衰减不添加到梯度中
32    weight_decouple: bool = True体重衰减
34    weight_decay: float = 0.0体重衰减是绝对的还是应该乘以学习速率
36    weight_decay_absolute: bool = Falseadam 更新是否经过优化(不同的 epsilon)
39    optimized_adam_update: bool = True要优化的参数
42    parameters: any学习率
45    learning_rate: float = 0.01Adam 的 Beta 值
47    betas: Tuple[float, float] = (0.9, 0.999)Epsilon 代表亚当
49    eps: float = 1e-08新加坡元的势头
52    momentum: float = 0.5是否使用 AmsGrad
54    amsgrad: bool = False预热优化器步骤数
57    warmup: int = 2_000优化器步长总数(余弦衰减)
59    total_steps: int = int(1e10)是否在 AdaBeLief 中退化为新加坡元
62    degenerate_to_sgd: bool = True是否在 AdaBelief 中使用整改过的亚当
65    rectify: bool = TrueNoam 优化器的模型嵌入大小
68    d_model: int70    def __init__(self):
71        super().__init__(_primary='optimizer')
72
73
74meta_config(OptimizerConfigs.parameters)77@option(OptimizerConfigs.weight_decay_obj, 'L2')
78def _weight_decay(c: OptimizerConfigs):
79    return WeightDecay(c.weight_decay, c.weight_decouple, c.weight_decay_absolute)
80
81
82@option(OptimizerConfigs.optimizer, 'SGD')
83def _sgd_optimizer(c: OptimizerConfigs):
84    return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum,
85                           weight_decay=c.weight_decay)
86
87
88@option(OptimizerConfigs.optimizer, 'Adam')
89def _adam_optimizer(c: OptimizerConfigs):
90    if c.amsgrad:
91        from labml_nn.optimizers.amsgrad import AMSGrad
92        return AMSGrad(c.parameters,
93                       lr=c.learning_rate, betas=c.betas, eps=c.eps,
94                       optimized_update=c.optimized_adam_update,
95                       weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad)
96    else:
97        from labml_nn.optimizers.adam import Adam
98        return Adam(c.parameters,
99                    lr=c.learning_rate, betas=c.betas, eps=c.eps,
100                    optimized_update=c.optimized_adam_update,
101                    weight_decay=c.weight_decay_obj)
102
103
104@option(OptimizerConfigs.optimizer, 'AdamW')
105def _adam_warmup_optimizer(c: OptimizerConfigs):
106    from labml_nn.optimizers.adam_warmup import AdamWarmup
107    return AdamWarmup(c.parameters,
108                      lr=c.learning_rate, betas=c.betas, eps=c.eps,
109                      weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup)
110
111
112@option(OptimizerConfigs.optimizer, 'RAdam')
113def _radam_optimizer(c: OptimizerConfigs):
114    from labml_nn.optimizers.radam import RAdam
115    return RAdam(c.parameters,
116                 lr=c.learning_rate, betas=c.betas, eps=c.eps,
117                 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
118                 degenerated_to_sgd=c.degenerate_to_sgd)
119
120
121@option(OptimizerConfigs.optimizer, 'AdaBelief')
122def _ada_belief_optimizer(c: OptimizerConfigs):
123    from labml_nn.optimizers.ada_belief import AdaBelief
124    return AdaBelief(c.parameters,
125                     lr=c.learning_rate, betas=c.betas, eps=c.eps,
126                     weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
127                     degenerate_to_sgd=c.degenerate_to_sgd,
128                     rectify=c.rectify)
129
130
131@option(OptimizerConfigs.optimizer, 'Noam')
132def _noam_optimizer(c: OptimizerConfigs):
133    from labml_nn.optimizers.noam import Noam
134    return Noam(c.parameters,
135                lr=c.learning_rate, betas=c.betas, eps=c.eps,
136                weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup,
137                d_model=c.d_model)
138
139
140@option(OptimizerConfigs.optimizer, 'AdamWarmupCosineDecay')
141def _noam_optimizer(c: OptimizerConfigs):
142    from labml_nn.optimizers.adam_warmup_cosine_decay import AdamWarmupCosineDecay
143    return AdamWarmupCosineDecay(c.parameters,
144                                 lr=c.learning_rate, betas=c.betas, eps=c.eps,
145                                 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
146                                 warmup=c.warmup, total_steps=c.total_steps)