This MNIST example uses these optimizers.
This file defines a common base class for Adam and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.
We also define a special class for L2 weight decay, so that we don’t have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.
Here are some concepts on PyTorch optimizers:
PyTorch optimizers group parameters into sets called groups. Each group can have it’s own hyper-parameters like learning rates.
In most common cases there will be only one group. This is when you initialize your optimizer with,
Optimizer(model.parameters())
You can define multiple parameter groups when initializing the optimizer:
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
Here we pass a list of groups. Each group is a dictionary with it’s parameters under the key ‘params’. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.
You can access (and even change) these groups, and their hyper-parameters with optimizer.param_groups.
Most learning rate schedule implementations I’ve come across do access this and change ‘lr’.
Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary optimizer.state.
This is where the optimizer maintains things like exponential averages.
59from typing import Dict, Tuple, Any
60
61import torch
62from torch import nn
63from torch.optim.optimizer import Optimizer66class GenericAdaptiveOptimizer(Optimizer):params is the collection of parameters or set of parameter groups.defaults a dictionary of default hyper-parametersbetas is the tuple $(\beta_1, \beta_2)$eps is $\epsilon$71    def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):Check the hyper-parameters
83        if not 0.0 <= lr:
84            raise ValueError(f"Invalid learning rate: {lr}")
85        if not 0.0 <= eps:
86            raise ValueError(f"Invalid epsilon value: {eps}")
87        if not 0.0 <= betas[0] < 1.0:
88            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
89        if not 0.0 <= betas[1] < 1.0:
90            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")Add the hyper-parameters to the defaults
93        defaults.update(dict(lr=lr, betas=betas, eps=eps))Initialize the PyTorch optimizer. This will create parameter groups with the default hyper-parameters
96        super().__init__(params, defaults)This should be overridden with code to initialize state for parameters param.
group is the parameter group dictionary to which param belongs.
98    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):105        passThis should be overridden and take the optimization step on param tensor $\theta$,
where grad is the gradient for that parameter, $g_t$,
state is the optimizer state dictionary for that parameter, and
group is the parameter group dictionary param belongs to.
107    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):116        passWe have created a template method that does the common stuff every Adam based optimizer needs.
118    @torch.no_grad()
119    def step(self, closure=None):Calculate loss.
🤔 I’m not sure when you need this. I guess it’s if you define a function that
calculates the loss, does loss.backward and return the loss, instead of calling
it on your own you could pass it to optimizer.step. 🤷♂️
130        loss = None
131        if closure is not None:
132            with torch.enable_grad():
133                loss = closure()Iterate through the parameter groups
136        for group in self.param_groups:Iterate through the parameters in the parameter group
138            for param in group['params']:Skip if the parameter has no gradient
140                if param.grad is None:
141                    continueGet the gradient tensor
143                grad = param.grad.dataWe don’t handle sparse gradients
145                if grad.is_sparse:
146                    raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
147                                       ' please consider SparseAdam instead')Get the state for the parameter
150                state = self.state[param]Initialize the state if state is uninitialized
153                if len(state) == 0:
154                    self.init_state(state, group, param)Take the optimization step on the parameter
157                self.step_param(state, group, grad, param)Return the loss, calculated from closure
160        return loss163class WeightDecay:weight_decay is the decay coefficientweight_decouple is a flag indicating whether to add the weight decay to the gradient or directly
decay from the parameter. If added to the  gradient it will go through the normal optimizer update.absolute this flag indicates whether the weight decay coefficient is absolute. This is applicable
when the decay is performed directly on the parameter. If this is false the actual decay is
weight_decay * learning_rate.168    def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):Check hyper-parameters
180        if not 0.0 <= weight_decay:
181            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
182
183        self.absolute = absolute
184        self.weight_decouple = weight_decouple
185        self.weight_decay = weight_decayReturn defaults for parameter groups
187    def defaults(self):191        return dict(weight_decay=self.weight_decay)193    def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):If we are doing the decay on the parameter directly
199        if self.weight_decouple:If the weight decay coefficient is absolute
201            if self.absolute:
202                param.data.mul_(1.0 - group['weight_decay'])Otherwise,
204            else:
205                param.data.mul_(1.0 - group['lr'] * group['weight_decay'])Return the unmodified gradient
207            return grad
208        else:
209            if group['weight_decay'] != 0:Add the weight decay to the gradient and return the modified gradient
211                return grad.add(param.data, alpha=group['weight_decay'])
212            else:
213                return grad