මෙමක්රියාත්මක කිරීම පදනම් වී ඇත්තේ කඩදාසි නිල වශයෙන් ක්රියාත්මක කිරීම මත අනුවර්තී ඉගෙනුම් අනුපාතය සහ ඉන් ඔබ්බට විචල්යතාව මත ය.
අපගේ AMSGrad ක්රියාත්මක කිරීමේ දිගුවක් ලෙස අපි එය PyTorch හි ක්රියාත්මක කර ඇති අතර එමඟින් ක්රියාත්මක කළ යුතු වෙනස්කම් පමණක් අවශ්ය වේ.
පුහුණුවේආරම්භක අදියරවලදී ඇඩම් ප්රශස්තකරණය සමහර විට නරක දේශීය ප්රශස්තිකරණයකට අභිසාරී වේ; විශේෂයෙන් ට්රාන්ස්ෆෝමර් පුහුණු කිරීමේදී. පර්යේෂණයන් මෙය මැඩපැවැත්වීම සඳහා උණුසුම් කිරීම් භාවිතා කරයි; මූලික පුහුණු පියවර සඳහා (උණුසුම් අවධිය) ඔවුන් අඩු ඉගෙනුම් අනුපාතයක් භාවිතා කරයි. පුහුණුවේ ආරම්භක අදියරවලදී අනුවර්තී ඉගෙනුම් අනුපාතයේ ඉහළ විචලතාව මෙම ලිපිය මඟින් ගැටළුව හඳුනා ගන්නා අතර විචලතාව අඩු කිරීම සඳහා නව නිවැරදි කිරීමේ යෙදුමක් භාවිතා කරමින් එය ගණන් කරයි.
කඩදාසිවිචල්යතා අඩු කිරීමේ යාන්ත්රණ දෙකක් ද ඇගයීමට ලක් කරයි: ඇඩම්-2K: පරාමිතීන් වෙනස් නොකර හෝ ගම්යතාව ගණනය නොකර පළමු 2k පියවර තුළ අනුවර්තී ඉගෙනුම් අනුපාතය ( ආදම්හි) පමණක් ගණනය කරන්න ( ). ඇඩම්-ඊපීඑස්: ආදම් විශාල .
ගම්යතාව සහ අනුවර්තී ඉගෙනුම් අනුපාතය ගණනය කිරීම සඳහා කාර්යයන් කරමු. ආදම් සඳහා, ඔවුන්
ඝාතීයචලනය වන සාමාන්යය බෙදා හැරීම සරල චලනය වන සාමාන්යයක් ලෙස ආසන්න කළ හැකිය.
මෙන්නඅපි අවසාන ශ්රේණියේ සරල චලනය වන සාමාන්යය ගන්නෙමු. පහත සඳහන් දෑ තෘප්තිමත් කරයි,
ලබාදෙන,
ඉහළින්අපට කොහේද තිබේ . මෙහි සම්මත අපගමනය හා ගම්යතාව සඳහා වඩා වෙනස් බව සලකන්න.
පරිමාණය කරන ලද ප්රතිලෝම චි-චතුරස්රාකාර යනු සාමාන්ය බෙදාහැරීම් වල මධ්යන්යයේ ප්රතිලෝම චතුරස්රාකාර බෙදා හැරීමයි. කොහෙද .
ඔවුන්විචලනය වන විට අඩු වන බව ඔවුහු ඔප්පු කරති.
එබැවින්විචලතාව උපරිම වශයෙන් අවම කර ඇත. අවම විචලතාව වීමට ඉඩ දෙන්න
අනුවර්තීඉගෙනුම් අනුපාතයට ස්ථාවර විචලතාවයක් ඇති බව සහතික කිරීම සඳහා, අපි විචලනය නිවැරදි කරමු
ඔවුන් 🤪 පළමු පිණිස පුළුල් මත පදනම් තක්සේරු මම එය ව්යුත්පන්න කරන ආකාරය ලැබුණේ නැහැ.
බෙදා හැරීමේ සිට අප සතුව ඇත,
ලබාදෙන,
අපසතුව ඇත
කොහේද? Lt සහ පියවර වන්න , සහ පියවරෙන් පියවර නිවැරදි කිරීමේ පදය වන්න .
මෙයලබා දෙයි,
139import math
140from typing import Dict, Optional
141
142import torch
143
144from labml_nn.optimizers import WeightDecay
145from labml_nn.optimizers.amsgrad import AMSGrad148class RAdam(AMSGrad):params
 යනු පරාමිතීන් ලැයිස්තුවයි lr
 යනු ඉගෙනුම් අනුපාතයයි  betas
 (, ) ක tuple වේ eps
  හෝ මත  පදනම් වේ optimized_update
 weight_decay
 WeightDecay
 අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.py
 optimized_update
 එකතු කිරීමෙන් පසු එය කිරීමෙන් දෙවන මොහොතේ පක්ෂග්රාහීව නිවැරදි කිරීම ප්රශස්ත කිරීම සඳහා ධජයකි  amsgrad
 ආදම් සරල කිරීම සඳහා AMSGrad හෝ වැටීම භාවිතා කළ යුතුද යන්න දැක්වෙන ධජයකි degenerate_to_sgd
 නිවැරදි කිරීමේ පදය  නොසැලකිය හැකි විට sgd භාවිතා කළ යුතුද යන්න. defaults
 කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් RAdam
වේ. 155    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
156                 weight_decay: WeightDecay = WeightDecay(),
157                 optimized_update: bool = True,
158                 amsgrad=False,
159                 degenerated_to_sgd=True, defaults=None):175        self.degenerated_to_sgd = degenerated_to_sgd
176        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)state
 පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
 පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි grad
 පරාමිතිය  සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ  param
 පරාමිතිය tensor වේ 178    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):බරක්ෂය වීම ගණනය කරන්න
189        grad = self.weight_decay(param, grad, group)ලබාගන්න සහ ; එනම් සහ පක්ෂග්රාහී නිවැරදි කිරීමකින් තොරව
192        m, v = self.get_mv(state, group, grad)ප්රශස්තිකරණපියවර ගණන ගණනය කරන්න
195        state['step'] += 1RadAM යාවත්කාලීන කිරීම සිදු
198        self.r_adam_update(state, group, param, m, v)200    @staticmethod
201    def calc_rectification_term(beta2: float, step: int) -> Optional[float]:207        beta2_t = beta2 ** step209        rho_inf = 2 / (1 - beta2) - 1211        rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)විට සොයාගත හැකිය . එය ආසන්න අගයක් බැවින් අපි තව ටිකක් ගතානුගතික වෙමු
215        if rho >= 5:217            r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)
218            return math.sqrt(r2)
219        else:
220            return Nonestate
 පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
 පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි param
 පරාමිතිය tensor වේ  m
 v
 සහ නිවැරදි නොකළ පළමු හා දෙවන අවස්ථා  සහ ; i.e.   පක්ෂග්රාහී නිවැරදි කිරීමකින් තොරව222    def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
223                      m: torch.Tensor, v: torch.Tensor):ලබා ගන්න
235        beta1, beta2 = group['betas']සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,
237        bias_correction1 = 1 - beta1 ** state['step']සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,
239        bias_correction2 = 1 - beta2 ** state['step']
240
241        r = self.calc_rectification_term(beta2, state['step'])ඉගෙනුම්අනුපාතය ලබා ගන්න
244        lr = self.get_lr(state, group)ඇද ගත නොහැකි නම්
247        if r is not None:Scalarගණනය ඒකාබද්ධ විසින් ගණනය උපරිම ඵල ලබා ගැනීම සඳහා යන්න
249            if self.optimized_update:නිගණ්ඨයා
251                denominator = v.sqrt().add_(group['eps'])පියවරප්රමාණය
253                step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1පරාමිතීන්යාවත්කාලීන කරන්න
256                param.data.addcdiv_(m, denominator, value=-step_size)ප්රශස්තිකරණයකින්තොරව ගණනය කිරීම
258            else:නිගණ්ඨයා
260                denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])පියවරප්රමාණය
262                step_size = lr * r / bias_correction1පරාමිතීන්යාවත්කාලීන කරන්න
265                param.data.addcdiv_(m, denominator, value=-step_size)ලබාගත නොහැකි නම් ගම්යතාව සමග SGD කරන්න
268        elif self.degenerated_to_sgd:පියවරප්රමාණය
270            step_size = lr / bias_correction1පරාමිතීන්යාවත්කාලීන කරන්න
273            param.data.add_(m, alpha=-step_size)276def _test_rectification_term():282    import matplotlib.pyplot as plt
283    import numpy as np
284
285    beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]
286    plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])
287    plt.legend(beta2)
288    plt.title("Optimizer")
289    plt.show()
290
291
292if __name__ == '__main__':
293    _test_rectification_term()