mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	adam comments
This commit is contained in:
		| @ -67,14 +67,13 @@ class WeightDecay: | |||||||
|     def defaults(self): |     def defaults(self): | ||||||
|         return dict(weight_decay=self.weight_decay) |         return dict(weight_decay=self.weight_decay) | ||||||
|  |  | ||||||
|     def __call__(self, param: torch.nn.Parameter, group: Dict[str, any]): |     def __call__(self, param: torch.nn.Parameter, grad: torch.Tensor, group: Dict[str, any]): | ||||||
|         grad = param.grad.data |  | ||||||
|  |  | ||||||
|         if self.weight_decouple: |         if self.weight_decouple: | ||||||
|             if not self.absolute: |             if not self.absolute: | ||||||
|                 param.data.mul_(1.0 - group['lr'] * group['weight_decay']) |                 param.data.mul_(1.0 - group['lr'] * group['weight_decay']) | ||||||
|             else: |             else: | ||||||
|                 param.data.mul_(1.0 - group['weight_decay']) |                 param.data.mul_(1.0 - group['weight_decay']) | ||||||
|  |             return grad | ||||||
|         else: |         else: | ||||||
|             if group['weight_decay'] != 0: |             if group['weight_decay'] != 0: | ||||||
|                 grad.add_(param.data, alpha=group['weight_decay']) |                 return grad.add(param.data, alpha=group['weight_decay']) | ||||||
|  | |||||||
| @ -81,7 +81,7 @@ class AdaBelief(RAdam): | |||||||
|             return m, v |             return m, v | ||||||
|  |  | ||||||
|     def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): |     def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): | ||||||
|         self.weight_decay(param, group) |         grad = self.weight_decay(param, grad, group) | ||||||
|         m, v = self.get_mv(state, group, grad) |         m, v = self.get_mv(state, group, grad) | ||||||
|         state['step'] += 1 |         state['step'] += 1 | ||||||
|  |  | ||||||
|  | |||||||
| @ -6,6 +6,33 @@ This is an implementation of popular optimizer *Adam* from paper | |||||||
|  |  | ||||||
| We extend the class `GenericAdaptiveOptimizer` defined in [__init__.py](index.html) | We extend the class `GenericAdaptiveOptimizer` defined in [__init__.py](index.html) | ||||||
| to implement the Adam optimizer. | to implement the Adam optimizer. | ||||||
|  |  | ||||||
|  | *Adam* update is, | ||||||
|  |  | ||||||
|  | \begin{align} | ||||||
|  | m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\ | ||||||
|  | v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\ | ||||||
|  | \hat{m}_t &\leftarrow \frac{m_t}/{1-\beta_1^t} \\ | ||||||
|  | \hat{v}_t &\leftarrow \frac{v_t}/{1-\beta_2^t} \\ | ||||||
|  | \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} | ||||||
|  | \end{align} | ||||||
|  |  | ||||||
|  | where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters. | ||||||
|  | $m_t$ and $v_t$ are first and second order moments. | ||||||
|  | $\hat{m}_t$  and $\hat{v}_t$ are biased corrected moments. | ||||||
|  | $\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter | ||||||
|  | that acts against variance in gradients. | ||||||
|  |  | ||||||
|  | Effective step taken assuming $\epsilon = 0$ is, | ||||||
|  | $$\Delta t = \alpha \cdot \frac{\hat{m}_t}{\hat{v}_t}$$ | ||||||
|  | This is bounded by, | ||||||
|  | $$\vert \Delta t \vert \le \alpha \cdot \frac{1 - \beta_1}{\sqrt{1-\beta_2}}$$ | ||||||
|  | when $1-\beta_1 \gt \sqrt{1-\beta_2}$ | ||||||
|  | and | ||||||
|  | $$\vert \Delta t\vert  \le \alpha$$ | ||||||
|  | otherwise. | ||||||
|  | And in most common scenarios, | ||||||
|  | $$\vert \Delta t \vert \approx \alpha$$ | ||||||
| """ | """ | ||||||
|  |  | ||||||
| import math | import math | ||||||
| @ -28,6 +55,7 @@ class Adam(GenericAdaptiveOptimizer): | |||||||
|         * `params` is the list of parameters |         * `params` is the list of parameters | ||||||
|         * 'lr' is the learning rate $\alpha$ |         * 'lr' is the learning rate $\alpha$ | ||||||
|         * `betas` is a tuple of ($\beta_1$, $\beta_2$) |         * `betas` is a tuple of ($\beta_1$, $\beta_2$) | ||||||
|  |         * `eps` is $\hat{\epsilon}$ | ||||||
|         * `weight_decay` is an instance of class `WeightDecay` defined in [__init__.py](index.html) |         * `weight_decay` is an instance of class `WeightDecay` defined in [__init__.py](index.html) | ||||||
|         * `defaults` is a dictionary of default for group values. |         * `defaults` is a dictionary of default for group values. | ||||||
|          This is useful when you want to extend the class `Adam`. |          This is useful when you want to extend the class `Adam`. | ||||||
| @ -101,8 +129,6 @@ class Adam(GenericAdaptiveOptimizer): | |||||||
|         This computes the following |         This computes the following | ||||||
|  |  | ||||||
|         \begin{align} |         \begin{align} | ||||||
|         \hat{m}_t &\leftarrow \frac{m_t}/{1-\beta_1^t} \\ |  | ||||||
|         \hat{v}_t &\leftarrow \frac{v_t}/{1-\beta_2^t} \\ |  | ||||||
|         \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} |         \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} | ||||||
|         \end{align} |         \end{align} | ||||||
|  |  | ||||||
| @ -114,12 +140,12 @@ class Adam(GenericAdaptiveOptimizer): | |||||||
|         \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot |         \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot | ||||||
|                 \frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\ |                 \frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\ | ||||||
|         \theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot |         \theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot | ||||||
|                 \frac{m_t}{\sqrt{v_t} + \epsilon'} \\ |                 \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}} \\ | ||||||
|         \end{align} |         \end{align} | ||||||
|  |  | ||||||
|         where |         where | ||||||
|         $$\epsilon` = (1-\beta_2^t) \epsilon \approx \epsilon$$ |         $$\hat{\epsilon} = (1-\beta_2^t) \epsilon$$ | ||||||
|         since $\beta_2 \approx 1$ |         is what we should specify as the hyper-parameter. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         # Get $\beta_1$ and $\beta_2$ |         # Get $\beta_1$ and $\beta_2$ | ||||||
| @ -134,7 +160,7 @@ class Adam(GenericAdaptiveOptimizer): | |||||||
|         # $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$ |         # $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$ | ||||||
|         step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1 |         step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1 | ||||||
|         # $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot |         # $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot | ||||||
|         #  \frac{m_t}{\sqrt{v_t} + \epsilon}$ |         #  \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$ | ||||||
|         param.data.addcdiv_(m, denominator, value=-step_size) |         param.data.addcdiv_(m, denominator, value=-step_size) | ||||||
|  |  | ||||||
|     def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): |     def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): | ||||||
| @ -148,7 +174,7 @@ class Adam(GenericAdaptiveOptimizer): | |||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         # Calculate weight decay |         # Calculate weight decay | ||||||
|         self.weight_decay(param, group) |         grad = self.weight_decay(param, grad, group) | ||||||
|  |  | ||||||
|         # Get $m_t$ and $v_t$ |         # Get $m_t$ and $v_t$ | ||||||
|         m, v = self.get_mv(state, group, grad) |         m, v = self.get_mv(state, group, grad) | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ class RAdam(AMSGrad): | |||||||
|         super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults) |         super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults) | ||||||
|  |  | ||||||
|     def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): |     def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): | ||||||
|         self.weight_decay(param, group) |         grad = self.weight_decay(param, grad, group) | ||||||
|  |  | ||||||
|         m, v = self.get_mv(state, group, grad) |         m, v = self.get_mv(state, group, grad) | ||||||
|         state['step'] += 1 |         state['step'] += 1 | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri