මෙයපදනම් වී ඇත්තේ AadaBelief කඩදාසි නිල වශයෙන් ක්රියාත්මක කිරීමෙනි AadaBelief Optimizer: නිරීක්ෂණය කරන ලද ශ්රේණියේ විශ්වාසය අනුව පියවර අනුවර්තනය කිරීම .
මෙය RADAM හි දිගුවක් ලෙස PyTorch හි ක්රියාත්මක වේ.
ආදම්ප්රශස්තකරණය සහ ඇඩබලිෆ් අතර ඇති ප්රධාන වෙනස නම්, එය අනුවර්තී ඉගෙනුම් අනුපාතය ගණනය කරන්නේ කෙසේද යන්නයි; අනුක්රමික වර්ග වල on ාතීය චලනය වන සාමාන්යයෙන් බෙදීම වෙනුවට, ඇඩබීලීෆ් විචලනය වන on ාතීය මධ්යන්යයෙන් බෙදේ.
🤔කඩදාසි විචලතාව ගණනය කරයි , නමුත් එය නැඹුරුව නිවැරදි කළ ගම්යතාව භාවිතා කළ යුතු යැයි මට හැඟේ . නැඹුරුව නිවැරදි කිරීම මූලික පුහුණු පියවරවලින් පසුව වන බැවින් මෙය බොහෝ දේට බලපාන්නේ නැතැයි මම සිතමි.
36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdam
45class AdaBelief(RAdam):
params
යනු පරාමිතීන් ලැයිස්තුවයි lr
යනු ඉගෙනුම් අනුපාතයයි betas
(, ) ක tuple වේ eps
හෝ මත පදනම් වේ optimized_update
weight_decay
WeightDecay
අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.py
optimized_update
එකතු කිරීමෙන් පසු එය කිරීමෙන් දෙවන මොහොතේ පක්ෂග්රාහීව නිවැරදි කිරීම ප්රශස්ත කිරීම සඳහා ධජයකි amsgrad
ආදම් සරල කිරීම සඳහා AMSGrad හෝ වැටීම භාවිතා කළ යුතුද යන්න දැක්වෙන ධජයකි degenerate_to_sgd
නිවැරදි කිරීමේ පදය සොයාගත නොහැකි විට sgd භාවිතා කළ යුතුද යන්න rectify
RaDam යාවත්කාලීන කිරීම භාවිතා කළ යුතුද යන්න defaults
කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් AdaBelief
වේ. 52 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54 degenerate_to_sgd=True,
55 rectify=True, defaults=None):
73 defaults = {} if defaults is None else defaults
74 super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75 self.rectify = rectify
state
පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි param
පරාමිතිය tensor වේ 77 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
85 state['step'] = 0
ඵලයඅනුක්රමික වටිනාකම් ඝාතීය වෙනස්වන සාමාන්යය
87 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
විචලතාවයේඝාතීය චලනය වන සාමාන්යය
89 state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
amsgrad
ධජය මෙම පරාමිතිය පිරිසක් True
සඳහා නම්, අපි විචලතාව ඝාතීය වෙනස්වන සාමාන්ය උපරිම පවත්වා
93 if group['amsgrad']:
සියලුඑක්ස්ප්රස් උපරිම පවත්වාගෙන යයි. ගමන් Avg. වර්ග. ශ්රේණියේ. අගයන්
95 state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state
පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි grad
පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ 97 def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
ලබා ගන්න
107 beta1, beta2 = group['betas']
ලබා ගන්න
110 m, s = state['exp_avg'], state['exp_avg_var']
ස්ථානයෙහිගණනය කිරීම
114 m.mul_(beta1).add_(grad, alpha=1 - beta1)
ශ්රේණියසහ ගම්යතාව අතර වෙනස
116 grad_residual = grad - m
ස්ථානයෙහිගණනය කිරීම
119 s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
මෙමපරාමිති කණ්ඩායම භාවිතා කරන්නේ නම් amsgrad
122 if group['amsgrad']:
ලබාගන්න .
124 s_max = state['max_exp_avg_var']
ගණනයකරන්න .
126 torch.maximum(s_max, s, out=s_max)
127
128 return m, s_max
129 else:
සහ වෙනත් ආකාරයකින්
131 return m, s
state
පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි grad
පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ param
පරාමිතිය tensor වේ 133 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
බරක්ෂය වීම ගණනය කරන්න
144 grad = self.weight_decay(param, grad, group)
ලබා ගන්න
147 m, s = self.get_ms(state, group, grad)
ප්රශස්තිකරණපියවර ගණන වැඩි කරන්න
150 state['step'] += 1
151
152 if not self.rectify:
155 self.adam_update(state, group, param, m, s + group['eps'])
156 else: