mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-03 05:46:16 +08:00
unoptimized adam
This commit is contained in:
@ -53,9 +53,11 @@ class Adam(GenericAdaptiveOptimizer):
|
||||
We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html)
|
||||
to implement the Adam optimizer.
|
||||
"""
|
||||
|
||||
def __init__(self, params,
|
||||
lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
|
||||
weight_decay: WeightDecay = WeightDecay(),
|
||||
optimized_update: bool = True,
|
||||
defaults: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
### Initialize the optimizer
|
||||
@ -63,8 +65,10 @@ class Adam(GenericAdaptiveOptimizer):
|
||||
* `params` is the list of parameters
|
||||
* `lr` is the learning rate $\alpha$
|
||||
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
||||
* `eps` is $\hat{\epsilon}$
|
||||
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
|
||||
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
||||
* 'optimized_update' is a flag whether to optimize the bias correction of the second moment
|
||||
by doing it after adding $\epsilon$
|
||||
* `defaults` is a dictionary of default for group values.
|
||||
This is useful when you want to extend the class `Adam`.
|
||||
"""
|
||||
@ -73,6 +77,7 @@ class Adam(GenericAdaptiveOptimizer):
|
||||
super().__init__(params, defaults, lr, betas, eps)
|
||||
|
||||
self.weight_decay = weight_decay
|
||||
self.optimized_update = optimized_update
|
||||
|
||||
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
||||
"""
|
||||
@ -163,13 +168,23 @@ class Adam(GenericAdaptiveOptimizer):
|
||||
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
# $\sqrt{v_t} + \epsilon$
|
||||
denominator = v.sqrt().add_(group['eps'])
|
||||
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
|
||||
step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1
|
||||
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
||||
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
|
||||
param.data.addcdiv_(m, denominator, value=-step_size)
|
||||
if self.optimized_update:
|
||||
# $\sqrt{v_t} + \hat{\epsilon}$
|
||||
denominator = v.sqrt().add_(group['eps'])
|
||||
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
|
||||
step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1
|
||||
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
||||
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
|
||||
param.data.addcdiv_(m, denominator, value=-step_size)
|
||||
else:
|
||||
# $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
|
||||
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
# $\frac{\alpha}{1-\beta_1^t}$
|
||||
step_size = self.get_lr(state, group) / bias_correction1
|
||||
# $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot
|
||||
# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
|
||||
param.data.addcdiv_(m, denominator, value=-step_size)
|
||||
|
||||
|
||||
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user