From 10ee239a14fa13347df43eb1933cebc053954f30 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 10 Dec 2020 10:51:19 +0530 Subject: [PATCH] unoptimized adam --- labml_nn/optimizers/adam.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/labml_nn/optimizers/adam.py b/labml_nn/optimizers/adam.py index e7d17c89..d1039a7d 100644 --- a/labml_nn/optimizers/adam.py +++ b/labml_nn/optimizers/adam.py @@ -168,6 +168,7 @@ class Adam(GenericAdaptiveOptimizer): # Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$ bias_correction2 = 1 - beta2 ** state['step'] + # Whether to optimize the computation if self.optimized_update: # $\sqrt{v_t} + \hat{\epsilon}$ denominator = v.sqrt().add_(group['eps']) @@ -176,6 +177,7 @@ class Adam(GenericAdaptiveOptimizer): # $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot # \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$ param.data.addcdiv_(m, denominator, value=-step_size) + # Computation without optimization else: # $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$ denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) @@ -185,7 +187,6 @@ class Adam(GenericAdaptiveOptimizer): # \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$ param.data.addcdiv_(m, denominator, value=-step_size) - def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): """ ### Take an update step for a given parameter tensor