නිවැරදිකරන ලද ආදම් (RaDAM) ප්රශස්ති

මෙමක්රියාත්මක කිරීම පදනම් වී ඇත්තේ කඩදාසි නිල වශයෙන් ක්රියාත්මක කිරීම මත අනුවර්තී ඉගෙනුම් අනුපාතය සහ ඉන් ඔබ්බට විචල්යතාව මත ය.

අපගේ AMSGrad ක්රියාත්මක කිරීමේ දිගුවක් ලෙස අපි එය PyTorch හි ක්රියාත්මක කර ඇති අතර එමඟින් ක්රියාත්මක කළ යුතු වෙනස්කම් පමණක් අවශ්ය වේ.

පුහුණුවේආරම්භක අදියරවලදී ඇඩම් ප්රශස්තකරණය සමහර විට නරක දේශීය ප්රශස්තිකරණයකට අභිසාරී වේ; විශේෂයෙන් ට්රාන්ස්ෆෝමර් පුහුණු කිරීමේදී. පර්යේෂණයන් මෙය මැඩපැවැත්වීම සඳහා උණුසුම් කිරීම් භාවිතා කරයි; මූලික පුහුණු පියවර සඳහා (උණුසුම් අවධිය) ඔවුන් අඩු ඉගෙනුම් අනුපාතයක් භාවිතා කරයි. පුහුණුවේ ආරම්භක අදියරවලදී අනුවර්තී ඉගෙනුම් අනුපාතයේ ඉහළ විචලතාව මෙම ලිපිය මඟින් ගැටළුව හඳුනා ගන්නා අතර විචලතාව අඩු කිරීම සඳහා නව නිවැරදි කිරීමේ යෙදුමක් භාවිතා කරමින් එය ගණන් කරයි.

කඩදාසිවිචල්යතා අඩු කිරීමේ යාන්ත්රණ දෙකක් ද ඇගයීමට ලක් කරයි: ඇඩම්-2K: පරාමිතීන් වෙනස් නොකර හෝ ගම්යතාව ගණනය නොකර පළමු 2k පියවර තුළ අනුවර්තී ඉගෙනුම් අනුපාතය ( ආදම්හි) පමණක් ගණනය කරන්න ( ). ඇඩම්-ඊපීඑස්: ආදම් විශාල .

නිවැරදිකරන ලද ආදම්

ගම්යතාව සහ අනුවර්තී ඉගෙනුම් අනුපාතය ගණනය කිරීම සඳහා කාර්යයන් කරමු. ආදම් සඳහා, ඔවුන්

සරලවෙනස්වන සාමාන්යය ලෙස ඝාතීය වෙනස්වන සාමාන්යය

ඝාතීයචලනය වන සාමාන්යය බෙදා හැරීම සරල චලනය වන සාමාන්යයක් ලෙස ආසන්න කළ හැකිය.

මෙන්නඅපි අවසාන ශ්රේණියේ සරල චලනය වන සාමාන්යය ගන්නෙමු. පහත සඳහන් දෑ තෘප්තිමත් කරයි,

ලබාදෙන,

පරිමාණයකළ ප්රතිලෝම චි-කොටු

ඉහළින්අපට කොහේද තිබේ . මෙහි සම්මත අපගමනය හා ගම්යතාව සඳහා වඩා වෙනස් බව සලකන්න.

පරිමාණය කරන ලද ප්රතිලෝම චි-චතුරස්රාකාර යනු සාමාන්ය බෙදාහැරීම් වල මධ්යන්යයේ ප්රතිලෝම චතුරස්රාකාර බෙදා හැරීමයි. කොහෙද .

නිවැරදිකිරීම

ඔවුන්විචලනය වන විට අඩු වන බව ඔවුහු ඔප්පු කරති.

එබැවින්විචලතාව උපරිම වශයෙන් අවම කර ඇත. අවම විචලතාව වීමට ඉඩ දෙන්න

අනුවර්තීඉගෙනුම් අනුපාතයට ස්ථාවර විචලතාවයක් ඇති බව සහතික කිරීම සඳහා, අපි විචලනය නිවැරදි කරමු

ආසන්නකිරීම

ඔවුන් 🤪 පළමු පිණිස පුළුල් මත පදනම් තක්සේරු මම එය ව්යුත්පන්න කරන ආකාරය ලැබුණේ නැහැ.

බෙදා හැරීමේ සිට අප සතුව ඇත,

ලබාදෙන,

නිවැරදිකිරීමේ පදය

අපසතුව ඇත

කොහේද? Lt සහ පියවර වන්න , සහ පියවරෙන් පියවර නිවැරදි කිරීමේ පදය වන්න .

මෙයලබා දෙයි,

139import math
140from typing import Dict, Optional
141
142import torch
143
144from labml_nn.optimizers import WeightDecay
145from labml_nn.optimizers.amsgrad import AMSGrad

නිවැරදිකරන ලද ආදම් ප්රශස්තකරණය

මෙමපන්තිය ඇම්සාඩම් ප්රශස්තකරණයෙන් අර්ථ දක්වා ඇත amsadam.py .

148class RAdam(AMSGrad):

ප්රශස්තකරණයආරම්භ කරන්න

  • params යනු පරාමිතීන් ලැයිස්තුවයි
  • lr යනු ඉගෙනුම් අනුපාතයයි
  • betas (, ) ක tuple වේ
  • eps හෝ මත පදනම් වේ optimized_update
  • weight_decay WeightDecay අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.py
  • optimized_update එකතු කිරීමෙන් පසු එය කිරීමෙන් දෙවන මොහොතේ පක්ෂග්රාහීව නිවැරදි කිරීම ප්රශස්ත කිරීම සඳහා ධජයකි
  • amsgrad ආදම් සරල කිරීම සඳහා AMSGrad හෝ වැටීම භාවිතා කළ යුතුද යන්න දැක්වෙන ධජයකි
  • degenerate_to_sgd නිවැරදි කිරීමේ පදය නොසැලකිය හැකි විට sgd භාවිතා කළ යුතුද යන්න.
  • defaults කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් RAdam වේ.
155    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
156                 weight_decay: WeightDecay = WeightDecay(),
157                 optimized_update: bool = True,
158                 amsgrad=False,
159                 degenerated_to_sgd=True, defaults=None):
175        self.degenerated_to_sgd = degenerated_to_sgd
176        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)

දීඇති පරාමිති ටෙන්සරයක් සඳහා යාවත්කාලීන පියවරක් ගන්න

  • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
  • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
  • grad පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ
  • param පරාමිතිය tensor වේ
  • 178    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

    බරක්ෂය වීම ගණනය කරන්න

    189        grad = self.weight_decay(param, grad, group)

    ලබාගන්න සහ ; එනම් සහ පක්ෂග්රාහී නිවැරදි කිරීමකින් තොරව

    192        m, v = self.get_mv(state, group, grad)

    ප්රශස්තිකරණපියවර ගණන ගණනය කරන්න

    195        state['step'] += 1

    RadAM යාවත්කාලීන කිරීම සිදු

    198        self.r_adam_update(state, group, param, m, v)

    නිවැරදිකිරීමේ පදය ගණනය කරන්න

    200    @staticmethod
    201    def calc_rectification_term(beta2: float, step: int) -> Optional[float]:

    207        beta2_t = beta2 ** step

    209        rho_inf = 2 / (1 - beta2) - 1

    211        rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)

    විට සොයාගත හැකිය . එය ආසන්න අගයක් බැවින් අපි තව ටිකක් ගතානුගතික වෙමු

    215        if rho >= 5:

    217            r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)
    218            return math.sqrt(r2)
    219        else:
    220            return None

    RadAM පරාමිති යාවත්කාලීන කිරීම කරන්න

    • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
    • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
    • param පරාමිතිය tensor වේ
    • m v සහ නිවැරදි නොකළ පළමු හා දෙවන අවස්ථා සහ ; i.e. පක්ෂග්රාහී නිවැරදි කිරීමකින් තොරව
    222    def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
    223                      m: torch.Tensor, v: torch.Tensor):

    ලබා ගන්න

    235        beta1, beta2 = group['betas']

    සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,

    237        bias_correction1 = 1 - beta1 ** state['step']

    සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,

    239        bias_correction2 = 1 - beta2 ** state['step']
    240
    241        r = self.calc_rectification_term(beta2, state['step'])

    ඉගෙනුම්අනුපාතය ලබා ගන්න

    244        lr = self.get_lr(state, group)

    ඇද ගත නොහැකි නම්

    247        if r is not None:

    Scalarගණනය ඒකාබද්ධ විසින් ගණනය උපරිම ඵල ලබා ගැනීම සඳහා යන්න

    249            if self.optimized_update:

    නිගණ්ඨයා

    251                denominator = v.sqrt().add_(group['eps'])

    පියවරප්රමාණය

    253                step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1

    පරාමිතීන්යාවත්කාලීන කරන්න

    256                param.data.addcdiv_(m, denominator, value=-step_size)

    ප්රශස්තිකරණයකින්තොරව ගණනය කිරීම

    258            else:

    නිගණ්ඨයා

    260                denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

    පියවරප්රමාණය

    262                step_size = lr * r / bias_correction1

    පරාමිතීන්යාවත්කාලීන කරන්න

    265                param.data.addcdiv_(m, denominator, value=-step_size)

    ලබාගත නොහැකි නම් ගම්යතාව සමග SGD කරන්න

    268        elif self.degenerated_to_sgd:

    පියවරප්රමාණය

    270            step_size = lr / bias_correction1

    පරාමිතීන්යාවත්කාලීන කරන්න

    273            param.data.add_(m, alpha=-step_size)

    විවිධ සඳහා කුමන්ත්රණය

    Plot of r_t

    276def _test_rectification_term():
    282    import matplotlib.pyplot as plt
    283    import numpy as np
    284
    285    beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]
    286    plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])
    287    plt.legend(beta2)
    288    plt.title("Optimizer")
    289    plt.show()
    290
    291
    292if __name__ == '__main__':
    293    _test_rectification_term()