මෙමක්රියාත්මක කිරීම පදනම් වී ඇත්තේ කඩදාසි නිල වශයෙන් ක්රියාත්මක කිරීම මත අනුවර්තී ඉගෙනුම් අනුපාතය සහ ඉන් ඔබ්බට විචල්යතාව මත ය.
අපගේ AMSGrad ක්රියාත්මක කිරීමේ දිගුවක් ලෙස අපි එය PyTorch හි ක්රියාත්මක කර ඇති අතර එමඟින් ක්රියාත්මක කළ යුතු වෙනස්කම් පමණක් අවශ්ය වේ.
පුහුණුවේආරම්භක අදියරවලදී ඇඩම් ප්රශස්තකරණය සමහර විට නරක දේශීය ප්රශස්තිකරණයකට අභිසාරී වේ; විශේෂයෙන් ට්රාන්ස්ෆෝමර් පුහුණු කිරීමේදී. පර්යේෂණයන් මෙය මැඩපැවැත්වීම සඳහා උණුසුම් කිරීම් භාවිතා කරයි; මූලික පුහුණු පියවර සඳහා (උණුසුම් අවධිය) ඔවුන් අඩු ඉගෙනුම් අනුපාතයක් භාවිතා කරයි. පුහුණුවේ ආරම්භක අදියරවලදී අනුවර්තී ඉගෙනුම් අනුපාතයේ ඉහළ විචලතාව මෙම ලිපිය මඟින් ගැටළුව හඳුනා ගන්නා අතර විචලතාව අඩු කිරීම සඳහා නව නිවැරදි කිරීමේ යෙදුමක් භාවිතා කරමින් එය ගණන් කරයි.
කඩදාසිවිචල්යතා අඩු කිරීමේ යාන්ත්රණ දෙකක් ද ඇගයීමට ලක් කරයි: ඇඩම්-2K: පරාමිතීන් වෙනස් නොකර හෝ ගම්යතාව ගණනය නොකර පළමු 2k පියවර තුළ අනුවර්තී ඉගෙනුම් අනුපාතය ( ආදම්හි) පමණක් ගණනය කරන්න ( ). ඇඩම්-ඊපීඑස්: ආදම් විශාල .
ගම්යතාව සහ අනුවර්තී ඉගෙනුම් අනුපාතය ගණනය කිරීම සඳහා කාර්යයන් කරමු. ආදම් සඳහා, ඔවුන්
ඝාතීයචලනය වන සාමාන්යය බෙදා හැරීම සරල චලනය වන සාමාන්යයක් ලෙස ආසන්න කළ හැකිය.
මෙන්නඅපි අවසාන ශ්රේණියේ සරල චලනය වන සාමාන්යය ගන්නෙමු. පහත සඳහන් දෑ තෘප්තිමත් කරයි,
ලබාදෙන,
ඉහළින්අපට කොහේද තිබේ . මෙහි සම්මත අපගමනය හා ගම්යතාව සඳහා වඩා වෙනස් බව සලකන්න.
පරිමාණය කරන ලද ප්රතිලෝම චි-චතුරස්රාකාර යනු සාමාන්ය බෙදාහැරීම් වල මධ්යන්යයේ ප්රතිලෝම චතුරස්රාකාර බෙදා හැරීමයි. කොහෙද .
ඔවුන්විචලනය වන විට අඩු වන බව ඔවුහු ඔප්පු කරති.
එබැවින්විචලතාව උපරිම වශයෙන් අවම කර ඇත. අවම විචලතාව වීමට ඉඩ දෙන්න
අනුවර්තීඉගෙනුම් අනුපාතයට ස්ථාවර විචලතාවයක් ඇති බව සහතික කිරීම සඳහා, අපි විචලනය නිවැරදි කරමු
ඔවුන් 🤪 පළමු පිණිස පුළුල් මත පදනම් තක්සේරු මම එය ව්යුත්පන්න කරන ආකාරය ලැබුණේ නැහැ.
බෙදා හැරීමේ සිට අප සතුව ඇත,
ලබාදෙන,
අපසතුව ඇත
කොහේද? Lt සහ පියවර වන්න , සහ පියවරෙන් පියවර නිවැරදි කිරීමේ පදය වන්න .
මෙයලබා දෙයි,
139import math
140from typing import Dict, Optional
141
142import torch
143
144from labml_nn.optimizers import WeightDecay
145from labml_nn.optimizers.amsgrad import AMSGrad
148class RAdam(AMSGrad):
params
යනු පරාමිතීන් ලැයිස්තුවයි lr
යනු ඉගෙනුම් අනුපාතයයි betas
(, ) ක tuple වේ eps
හෝ මත පදනම් වේ optimized_update
weight_decay
WeightDecay
අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.py
optimized_update
එකතු කිරීමෙන් පසු එය කිරීමෙන් දෙවන මොහොතේ පක්ෂග්රාහීව නිවැරදි කිරීම ප්රශස්ත කිරීම සඳහා ධජයකි amsgrad
ආදම් සරල කිරීම සඳහා AMSGrad හෝ වැටීම භාවිතා කළ යුතුද යන්න දැක්වෙන ධජයකි degenerate_to_sgd
නිවැරදි කිරීමේ පදය නොසැලකිය හැකි විට sgd භාවිතා කළ යුතුද යන්න. defaults
කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් RAdam
වේ. 155 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
156 weight_decay: WeightDecay = WeightDecay(),
157 optimized_update: bool = True,
158 amsgrad=False,
159 degenerated_to_sgd=True, defaults=None):
175 self.degenerated_to_sgd = degenerated_to_sgd
176 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
state
පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි grad
පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ param
පරාමිතිය tensor වේ 178 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
බරක්ෂය වීම ගණනය කරන්න
189 grad = self.weight_decay(param, grad, group)
ලබාගන්න සහ ; එනම් සහ පක්ෂග්රාහී නිවැරදි කිරීමකින් තොරව
192 m, v = self.get_mv(state, group, grad)
ප්රශස්තිකරණපියවර ගණන ගණනය කරන්න
195 state['step'] += 1
RadAM යාවත්කාලීන කිරීම සිදු
198 self.r_adam_update(state, group, param, m, v)
200 @staticmethod
201 def calc_rectification_term(beta2: float, step: int) -> Optional[float]:
207 beta2_t = beta2 ** step
209 rho_inf = 2 / (1 - beta2) - 1
211 rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)
විට සොයාගත හැකිය . එය ආසන්න අගයක් බැවින් අපි තව ටිකක් ගතානුගතික වෙමු
215 if rho >= 5:
217 r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)
218 return math.sqrt(r2)
219 else:
220 return None
state
පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි param
පරාමිතිය tensor වේ m
v
සහ නිවැරදි නොකළ පළමු හා දෙවන අවස්ථා සහ ; i.e. පක්ෂග්රාහී නිවැරදි කිරීමකින් තොරව222 def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
223 m: torch.Tensor, v: torch.Tensor):
ලබා ගන්න
235 beta1, beta2 = group['betas']
සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,
237 bias_correction1 = 1 - beta1 ** state['step']
සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,
239 bias_correction2 = 1 - beta2 ** state['step']
240
241 r = self.calc_rectification_term(beta2, state['step'])
ඉගෙනුම්අනුපාතය ලබා ගන්න
244 lr = self.get_lr(state, group)
ඇද ගත නොහැකි නම්
247 if r is not None:
Scalarගණනය ඒකාබද්ධ විසින් ගණනය උපරිම ඵල ලබා ගැනීම සඳහා යන්න
249 if self.optimized_update:
නිගණ්ඨයා
251 denominator = v.sqrt().add_(group['eps'])
පියවරප්රමාණය
253 step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1
පරාමිතීන්යාවත්කාලීන කරන්න
256 param.data.addcdiv_(m, denominator, value=-step_size)
ප්රශස්තිකරණයකින්තොරව ගණනය කිරීම
258 else:
නිගණ්ඨයා
260 denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
පියවරප්රමාණය
262 step_size = lr * r / bias_correction1
පරාමිතීන්යාවත්කාලීන කරන්න
265 param.data.addcdiv_(m, denominator, value=-step_size)
ලබාගත නොහැකි නම් ගම්යතාව සමග SGD කරන්න
268 elif self.degenerated_to_sgd:
පියවරප්රමාණය
270 step_size = lr / bias_correction1
පරාමිතීන්යාවත්කාලීන කරන්න
273 param.data.add_(m, alpha=-step_size)
276def _test_rectification_term():
282 import matplotlib.pyplot as plt
283 import numpy as np
284
285 beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]
286 plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])
287 plt.legend(beta2)
288 plt.title("Optimizer")
289 plt.show()
290
291
292if __name__ == '__main__':
293 _test_rectification_term()