这是论文《亚当:随机优化方法》中流行的优化器 Adam 的 Py Torch 实现。
亚当的更新是,
其中、和是标量超级参数。和是一阶和二阶时刻。并且是有偏差的校正时刻。用作除以零误差的修复,但也用作对梯度方差起作用的超参数的一种形式。
假设采取的有效步骤是,这受限于、何时以及其他方面。在大多数常见情况下,
40import math
41from typing import Dict, Any, Tuple, Optional
42
43import torch
44from labml import tracker
45from torch import nn
46
47from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay50class Adam(GenericAdaptiveOptimizer):params
是参数列表lr
是学习率betas
是 (,) 的元组eps
是或基于optimized_update
weight_decay
是在中WeightDecay
定义的类的实例 __init__.py
optimized_update
是一个标志,是否在添加后通过这样做来优化第二个时刻的偏差校正defaults
是组值的默认字典。当你想扩展类时,这很有用Adam
。58    def __init__(self, params,
59                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60                 weight_decay: WeightDecay = WeightDecay(),
61                 optimized_update: bool = True,
62                 defaults: Optional[Dict[str, Any]] = None):76        defaults = {} if defaults is None else defaults
77        defaults.update(weight_decay.defaults())
78        super().__init__(params, defaults, lr, betas, eps)
79
80        self.weight_decay = weight_decay
81        self.optimized_update = optimized_update83    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):这是优化器对参数采取的步骤数,
93        state['step'] = 0梯度的指数移动平均线,
95        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)梯度平方值的指数移动平均线,
97        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)99    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):获取和
109        beta1, beta2 = group['betas']获取和
112        m, v = state['exp_avg'], state['exp_avg_sq']就地计算
116        m.mul_(beta1).add_(grad, alpha=1 - beta1)就地计算
119        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121        return m, v123    def get_lr(self, state: Dict[str, any], group: Dict[str, any]):131        return group['lr']state
是参数(张量)的优化器状态group
存储参数组的优化程序属性param
是参数张量m
并且v
是未校正的第一和第二时刻,以及.这计算出以下内容
由于、和是标量,其他是张量,因此我们将此计算修改为优化计算。
wher e 是我们应该指定为超参数的内容。
133    def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134                    m: torch.Tensor, v: torch.Tensor):获取和
166        beta1, beta2 = group['betas']偏差校正术语,
168        bias_correction1 = 1 - beta1 ** state['step']偏差校正术语,
170        bias_correction2 = 1 - beta2 ** state['step']获取学习率
173        lr = self.get_lr(state, group)是否优化计算
176        if self.optimized_update:178            denominator = v.sqrt().add_(group['eps'])180            step_size = lr * math.sqrt(bias_correction2) / bias_correction1183            param.data.addcdiv_(m, denominator, value=-step_size)无需优化的计算
185        else:187            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])189            step_size = lr / bias_correction1192            param.data.addcdiv_(m, denominator, value=-step_size)194    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):计算体重衰减
205        grad = self.weight_decay(param, grad, group)获取和
208        m, v = self.get_mv(state, group, grad)增加优化器步数
211        state['step'] += 1执行 Adam 更新
214        self.adam_update(state, group, param, m, v)