10from typing import Tuple
11
12import torch
13
14from labml.configs import BaseConfigs, option, meta_config
15from labml_nn.optimizers import WeightDecay18class OptimizerConfigs(BaseConfigs):ප්රශස්තකරණය
26 optimizer: torch.optim.Adamසිරුරේබර ක්ෂය
29 weight_decay_obj: WeightDecayබරක්ෂය වීම දිරාපත් වේද; එනම් බර ක්ෂය වීම අනුක්රමික වලට එකතු නොවේ
32 weight_decouple: bool = Trueසිරුරේබර ක්ෂය
34 weight_decay: float = 0.0බරක්ෂය වීම නිරපේක්ෂ හෝ ඉගෙනුම් අනුපාතය ගුණ කළ යුතු ද යන්න
36 weight_decay_absolute: bool = Falseආදම්යාවත්කාලීන වීමත යන්න (විවිධ epsilon)
39 optimized_adam_update: bool = Trueප්රශස්තිකරණයකළ යුතු පරාමිතීන්
42 parameters: anyඉගෙනුම්අනුපාතය
45 learning_rate: float = 0.01ආදම් සඳහා බීටා අගයන්
47 betas: Tuple[float, float] = (0.9, 0.999)ආදම් සඳහා එප්සිලන්
49 eps: float = 1e-08SGDසඳහා ගම්යතාව
52 momentum: float = 0.5AMSGradභාවිතා කළ යුතුද යන්න
54 amsgrad: bool = Falseඋනුසුම්ප්රශස්තිකරණ පියවර ගණන
57 warmup: int = 2_000ප්රශස්තිකරණපියවර ගණන (කොසයින් ක්ෂය වීම සඳහා)
59 total_steps: int = int(1e10)Adeabeliefහි SGD වෙත පරිහානියට පත් විය යුතුද යන්න
62 degenerate_to_sgd: bool = TrueAdameliefහි නිවැරදි කරන ලද ආදම් භාවිතා කළ යුතුද යන්න
65 rectify: bool = TrueNoamප්රශස්තකරණය සඳහා ආදර්ශ කාවැද්දීම ප්රමාණය
68 d_model: int70 def __init__(self):
71 super().__init__(_primary='optimizer')
72
73
74meta_config(OptimizerConfigs.parameters)77@option(OptimizerConfigs.weight_decay_obj, 'L2')
78def _weight_decay(c: OptimizerConfigs):
79 return WeightDecay(c.weight_decay, c.weight_decouple, c.weight_decay_absolute)
80
81
82@option(OptimizerConfigs.optimizer, 'SGD')
83def _sgd_optimizer(c: OptimizerConfigs):
84 return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum,
85 weight_decay=c.weight_decay)
86
87
88@option(OptimizerConfigs.optimizer, 'Adam')
89def _adam_optimizer(c: OptimizerConfigs):
90 if c.amsgrad:
91 from labml_nn.optimizers.amsgrad import AMSGrad
92 return AMSGrad(c.parameters,
93 lr=c.learning_rate, betas=c.betas, eps=c.eps,
94 optimized_update=c.optimized_adam_update,
95 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad)
96 else:
97 from labml_nn.optimizers.adam import Adam
98 return Adam(c.parameters,
99 lr=c.learning_rate, betas=c.betas, eps=c.eps,
100 optimized_update=c.optimized_adam_update,
101 weight_decay=c.weight_decay_obj)
102
103
104@option(OptimizerConfigs.optimizer, 'AdamW')
105def _adam_warmup_optimizer(c: OptimizerConfigs):
106 from labml_nn.optimizers.adam_warmup import AdamWarmup
107 return AdamWarmup(c.parameters,
108 lr=c.learning_rate, betas=c.betas, eps=c.eps,
109 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup)
110
111
112@option(OptimizerConfigs.optimizer, 'RAdam')
113def _radam_optimizer(c: OptimizerConfigs):
114 from labml_nn.optimizers.radam import RAdam
115 return RAdam(c.parameters,
116 lr=c.learning_rate, betas=c.betas, eps=c.eps,
117 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
118 degenerated_to_sgd=c.degenerate_to_sgd)
119
120
121@option(OptimizerConfigs.optimizer, 'AdaBelief')
122def _ada_belief_optimizer(c: OptimizerConfigs):
123 from labml_nn.optimizers.ada_belief import AdaBelief
124 return AdaBelief(c.parameters,
125 lr=c.learning_rate, betas=c.betas, eps=c.eps,
126 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
127 degenerate_to_sgd=c.degenerate_to_sgd,
128 rectify=c.rectify)
129
130
131@option(OptimizerConfigs.optimizer, 'Noam')
132def _noam_optimizer(c: OptimizerConfigs):
133 from labml_nn.optimizers.noam import Noam
134 return Noam(c.parameters,
135 lr=c.learning_rate, betas=c.betas, eps=c.eps,
136 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup,
137 d_model=c.d_model)
138
139
140@option(OptimizerConfigs.optimizer, 'AdamWarmupCosineDecay')
141def _noam_optimizer(c: OptimizerConfigs):
142 from labml_nn.optimizers.adam_warmup_cosine_decay import AdamWarmupCosineDecay
143 return AdamWarmupCosineDecay(c.parameters,
144 lr=c.learning_rate, betas=c.betas, eps=c.eps,
145 weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,
146 warmup=c.warmup, total_steps=c.total_steps)