mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 10:18:50 +08:00
adam comments
This commit is contained in:
@ -67,14 +67,13 @@ class WeightDecay:
|
|||||||
def defaults(self):
|
def defaults(self):
|
||||||
return dict(weight_decay=self.weight_decay)
|
return dict(weight_decay=self.weight_decay)
|
||||||
|
|
||||||
def __call__(self, param: torch.nn.Parameter, group: Dict[str, any]):
|
def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]):
|
||||||
grad = param.grad.data
|
|
||||||
|
|
||||||
if self.weight_decouple:
|
if self.weight_decouple:
|
||||||
if not self.absolute:
|
if not self.absolute:
|
||||||
param.data.mul_(1.0 - group['lr'] * group['weight_decay'])
|
param.data.mul_(1.0 - group['lr'] * group['weight_decay'])
|
||||||
else:
|
else:
|
||||||
param.data.mul_(1.0 - group['weight_decay'])
|
param.data.mul_(1.0 - group['weight_decay'])
|
||||||
|
return grad
|
||||||
else:
|
else:
|
||||||
if group['weight_decay'] != 0:
|
if group['weight_decay'] != 0:
|
||||||
grad.add_(param.data, alpha=group['weight_decay'])
|
return grad.add(param.data, alpha=group['weight_decay'])
|
||||||
|
|||||||
@ -81,7 +81,7 @@ class AdaBelief(RAdam):
|
|||||||
return m, v
|
return m, v
|
||||||
|
|
||||||
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||||
self.weight_decay(param, group)
|
grad = self.weight_decay(param, grad, group)
|
||||||
m, v = self.get_mv(state, group, grad)
|
m, v = self.get_mv(state, group, grad)
|
||||||
state['step'] += 1
|
state['step'] += 1
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,33 @@ This is an implementation of popular optimizer *Adam* from paper
|
|||||||
|
|
||||||
We extend the class `GenericAdaptiveOptimizer` defined in [__init__.py](index.html)
|
We extend the class `GenericAdaptiveOptimizer` defined in [__init__.py](index.html)
|
||||||
to implement the Adam optimizer.
|
to implement the Adam optimizer.
|
||||||
|
|
||||||
|
*Adam* update is,
|
||||||
|
|
||||||
|
\begin{align}
|
||||||
|
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\
|
||||||
|
v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\
|
||||||
|
\hat{m}_t &\leftarrow \frac{m_t}/{1-\beta_1^t} \\
|
||||||
|
\hat{v}_t &\leftarrow \frac{v_t}/{1-\beta_2^t} \\
|
||||||
|
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
|
||||||
|
\end{align}
|
||||||
|
|
||||||
|
where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters.
|
||||||
|
$m_t$ and $v_t$ are first and second order moments.
|
||||||
|
$\hat{m}_t$ and $\hat{v}_t$ are biased corrected moments.
|
||||||
|
$\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter
|
||||||
|
that acts against variance in gradients.
|
||||||
|
|
||||||
|
Effective step taken assuming $\epsilon = 0$ is,
|
||||||
|
$$\Delta t = \alpha \cdot \frac{\hat{m}_t}{\hat{v}_t}$$
|
||||||
|
This is bounded by,
|
||||||
|
$$\vert \Delta t \vert \le \alpha \cdot \frac{1 - \beta_1}{\sqrt{1-\beta_2}}$$
|
||||||
|
when $1-\beta_1 \gt \sqrt{1-\beta_2}$
|
||||||
|
and
|
||||||
|
$$\vert \Delta t\vert \le \alpha$$
|
||||||
|
otherwise.
|
||||||
|
And in most common scenarios,
|
||||||
|
$$\vert \Delta t \vert \approx \alpha$$
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@ -28,6 +55,7 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
* `params` is the list of parameters
|
* `params` is the list of parameters
|
||||||
* 'lr' is the learning rate $\alpha$
|
* 'lr' is the learning rate $\alpha$
|
||||||
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
||||||
|
* `eps` is $\hat{\epsilon}$
|
||||||
* `weight_decay` is an instance of class `WeightDecay` defined in [__init__.py](index.html)
|
* `weight_decay` is an instance of class `WeightDecay` defined in [__init__.py](index.html)
|
||||||
* `defaults` is a dictionary of default for group values.
|
* `defaults` is a dictionary of default for group values.
|
||||||
This is useful when you want to extend the class `Adam`.
|
This is useful when you want to extend the class `Adam`.
|
||||||
@ -101,8 +129,6 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
This computes the following
|
This computes the following
|
||||||
|
|
||||||
\begin{align}
|
\begin{align}
|
||||||
\hat{m}_t &\leftarrow \frac{m_t}/{1-\beta_1^t} \\
|
|
||||||
\hat{v}_t &\leftarrow \frac{v_t}/{1-\beta_2^t} \\
|
|
||||||
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
|
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
@ -114,12 +140,12 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot
|
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot
|
||||||
\frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\
|
\frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\
|
||||||
\theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
\theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
||||||
\frac{m_t}{\sqrt{v_t} + \epsilon'} \\
|
\frac{m_t}{\sqrt{v_t} + \hat{\epsilon}} \\
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
where
|
where
|
||||||
$$\epsilon` = (1-\beta_2^t) \epsilon \approx \epsilon$$
|
$$\hat{\epsilon} = (1-\beta_2^t) \epsilon$$
|
||||||
since $\beta_2 \approx 1$
|
is what we should specify as the hyper-parameter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get $\beta_1$ and $\beta_2$
|
# Get $\beta_1$ and $\beta_2$
|
||||||
@ -134,7 +160,7 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
|
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
|
||||||
step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1
|
step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1
|
||||||
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
||||||
# \frac{m_t}{\sqrt{v_t} + \epsilon}$
|
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
|
||||||
param.data.addcdiv_(m, denominator, value=-step_size)
|
param.data.addcdiv_(m, denominator, value=-step_size)
|
||||||
|
|
||||||
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||||
@ -148,7 +174,7 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Calculate weight decay
|
# Calculate weight decay
|
||||||
self.weight_decay(param, group)
|
grad = self.weight_decay(param, grad, group)
|
||||||
|
|
||||||
# Get $m_t$ and $v_t$
|
# Get $m_t$ and $v_t$
|
||||||
m, v = self.get_mv(state, group, grad)
|
m, v = self.get_mv(state, group, grad)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ class RAdam(AMSGrad):
|
|||||||
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults)
|
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults)
|
||||||
|
|
||||||
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||||
self.weight_decay(param, group)
|
grad = self.weight_decay(param, grad, group)
|
||||||
|
|
||||||
m, v = self.get_mv(state, group, grad)
|
m, v = self.get_mv(state, group, grad)
|
||||||
state['step'] += 1
|
state['step'] += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user