ADABeliefප්රශස්තකරණය

මෙයපදනම් වී ඇත්තේ AadaBelief කඩදාසි නිල වශයෙන් ක්රියාත්මක කිරීමෙනි AadaBelief Optimizer: නිරීක්ෂණය කරන ලද ශ්රේණියේ විශ්වාසය අනුව පියවර අනුවර්තනය කිරීම .

මෙය RADAM හි දිගුවක් ලෙස PyTorch හි ක්රියාත්මක වේ.

ආදම්ප්රශස්තකරණය සහ ඇඩබලිෆ් අතර ඇති ප්රධාන වෙනස නම්, එය අනුවර්තී ඉගෙනුම් අනුපාතය ගණනය කරන්නේ කෙසේද යන්නයි; අනුක්රමික වර්ග වල on ාතීය චලනය වන සාමාන්යයෙන් බෙදීම වෙනුවට, ඇඩබීලීෆ් විචලනය වන on ාතීය මධ්යන්යයෙන් බෙදේ.

🤔කඩදාසි විචලතාව ගණනය කරයි , නමුත් එය නැඹුරුව නිවැරදි කළ ගම්යතාව භාවිතා කළ යුතු යැයි මට හැඟේ . නැඹුරුව නිවැරදි කිරීම මූලික පුහුණු පියවරවලින් පසුව වන බැවින් මෙය බොහෝ දේට බලපාන්නේ නැතැයි මම සිතමි.

36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam

ADABeliefප්රශස්තකරණය

මෙමපන්තිය අර්ථ දක්වා ඇති RadAM ප්රශස්තකරණයෙන් විහිදේ radam.py .

45class AdaBelief(RAdam):

ප්රශස්තකරණයආරම්භ කරන්න

  • params යනු පරාමිතීන් ලැයිස්තුවයි
  • lr යනු ඉගෙනුම් අනුපාතයයි
  • betas (, ) ක tuple වේ
  • eps හෝ මත පදනම් වේ optimized_update
  • weight_decay WeightDecay අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.py
  • optimized_update එකතු කිරීමෙන් පසු එය කිරීමෙන් දෙවන මොහොතේ පක්ෂග්රාහීව නිවැරදි කිරීම ප්රශස්ත කිරීම සඳහා ධජයකි
  • amsgrad ආදම් සරල කිරීම සඳහා AMSGrad හෝ වැටීම භාවිතා කළ යුතුද යන්න දැක්වෙන ධජයකි
  • degenerate_to_sgd නිවැරදි කිරීමේ පදය සොයාගත නොහැකි විට sgd භාවිතා කළ යුතුද යන්න
  • rectify RaDam යාවත්කාලීන කිරීම භාවිතා කළ යුතුද යන්න
  • defaults කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් AdaBelief වේ.
52    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53                 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54                 degenerate_to_sgd=True,
55                 rectify=True, defaults=None):
73        defaults = {} if defaults is None else defaults
74        super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75        self.rectify = rectify

පරාමිතිතත්වයක් ආරම්භ කරන්න

  • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
  • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
  • param පරාමිතිය tensor වේ
  • 77    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
    85        state['step'] = 0

    ඵලයඅනුක්රමික වටිනාකම් ඝාතීය වෙනස්වන සාමාන්යය

    87        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

    විචලතාවයේඝාතීය චලනය වන සාමාන්යය

    89        state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)

    amsgrad ධජය මෙම පරාමිතිය පිරිසක් True සඳහා නම්, අපි විචලතාව ඝාතීය වෙනස්වන සාමාන්ය උපරිම පවත්වා

    93        if group['amsgrad']:

    සියලුඑක්ස්ප්රස් උපරිම පවත්වාගෙන යයි. ගමන් Avg. වර්ග. ශ්රේණියේ. අගයන්

    95            state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)

    ගණනයකරන්න සහ හෝ

    • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
    • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
  • grad පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ
  • 97    def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):

    ලබා ගන්න

    107        beta1, beta2 = group['betas']

    ලබා ගන්න

    110        m, s = state['exp_avg'], state['exp_avg_var']

    ස්ථානයෙහිගණනය කිරීම

    114        m.mul_(beta1).add_(grad, alpha=1 - beta1)

    ශ්රේණියසහ ගම්යතාව අතර වෙනස

    116        grad_residual = grad - m

    ස්ථානයෙහිගණනය කිරීම

    119        s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)

    මෙමපරාමිති කණ්ඩායම භාවිතා කරන්නේ නම් amsgrad

    122        if group['amsgrad']:

    ලබාගන්න .

    124            s_max = state['max_exp_avg_var']

    ගණනයකරන්න .

    126            torch.maximum(s_max, s, out=s_max)
    127
    128            return m, s_max
    129        else:

    සහ වෙනත් ආකාරයකින්

    131            return m, s

    දීඇති පරාමිති ටෙන්සරයක් සඳහා යාවත්කාලීන පියවරක් ගන්න

    • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
    • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
    • grad පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ
  • param පරාමිතිය tensor වේ
  • 133    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

    බරක්ෂය වීම ගණනය කරන්න

    144        grad = self.weight_decay(param, grad, group)

    ලබා ගන්න

    147        m, s = self.get_ms(state, group, grad)

    ප්රශස්තිකරණපියවර ගණන වැඩි කරන්න

    150        state['step'] += 1
    151
    152        if not self.rectify:

    ආදම් යාවත්කාලීන කිරීම සිදු කරන්න adam.py , අර්ථ දක්වා ඇත, වෙනුවට .

    155            self.adam_update(state, group, param, m, s + group['eps'])
    156        else:

    නිවැරදිකරන ලද ආදම් යාවත්කාලීන කිරීම සිදු කරන්න radam.py , වෙනුවට අර්ථ දක්වා ඇත.

    159            self.r_adam_update(state, group, param, m, s + group['eps'])