此 MNIST 示例使用了这些优化器。
这个文件定义了 Adam 的通用基类及其扩展。由于可重用性,基类有助于以最少的代码实现其他优化器。
我们还为 L2 权重衰减定义了一个特殊的类,这样我们就不必在每个优化器中实现它,并且可以在不更改优化器的情况下轻松扩展到其他权重衰减,例如 L1。
以下是关于 PyTorch 优化器的一些概念:
PyTorch 优化器将参数分组到名为组的集合中。每个组可以有自己的超参数,例如学习率。
在大多数情况下,只有一组。这是你使用初始化优化器的时候,
Optimizer(model.parameters())在初始化优化器时,可以定义多个参数组:
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])在这里,我们传递一个组列表。每个组都是一个字典,其参数位于键 “params” 下。您也可以指定任何超参数。如果未定义 hyper 参数,它们将默认为优化程序级别的默认值。
您可以使用访问(甚至更改)这些组及其超参数optimizer.param_groups
。我遇到的大多数学习率计划实现都访问了这个并更改了 “lr”。
Optimizer 在字典中维护每个参数(张量)的状态(字典)optimizer.state
。这是优化器维护指数平均值之类的东西的地方。
62from typing import Dict, Tuple, Any
63
64import torch
65from torch import nn
66from torch.optim.optimizer import Optimizer69class GenericAdaptiveOptimizer(Optimizer):74    def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):检查超参数
86        if not 0.0 <= lr:
87            raise ValueError(f"Invalid learning rate: {lr}")
88        if not 0.0 <= eps:
89            raise ValueError(f"Invalid epsilon value: {eps}")
90        if not 0.0 <= betas[0] < 1.0:
91            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
92        if not 0.0 <= betas[1] < 1.0:
93            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")将超参数添加到默认值
96        defaults.update(dict(lr=lr, betas=betas, eps=eps))初始化 PyTorch 优化器。这将使用默认的超参数创建参数组
99        super().__init__(params, defaults)101    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):108        pass110    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):119        pass121    @torch.no_grad()
122    def step(self, closure=None):133        loss = None
134        if closure is not None:
135            with torch.enable_grad():
136                loss = closure()遍历参数组
139        for group in self.param_groups:遍历参数组中的参数
141            for param in group['params']:如果参数没有渐变,则跳过
143                if param.grad is None:
144                    continue获取梯度张量
146                grad = param.grad.data我们不处理稀疏渐变
148                if grad.is_sparse:
149                    raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
150                                       ' please consider SparseAdam instead')获取参数的状态
153                state = self.state[param]如果状态未初始化,则初始化状态
156                if len(state) == 0:
157                    self.init_state(state, group, param)对参数采取优化步骤
160                self.step_param(state, group, grad, param)返回从闭包计算得出的损失
163        return loss166class WeightDecay:weight_decay
是衰减系数weight_decouple
是一个标志,指示是将权重衰减添加到梯度还是直接从参数中衰减。如果添加到渐变中,它将通过普通的优化器更新。absolute
此标志指示权重衰减系数是否为绝对值。当直接对参数执行衰减时,这适用。如果此值为假,则实际衰减为weight_decay
learning_rate
。171    def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):检查超参数
184        if not 0.0 <= weight_decay:
185            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
186
187        self.absolute = absolute
188        self.weight_decouple = weight_decouple
189        self.weight_decay = weight_decay返回参数组的默认值
191    def defaults(self):195        return dict(weight_decay=self.weight_decay)197    def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):如果我们直接对参数进行衰减
203        if self.weight_decouple:如果权重衰减系数为绝对值
205            if self.absolute:
206                param.data.mul_(1.0 - group['weight_decay'])否则,
208            else:
209                param.data.mul_(1.0 - group['lr'] * group['weight_decay'])返回未修改的渐变
211            return grad
212        else:
213            if group['weight_decay'] != 0:将权重衰减添加到渐变并返回修改后的渐变
215                return grad.add(param.data, alpha=group['weight_decay'])
216            else:
217                return grad