この MNIST の例では、これらのオプティマイザーを使用しています。
このファイルは、Adam の共通基本クラスとその拡張を定義しています。基本クラスは、再利用が可能なため、最小限のコードで他のオプティマイザを実装するのに役立ちます
。また、L2の重み減衰用の特別なクラスを定義しているので、各オプティマイザー内に実装する必要がなく、オプティマイザーを変更せずにL1のような他の重み減衰にも簡単に拡張できます。
PyTorch オプティマイザの概念は次のとおりです。
PyTorch オプティマイザーは、パラメーターをグループと呼ばれるセットにグループ化します。各グループには、学習率などの独自のハイパーパラメータを設定できます
。たいていの場合、グループが 1 つしかありません。このとき、オプティマイザを次のように初期化します
。Optimizer(model.parameters())オプティマイザを初期化するときに、複数のパラメータグループを定義できます。
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])ここにグループのリストを渡します。各グループは辞書で、パラメータは 'params' です。任意のハイパーパラメータも指定します。ハイパーパラメータが定義されていない場合は、デフォルトでオプティマイザレベルのデフォルトになります
。を使用してこれらのグループとそのハイパーパラメータにアクセスしたり、変更したりすることができます。optimizer.param_groups
私が出会ったほとんどの学習率スケジュールの実装は、これにアクセスして「lr」を変更します
オプティマイザーは、各パラメーター (テンソル) の状態 (辞書) を辞書に保持します。optimizer.state
ここで、オプティマイザーは指数平均などを管理します
62from typing import Dict, Tuple, Any
63
64import torch
65from torch import nn
66from torch.optim.optimizer import Optimizer69class GenericAdaptiveOptimizer(Optimizer):params
パラメータのコレクションまたはパラメータグループのセットです。defaults
デフォルトのハイパーパラメータの辞書lr
は学習率 betas
はタプルです eps
は 74    def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):ハイパーパラメータを確認
86        if not 0.0 <= lr:
87            raise ValueError(f"Invalid learning rate: {lr}")
88        if not 0.0 <= eps:
89            raise ValueError(f"Invalid epsilon value: {eps}")
90        if not 0.0 <= betas[0] < 1.0:
91            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
92        if not 0.0 <= betas[1] < 1.0:
93            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")ハイパーパラメータをデフォルトに追加
96        defaults.update(dict(lr=lr, betas=betas, eps=eps))PyTorch オプティマイザーを初期化します。これにより、デフォルトのハイパーパラメータを使用してパラメータグループが作成されます
99        super().__init__(params, defaults)state
これをオーバーライドしてパラメータを初期化するコードを使うべきです。param
group
param
が属するパラメータグループディクショナリです。
101    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):108        passこれをオーバーライドして、param
テンソルで最適化ステップを実行する必要があります。ここでgrad
、はそのパラメーターの勾配、はそのパラメーターのオプティマイザー状態ディクショナリ、state
group
はディクショナリが属するパラメーターグループです。 param
110    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):119        pass121    @torch.no_grad()
122    def step(self, closure=None):損失を計算します。
🤔 いつこれが必要なのかわかりません。自分で呼び出すのではなく、loss.backward
損失を計算して損失を出して返す関数を定義すれば、その関数を渡せると思いますoptimizer.step
。🤷♂️
133        loss = None
134        if closure is not None:
135            with torch.enable_grad():
136                loss = closure()パラメータグループを繰り返し処理する
139        for group in self.param_groups:パラメータグループ内のパラメータを繰り返し処理します
141            for param in group['params']:パラメータにグラデーションがない場合はスキップ
143                if param.grad is None:
144                    continue勾配テンソルを取得
146                grad = param.grad.dataスパースグラデーションは扱いません
148                if grad.is_sparse:
149                    raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
150                                       ' please consider SparseAdam instead')パラメータの状態を取得
153                state = self.state[param]状態が初期化されていない場合は状態を初期化します
156                if len(state) == 0:
157                    self.init_state(state, group, param)パラメータの最適化手順を実行してください
160                self.step_param(state, group, grad, param)決済から計算した損失額を返金
163        return loss166class WeightDecay:weight_decay
は減衰係数weight_decouple
グラデーションにウェイトディケイを追加するか、パラメータから直接ディケイを加えるかを示すフラグです。グラデーションに追加すると、通常のオプティマイザーの更新が行われますabsolute
このフラグは重量減衰係数が絶対値かどうかを示します。これは、ディケイをパラメータに直接適用する場合に適用できます。これが false の場合、実際の減衰は weight_decay
learning_rate
。171    def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):ハイパーパラメータをチェック
184        if not 0.0 <= weight_decay:
185            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
186
187        self.absolute = absolute
188        self.weight_decouple = weight_decouple
189        self.weight_decay = weight_decayパラメータグループのデフォルト値を返す
191    def defaults(self):195        return dict(weight_decay=self.weight_decay)197    def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):パラメータで直接ディケイを行う場合
203        if self.weight_decouple:重量減衰係数が絶対値の場合
205            if self.absolute:
206                param.data.mul_(1.0 - group['weight_decay'])それ以外の場合は、
208            else:
209                param.data.mul_(1.0 - group['lr'] * group['weight_decay'])変更されていないグラデーションを返す
211            return grad
212        else:
213            if group['weight_decay'] != 0:グラデーションにウェイトディケイを追加し、変更したグラデーションを返します。
215                return grad.add(param.data, alpha=group['weight_decay'])
216            else:
217                return grad