mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-31 02:39:16 +08:00
unoptimized adam
This commit is contained in:
@ -168,6 +168,7 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
|
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
|
||||||
bias_correction2 = 1 - beta2 ** state['step']
|
bias_correction2 = 1 - beta2 ** state['step']
|
||||||
|
|
||||||
|
# Whether to optimize the computation
|
||||||
if self.optimized_update:
|
if self.optimized_update:
|
||||||
# $\sqrt{v_t} + \hat{\epsilon}$
|
# $\sqrt{v_t} + \hat{\epsilon}$
|
||||||
denominator = v.sqrt().add_(group['eps'])
|
denominator = v.sqrt().add_(group['eps'])
|
||||||
@ -176,6 +177,7 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
||||||
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
|
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
|
||||||
param.data.addcdiv_(m, denominator, value=-step_size)
|
param.data.addcdiv_(m, denominator, value=-step_size)
|
||||||
|
# Computation without optimization
|
||||||
else:
|
else:
|
||||||
# $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
|
# $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
|
||||||
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||||
@ -185,7 +187,6 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
|
# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
|
||||||
param.data.addcdiv_(m, denominator, value=-step_size)
|
param.data.addcdiv_(m, denominator, value=-step_size)
|
||||||
|
|
||||||
|
|
||||||
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||||
"""
|
"""
|
||||||
### Take an update step for a given parameter tensor
|
### Take an update step for a given parameter tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user