මෙයපදනම් වී ඇත්තේ AadaBelief කඩදාසි නිල වශයෙන් ක්රියාත්මක කිරීමෙනි AadaBelief Optimizer: නිරීක්ෂණය කරන ලද ශ්රේණියේ විශ්වාසය අනුව පියවර අනුවර්තනය කිරීම .
මෙය RADAM හි දිගුවක් ලෙස PyTorch හි ක්රියාත්මක වේ.
ආදම්ප්රශස්තකරණය සහ ඇඩබලිෆ් අතර ඇති ප්රධාන වෙනස නම්, එය අනුවර්තී ඉගෙනුම් අනුපාතය ගණනය කරන්නේ කෙසේද යන්නයි; අනුක්රමික වර්ග වල on ාතීය චලනය වන සාමාන්යයෙන් බෙදීම වෙනුවට, ඇඩබීලීෆ් විචලනය වන on ාතීය මධ්යන්යයෙන් බෙදේ.
🤔කඩදාසි විචලතාව ගණනය කරයි , නමුත් එය නැඹුරුව නිවැරදි කළ ගම්යතාව භාවිතා කළ යුතු යැයි මට හැඟේ . නැඹුරුව නිවැරදි කිරීම මූලික පුහුණු පියවරවලින් පසුව වන බැවින් මෙය බොහෝ දේට බලපාන්නේ නැතැයි මම සිතමි.
36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam45class AdaBelief(RAdam):params
 යනු පරාමිතීන් ලැයිස්තුවයි lr
 යනු ඉගෙනුම් අනුපාතයයි  betas
 (, ) ක tuple වේ eps
  හෝ මත  පදනම් වේ optimized_update
 weight_decay
 WeightDecay
 අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.py
 optimized_update
 එකතු කිරීමෙන් පසු එය කිරීමෙන් දෙවන මොහොතේ පක්ෂග්රාහීව නිවැරදි කිරීම ප්රශස්ත කිරීම සඳහා ධජයකි  amsgrad
 ආදම් සරල කිරීම සඳහා AMSGrad හෝ වැටීම භාවිතා කළ යුතුද යන්න දැක්වෙන ධජයකි degenerate_to_sgd
 නිවැරදි කිරීමේ පදය  සොයාගත නොහැකි විට sgd භාවිතා කළ යුතුද යන්න rectify
 RaDam යාවත්කාලීන කිරීම භාවිතා කළ යුතුද යන්න defaults
 කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් AdaBelief
වේ. 52    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53                 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54                 degenerate_to_sgd=True,
55                 rectify=True, defaults=None):73        defaults = {} if defaults is None else defaults
74        super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75        self.rectify = rectifystate
 පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
 පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි param
 පරාමිතිය tensor වේ 77    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):85        state['step'] = 0ඵලයඅනුක්රමික වටිනාකම් ඝාතීය වෙනස්වන සාමාන්යය
87        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)විචලතාවයේඝාතීය චලනය වන සාමාන්යය
89        state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)amsgrad
 ධජය මෙම පරාමිතිය පිරිසක් True
 සඳහා නම්, අපි විචලතාව ඝාතීය වෙනස්වන සාමාන්ය උපරිම පවත්වා 
93        if group['amsgrad']:සියලුඑක්ස්ප්රස් උපරිම පවත්වාගෙන යයි. ගමන් Avg. වර්ග. ශ්රේණියේ. අගයන්
95            state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)state
 පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
 පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි grad
 පරාමිතිය  සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ 97    def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):ලබා ගන්න
107        beta1, beta2 = group['betas']ලබා ගන්න
110        m, s = state['exp_avg'], state['exp_avg_var']ස්ථානයෙහිගණනය කිරීම
114        m.mul_(beta1).add_(grad, alpha=1 - beta1)ශ්රේණියසහ ගම්යතාව අතර වෙනස
116        grad_residual = grad - mස්ථානයෙහිගණනය කිරීම
119        s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)මෙමපරාමිති කණ්ඩායම භාවිතා කරන්නේ නම් amsgrad
 
122        if group['amsgrad']:ලබාගන්න .
124            s_max = state['max_exp_avg_var']ගණනයකරන්න .
126            torch.maximum(s_max, s, out=s_max)
127
128            return m, s_max
129        else:සහ වෙනත් ආකාරයකින්
131            return m, sstate
 පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
 පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි grad
 පරාමිතිය  සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ  param
 පරාමිතිය tensor වේ 133    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):බරක්ෂය වීම ගණනය කරන්න
144        grad = self.weight_decay(param, grad, group)ලබා ගන්න
147        m, s = self.get_ms(state, group, grad)ප්රශස්තිකරණපියවර ගණන වැඩි කරන්න
150        state['step'] += 1
151
152        if not self.rectify:155            self.adam_update(state, group, param, m, s + group['eps'])
156        else: