This MNIST example uses these optimizers.
This file defines a common base class for Adam and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.
We also define a special class for L2 weight decay, so that we don’t have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.
Here are some concepts on PyTorch optimizers:
PyTorch optimizers group parameters into sets called groups. Each group can have it’s own hyper-parameters like learning rates.
In most common cases there will be only one group. This is when you initialize your optimizer with,
Optimizer(model.parameters())
You can define multiple parameter groups when initializing the optimizer:
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
Here we pass a list of groups. Each group is a dictionary with it’s parameters under the key ‘params’. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.
You can access (and even change) these groups, and their hyper-parameters with optimizer.param_groups
.
Most learning rate schedule implementations I’ve come across do access this and change ‘lr’.
Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary optimizer.state
.
This is where the optimizer maintains things like exponential averages.
59from typing import Dict, Tuple, Any
60
61import torch
62from torch import nn
63from torch.optim.optimizer import Optimizer
66class GenericAdaptiveOptimizer(Optimizer):
params
is the collection of parameters or set of parameter groups.defaults
a dictionary of default hyper-parametersbetas
is the tuple $(\beta_1, \beta_2)$eps
is $\epsilon$71 def __init__(self, params, defaults: Dict[str, Any], lr: float, betas: Tuple[float, float], eps: float):
Check the hyper-parameters
83 if not 0.0 <= lr:
84 raise ValueError(f"Invalid learning rate: {lr}")
85 if not 0.0 <= eps:
86 raise ValueError(f"Invalid epsilon value: {eps}")
87 if not 0.0 <= betas[0] < 1.0:
88 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
89 if not 0.0 <= betas[1] < 1.0:
90 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
Add the hyper-parameters to the defaults
93 defaults.update(dict(lr=lr, betas=betas, eps=eps))
Initialize the PyTorch optimizer. This will create parameter groups with the default hyper-parameters
96 super().__init__(params, defaults)
This should be overridden with code to initialize state
for parameters param
.
group
is the parameter group dictionary to which param
belongs.
98 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
105 pass
This should be overridden and take the optimization step on param
tensor $\theta$,
where grad
is the gradient for that parameter, $g_t$,
state
is the optimizer state dictionary for that parameter, and
group
is the parameter group dictionary param
belongs to.
107 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.Tensor):
116 pass
We have created a template method that does the common stuff every Adam based optimizer needs.
118 @torch.no_grad()
119 def step(self, closure=None):
Calculate loss.
🤔 I’m not sure when you need this. I guess it’s if you define a function that
calculates the loss, does loss.backward
and return the loss, instead of calling
it on your own you could pass it to optimizer.step
. 🤷♂️
130 loss = None
131 if closure is not None:
132 with torch.enable_grad():
133 loss = closure()
Iterate through the parameter groups
136 for group in self.param_groups:
Iterate through the parameters in the parameter group
138 for param in group['params']:
Skip if the parameter has no gradient
140 if param.grad is None:
141 continue
Get the gradient tensor
143 grad = param.grad.data
We don’t handle sparse gradients
145 if grad.is_sparse:
146 raise RuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'
147 ' please consider SparseAdam instead')
Get the state for the parameter
150 state = self.state[param]
Initialize the state if state is uninitialized
153 if len(state) == 0:
154 self.init_state(state, group, param)
Take the optimization step on the parameter
157 self.step_param(state, group, grad, param)
Return the loss, calculated from closure
160 return loss
163class WeightDecay:
weight_decay
is the decay coefficientweight_decouple
is a flag indicating whether to add the weight decay to the gradient or directly
decay from the parameter. If added to the gradient it will go through the normal optimizer update.absolute
this flag indicates whether the weight decay coefficient is absolute. This is applicable
when the decay is performed directly on the parameter. If this is false the actual decay is
weight_decay
* learning_rate
.168 def __init__(self, weight_decay: float = 0., weight_decouple: bool = True, absolute: bool = False):
Check hyper-parameters
180 if not 0.0 <= weight_decay:
181 raise ValueError(f"Invalid weight_decay value: {weight_decay}")
182
183 self.absolute = absolute
184 self.weight_decouple = weight_decouple
185 self.weight_decay = weight_decay
Return defaults for parameter groups
187 def defaults(self):
191 return dict(weight_decay=self.weight_decay)
193 def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):
If we are doing the decay on the parameter directly
199 if self.weight_decouple:
If the weight decay coefficient is absolute
201 if self.absolute:
202 param.data.mul_(1.0 - group['weight_decay'])
Otherwise,
204 else:
205 param.data.mul_(1.0 - group['lr'] * group['weight_decay'])
Return the unmodified gradient
207 return grad
208 else:
209 if group['weight_decay'] != 0:
Add the weight decay to the gradient and return the modified gradient
211 return grad.add(param.data, alpha=group['weight_decay'])
212 else:
213 return grad