これは、論文「アダム:確率的最適化の方法」に掲載された人気のオプティマイザーAdamをPyTorchで実装したものです。
アダムのアップデートは、
ここで、、およびはスカラーのハイパーパラメーターです。ファーストオーダー、セカンドオーダーの瞬間です 偏り修正されたモーメントです。ゼロエラーによる除算の修正として使われますが、勾配のばらつきに対して作用するハイパーパラメータの形式としても機能します
。有効な手順は、「This が制限される」、「いつ」、「それ以外の場合」を前提としています。そして、最も一般的なシナリオでは、
40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay50class Adam(GenericAdaptiveOptimizer):params
はパラメータのリストですlr
は学習率 betas
(,)  のタプルです eps
またはそれに基づいている optimized_update
weight_decay
WeightDecay
で定義されているクラスのインスタンスです __init__.py
optimized_update
セカンドモーメントのバイアス補正を加算してから行うことで最適化するか否かのフラグです defaults
グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdam
。58    def __init__(self, params,
59                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60                 weight_decay: WeightDecay = WeightDecay(),
61                 optimized_update: bool = True,
62                 defaults: Optional[Dict[str, Any]] = None):76        defaults = {} if defaults is None else defaults
77        defaults.update(weight_decay.defaults())
78        super().__init__(params, defaults, lr, betas, eps)
79
80        self.weight_decay = weight_decay
81        self.optimized_update = optimized_update83    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):これは、パラメーターに対して実行されたオプティマイザーステップの数です。
93        state['step'] = 0勾配の指数移動平均、
95        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)二乗勾配値の指数移動平均、
97        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)99    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):取得して
109        beta1, beta2 = group['betas']取得して
112        m, v = state['exp_avg'], state['exp_avg_sq']のインプレース計算
116        m.mul_(beta1).add_(grad, alpha=1 - beta1)のインプレース計算
119        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121        return m, v123    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):131        return group['lr']state
はパラメーター (テンソル) のオプティマイザー状態ですgroup
パラメータグループのオプティマイザ属性を格納しますparam
はパラメータテンソル m
v
そして未修正の第一瞬間と第二瞬間と これにより、以下が計算されます
、はスカラーで、その他はテンソルなので、この計算を変更して計算を最適化します。
ここで、ハイパーパラメータとして指定する必要があります。
133    def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134                    m: torch.Tensor, v: torch.Tensor):取得して
166        beta1, beta2 = group['betas']のバイアス補正用語
168        bias_correction1 = 1 - beta1 ** state['step']のバイアス補正用語
170        bias_correction2 = 1 - beta2 ** state['step']学習率を取得
173        lr = self.get_lr(state, group)計算を最適化するかどうか
176        if self.optimized_update:178            denominator = v.sqrt().add_(group['eps'])180            step_size = lr * math.sqrt(bias_correction2) / bias_correction1183            param.data.addcdiv_(m, denominator, value=-step_size)最適化なしの計算
185        else:187            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])189            step_size = lr / bias_correction1192            param.data.addcdiv_(m, denominator, value=-step_size)state
はパラメーター (テンソル) のオプティマイザー状態ですgroup
パラメータグループのオプティマイザ属性を格納しますgrad
パラメータの現在の勾配テンソルです param
はパラメータテンソル 194    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):体重減少の計算
205        grad = self.weight_decay(param, grad, group)取得して
208        m, v = self.get_mv(state, group, grad)オプティマイザーのステップ数を増やす
211        state['step'] += 1Adam アップデートを実行
214        self.adam_update(state, group, param, m, v)