මෙය PyTorch ක්රියාත්මක කිරීමකි ජනප්රිය ප්රශස්තිකරණය කඩදාසි ඇඩම් වෙතින් ඇඩම්: ස්ටොචාස්ටික් ප්රශස්තිකරණය සඳහා ක්රමයක් .
ආදම් යාවත්කාලීන කිරීම,
කොහෙද , සහ පරිමාණ අධි පරාමිතීන් වේ. සහ පළමු හා දෙවන ඇණවුම් අවස්ථා වේ. සහ පක්ෂග්රාහී නිවැරදි අවස්ථාවන් වේ. ශුන්ය දෝෂයකින් බෙදීම සඳහා විසඳුමක් ලෙස භාවිතා කරයි, නමුත් අනුක්රමික විචල්යතාවයට එරෙහිව ක්රියා කරන අධි-පරාමිතියක ආකාරයක් ලෙසද ක්රියා කරයි.
උපකල්පනයකරමින් ගන්නා ලද effective ලදායී පියවර වන්නේ, මෙය මායිම් කරනු ලබන්නේ කවදාද සහ වෙනත් ආකාරයකින් ය. සහ වඩාත් පොදු අවස්ථාවන්හීදී,
40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
ආදම්ප්රශස්තකරණය ක්රියාත්මක කිරීම __init__.py
සඳහා අපි GenericAdaptiveOptimizer
අර්ථ දක්වා ඇති පන්තිය දීර් extend කරමු.
50class Adam(GenericAdaptiveOptimizer):
params
යනු පරාමිතීන් ලැයිස්තුවයි lr
යනු ඉගෙනුම් අනුපාතයයි betas
(, ) ක tuple වේ eps
හෝ මත පදනම් වේ optimized_update
weight_decay
WeightDecay
අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.py
optimized_update
එකතු කිරීමෙන් පසු එය කිරීමෙන් දෙවන මොහොතේ පක්ෂග්රාහීව නිවැරදි කිරීම ප්රශස්ත කිරීම සඳහා ධජයකි defaults
කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් Adam
වේ. 58 def __init__(self, params,
59 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60 weight_decay: WeightDecay = WeightDecay(),
61 optimized_update: bool = True,
62 defaults: Optional[Dict[str, Any]] = None):
76 defaults = {} if defaults is None else defaults
77 defaults.update(weight_decay.defaults())
78 super().__init__(params, defaults, lr, betas, eps)
79
80 self.weight_decay = weight_decay
81 self.optimized_update = optimized_update
state
පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි param
පරාමිතිය tensor වේ 83 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
පරාමිතියමත ගෙන ඇති ප්රශස්තිකරණ පියවර ගණන මෙයයි,
93 state['step'] = 0
අනුක්රමිකක ඝාතීය වෙනස්වන සාමාන්යය,
95 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
වර්ගඵලය අනුක්රමික වටිනාකම් ඝාතීය වෙනස්වන සාමාන්යය,
97 state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state
පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි grad
පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ 99 def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
ලබා ගන්න
109 beta1, beta2 = group['betas']
ලබා ගන්න
112 m, v = state['exp_avg'], state['exp_avg_sq']
ස්ථානයෙහිගණනය කිරීම
116 m.mul_(beta1).add_(grad, alpha=1 - beta1)
ස්ථානයෙහිගණනය කිරීම
119 v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121 return m, v
මෙයරාජ්යය මත පදනම්ව නවීකරණය කරන ලද ඉගෙනුම් අනුපාතය නැවත ලබා දෙයි. ආදම් සඳහා මෙය පරාමිති කණ්ඩායම සඳහා නිශ්චිත ඉගෙනුම් අනුපාතය පමණි, .
123 def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
131 return group['lr']
state
පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි param
පරාමිතිය tensor වේ m
v
සහ නිවැරදි නොකළ පළමු හා දෙවන අවස්ථා සහ . මෙයපහත සඳහන් දේ ගණනය කරයි
සිට , සහ පරිමාණයන් වන අතර අනෙක් ඒවා අපි මෙම ගණනය කිරීම වෙනස් කරන ආතතීන් වේ ගණනය කිරීම ප්රශස්ත කරන්න.
අධි-පරාමිතියලෙස අප සඳහන් කළ යුත්තේ කොහේද?
133 def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134 m: torch.Tensor, v: torch.Tensor):
ලබා ගන්න
166 beta1, beta2 = group['betas']
සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,
168 bias_correction1 = 1 - beta1 ** state['step']
සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,
170 bias_correction2 = 1 - beta2 ** state['step']
ඉගෙනුම්අනුපාතය ලබා ගන්න
173 lr = self.get_lr(state, group)
මෙමගණනය උපරිම ඵල ලබා ගැනීම සඳහා යන්න
176 if self.optimized_update:
178 denominator = v.sqrt().add_(group['eps'])
180 step_size = lr * math.sqrt(bias_correction2) / bias_correction1
183 param.data.addcdiv_(m, denominator, value=-step_size)
ප්රශස්තිකරණයකින්තොරව ගණනය කිරීම
185 else:
187 denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
189 step_size = lr / bias_correction1
192 param.data.addcdiv_(m, denominator, value=-step_size)
state
පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor) group
පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි grad
පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ param
පරාමිතිය tensor වේ 194 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
බරක්ෂය වීම ගණනය කරන්න
205 grad = self.weight_decay(param, grad, group)
ලබා ගන්න
208 m, v = self.get_mv(state, group, grad)
ප්රශස්තිකරණපියවර ගණන වැඩි කරන්න
211 state['step'] += 1
ආදම් යාවත්කාලීන කිරීම සිදු
214 self.adam_update(state, group, param, m, v)