This MNIST example uses these optimizers.
This file defines a common base class for Adam and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.
We also define a special class for L2 weight decay, so that we don't have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.
Here are some concepts on PyTorch optimizers:
PyTorch optimizers group parameters into sets called groups. Each group can have it's own hyper-parameters like learning rates.
In most common cases there will be only one group. This is when you initialize your optimizer with,
Optimizer(model.parameters())You can define multiple parameter groups when initializing the optimizer:
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])Here we pass a list of groups. Each group is a dictionary with it's parameters under the key 'params'. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.
You can access (and even change) these groups, and their hyper-parameters with optimizer.param_groups
. Most learning rate schedule implementations I've come across do access this and change 'lr'.
Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary optimizer.state
. This is where the optimizer maintains things like exponential averages.
62from typing import Dict, Tuple, Any
63
64import torch
65from torch import nn
66from torch.optim.optimizer import Optimizer69class GenericAdaptiveOptimizer(Optimizer):params
 is the collection of parameters or set of parameter groups. defaults
 a dictionary of default hyper-parameters lr
 is the learning rate,  betas
 is the tuple  eps
 is 74    def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):Check the hyper-parameters
86        if not 0.0 <= lr:
87            raise ValueError(f"Invalid learning rate: {lr}")
88        if not 0.0 <= eps:
89            raise ValueError(f"Invalid epsilon value: {eps}")
90        if not 0.0 <= betas[0] < 1.0:
91            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
92        if not 0.0 <= betas[1] < 1.0:
93            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")Add the hyper-parameters to the defaults
96        defaults.update(dict(lr=lr, betas=betas, eps=eps))Initialize the PyTorch optimizer. This will create parameter groups with the default hyper-parameters
99        super().__init__(params, defaults)This should be overridden with code to initialize state
 for parameters param
. group
 is the parameter group dictionary to which param
 belongs.
101    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):108        passThis should be overridden and take the optimization step on param
 tensor , where grad
 is the gradient for that parameter, , state
 is the optimizer state dictionary for that parameter, and group
 is the parameter group dictionary param
 belongs to.
110    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):119        passWe have created a template method that does the common stuff every Adam based optimizer needs.
121    @torch.no_grad()
122    def step(self, closure=None):Calculate loss.
🤔 I'm not sure when you need this. I guess it's if you define a function that calculates the loss, does loss.backward
 and return the loss, instead of calling it on your own you could pass it to optimizer.step
. 🤷♂️ 
133        loss = None
134        if closure is not None:
135            with torch.enable_grad():
136                loss = closure()Iterate through the parameter groups
139        for group in self.param_groups:Iterate through the parameters in the parameter group
141            for param in group['params']:Skip if the parameter has no gradient
143                if param.grad is None:
144                    continueGet the gradient tensor
146                grad = param.grad.dataWe don't handle sparse gradients
148                if grad.is_sparse:
149                    raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
150                                       ' please consider SparseAdam instead')Get the state for the parameter
153                state = self.state[param]Initialize the state if state is uninitialized
156                if len(state) == 0:
157                    self.init_state(state, group, param)Take the optimization step on the parameter
160                self.step_param(state, group, grad, param)Return the loss, calculated from closure
163        return loss166class WeightDecay:weight_decay
 is the decay coefficient weight_decouple
 is a flag indicating whether to add the weight decay to the gradient or directly decay from the parameter. If added to the gradient it will go through the normal optimizer update. absolute
 this flag indicates whether the weight decay coefficient is absolute. This is applicable when the decay is performed directly on the parameter. If this is false the actual decay is weight_decay
 learning_rate
.171    def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):Check hyper-parameters
184        if not 0.0 <= weight_decay:
185            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
186
187        self.absolute = absolute
188        self.weight_decouple = weight_decouple
189        self.weight_decay = weight_decayReturn defaults for parameter groups
191    def defaults(self):195        return dict(weight_decay=self.weight_decay)197    def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):If we are doing the decay on the parameter directly
203        if self.weight_decouple:If the weight decay coefficient is absolute
205            if self.absolute:
206                param.data.mul_(1.0 - group['weight_decay'])Otherwise,
208            else:
209                param.data.mul_(1.0 - group['lr'] * group['weight_decay'])Return the unmodified gradient
211            return grad
212        else:
213            if group['weight_decay'] != 0:Add the weight decay to the gradient and return the modified gradient
215                return grad.add(param.data, alpha=group['weight_decay'])
216            else:
217                return grad