これは、「Adabeliefオプティマイザー:観測された勾配を信じてステップサイズを調整する」という論文のAdableLief公式実装に基づいています。
これは RadAM の拡張機能として PyTorch に実装されています。
Adam オプティマイザーと Adabelief の主な違いは、適応型学習率の計算方法にあります。Adabelief では、勾配の 2 乗の指数移動平均で割るのではなく、指数関数的分散平均で除算されます。
🤔 論文では分散を次のように計算していますが、バイアス補正されたモメンタムを使用すべきだと思います。バイアス補正は最初のトレーニングステップの後に行われるので、これはあまり影響しないと思います。
36from typing import Dict, Any
37
38import torch
39from torch import nn
40
41from labml_nn.optimizers import WeightDecay
42from labml_nn.optimizers.radam import RAdamparams
はパラメータのリストですlr
は学習率 betas
(,)  のタプルです eps
またはそれに基づいている optimized_update
weight_decay
WeightDecay
で定義されているクラスのインスタンスです __init__.py
optimized_update
セカンドモーメントのバイアス補正を加算してから行うことで最適化するか否かのフラグです amsgrad
amsGradを使用するか、プレーンなAdamにフォールバックするかを示すフラグですdegenerate_to_sgd
修正項が扱いにくい場合に sgd を使うかどうか rectify
RadAMアップデートを使用するかどうかですdefaults
グループ値のデフォルト辞書です。これは、クラスを拡張する場合に便利ですAdaBelief
。52    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53                 weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54                 degenerate_to_sgd=True,
55                 rectify=True, defaults=None):73        defaults = {} if defaults is None else defaults
74        super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75        self.rectify = rectify77    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):85        state['step'] = 0勾配値の指数移動平均
87        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)指数移動平均偏差
89        state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)amsgrad
True
このパラメータグループにフラグを指定すると、指数移動平均の最大分散値が維持されます。
93        if group['amsgrad']:すべての許容偏差移動平均値の最大値を維持
95            state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)97    def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):取得して
107        beta1, beta2 = group['betas']取得して
110        m, s = state['exp_avg'], state['exp_avg_var']のインプレース計算
114        m.mul_(beta1).add_(grad, alpha=1 - beta1)勾配と運動量の違い
116        grad_residual = grad - mのインプレース計算
119        s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)このパラメータグループが使用している場合 amsgrad
122        if group['amsgrad']:取得。
124            s_max = state['max_exp_avg_var']計算。
126            torch.maximum(s_max, s, out=s_max)
127
128            return m, s_max
129        else:それ以外は
131            return m, sstate
はパラメーター (テンソル) のオプティマイザー状態ですgroup
パラメータグループのオプティマイザ属性を格納しますgrad
パラメータの現在の勾配テンソルです param
はパラメータテンソル 133    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):体重減少の計算
144        grad = self.weight_decay(param, grad, group)取得して
147        m, s = self.get_ms(state, group, grad)オプティマイザーのステップ数を増やす
150        state['step'] += 1
151
152        if not self.rectify:155            self.adam_update(state, group, param, m, s + group['eps'])
156        else: