ආදම්ප්රශස්තකරණය

මෙය PyTorch ක්රියාත්මක කිරීමකි ජනප්රිය ප්රශස්තිකරණය කඩදාසි ඇඩම් වෙතින් ඇඩම්: ස්ටොචාස්ටික් ප්රශස්තිකරණය සඳහා ක්රමයක් .

ආදම් යාවත්කාලීන කිරීම,

කොහෙද , සහ පරිමාණ අධි පරාමිතීන් වේ. සහ පළමු හා දෙවන ඇණවුම් අවස්ථා වේ. සහ පක්ෂග්රාහී නිවැරදි අවස්ථාවන් වේ. ශුන්ය දෝෂයකින් බෙදීම සඳහා විසඳුමක් ලෙස භාවිතා කරයි, නමුත් අනුක්රමික විචල්යතාවයට එරෙහිව ක්රියා කරන අධි-පරාමිතියක ආකාරයක් ලෙසද ක්රියා කරයි.

උපකල්පනයකරමින් ගන්නා ලද effective ලදායී පියවර වන්නේ, මෙය මායිම් කරනු ලබන්නේ කවදාද සහ වෙනත් ආකාරයකින් ය. සහ වඩාත් පොදු අවස්ථාවන්හීදී,

40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay

ආදම්ප්රශස්තකරණය

ආදම්ප්රශස්තකරණය ක්රියාත්මක කිරීම __init__.py සඳහා අපි GenericAdaptiveOptimizer අර්ථ දක්වා ඇති පන්තිය දීර් extend කරමු.

50class Adam(GenericAdaptiveOptimizer):

ප්රශස්තකරණයආරම්භ කරන්න

  • params යනු පරාමිතීන් ලැයිස්තුවයි
  • lr යනු ඉගෙනුම් අනුපාතයයි
  • betas (, ) ක tuple වේ
  • eps හෝ මත පදනම් වේ optimized_update
  • weight_decay WeightDecay අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.py
  • optimized_update එකතු කිරීමෙන් පසු එය කිරීමෙන් දෙවන මොහොතේ පක්ෂග්රාහීව නිවැරදි කිරීම ප්රශස්ත කිරීම සඳහා ධජයකි
  • defaults කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් Adam වේ.
58    def __init__(self, params,
59                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60                 weight_decay: WeightDecay = WeightDecay(),
61                 optimized_update: bool = True,
62                 defaults: Optional[Dict[str, Any]] = None):
76        defaults = {} if defaults is None else defaults
77        defaults.update(weight_decay.defaults())
78        super().__init__(params, defaults, lr, betas, eps)
79
80        self.weight_decay = weight_decay
81        self.optimized_update = optimized_update

පරාමිතිතත්වයක් ආරම්භ කරන්න

  • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
  • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
  • param පරාමිතිය tensor වේ
  • 83    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

    පරාමිතියමත ගෙන ඇති ප්රශස්තිකරණ පියවර ගණන මෙයයි,

    93        state['step'] = 0

    අනුක්රමිකක ඝාතීය වෙනස්වන සාමාන්යය,

    95        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)

    වර්ගඵලය අනුක්රමික වටිනාකම් ඝාතීය වෙනස්වන සාමාන්යය,

    97        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

    ගණනයකරන්න සහ

    • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
    • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
  • grad පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ
  • 99    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):

    ලබා ගන්න

    109        beta1, beta2 = group['betas']

    ලබා ගන්න

    112        m, v = state['exp_avg'], state['exp_avg_sq']

    ස්ථානයෙහිගණනය කිරීම

    116        m.mul_(beta1).add_(grad, alpha=1 - beta1)

    ස්ථානයෙහිගණනය කිරීම

    119        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
    120
    121        return m, v

    ඉගෙනීම-අනුපාතයලබා ගන්න

    මෙයරාජ්යය මත පදනම්ව නවීකරණය කරන ලද ඉගෙනුම් අනුපාතය නැවත ලබා දෙයි. ආදම් සඳහා මෙය පරාමිති කණ්ඩායම සඳහා නිශ්චිත ඉගෙනුම් අනුපාතය පමණි, .

    123    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
    131        return group['lr']

    ආදම් පරාමිති යාවත්කාලීන කිරීම කරන්න

    • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
    • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
    • param පරාමිතිය tensor වේ
    • m v සහ නිවැරදි නොකළ පළමු හා දෙවන අවස්ථා සහ .

    මෙයපහත සඳහන් දේ ගණනය කරයි

    සිට , සහ පරිමාණයන් වන අතර අනෙක් ඒවා අපි මෙම ගණනය කිරීම වෙනස් කරන ආතතීන් වේ ගණනය කිරීම ප්රශස්ත කරන්න.

    අධි-පරාමිතියලෙස අප සඳහන් කළ යුත්තේ කොහේද?

    133    def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
    134                    m: torch.Tensor, v: torch.Tensor):

    ලබා ගන්න

    166        beta1, beta2 = group['betas']

    සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,

    168        bias_correction1 = 1 - beta1 ** state['step']

    සඳහානැඹුරුව නිවැරදි කිරීමේ පදය ,

    170        bias_correction2 = 1 - beta2 ** state['step']

    ඉගෙනුම්අනුපාතය ලබා ගන්න

    173        lr = self.get_lr(state, group)

    මෙමගණනය උපරිම ඵල ලබා ගැනීම සඳහා යන්න

    176        if self.optimized_update:

    178            denominator = v.sqrt().add_(group['eps'])

    180            step_size = lr * math.sqrt(bias_correction2) / bias_correction1

    183            param.data.addcdiv_(m, denominator, value=-step_size)

    ප්රශස්තිකරණයකින්තොරව ගණනය කිරීම

    185        else:

    187            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

    189            step_size = lr / bias_correction1

    192            param.data.addcdiv_(m, denominator, value=-step_size)

    දීඇති පරාමිති ටෙන්සරයක් සඳහා යාවත්කාලීන පියවරක් ගන්න

    • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
    • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
    • grad පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ
  • param පරාමිතිය tensor වේ
  • 194    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

    බරක්ෂය වීම ගණනය කරන්න

    205        grad = self.weight_decay(param, grad, group)

    ලබා ගන්න

    208        m, v = self.get_mv(state, group, grad)

    ප්රශස්තිකරණපියවර ගණන වැඩි කරන්න

    211        state['step'] += 1

    ආදම් යාවත්කාලීන කිරීම සිදු

    214        self.adam_update(state, group, param, m, v)