mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 02:07:56 +08:00
unoptimized adam
This commit is contained in:
@ -53,9 +53,11 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html)
|
We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html)
|
||||||
to implement the Adam optimizer.
|
to implement the Adam optimizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params,
|
def __init__(self, params,
|
||||||
lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
|
lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
|
||||||
weight_decay: WeightDecay = WeightDecay(),
|
weight_decay: WeightDecay = WeightDecay(),
|
||||||
|
optimized_update: bool = True,
|
||||||
defaults: Optional[Dict[str, Any]] = None):
|
defaults: Optional[Dict[str, Any]] = None):
|
||||||
"""
|
"""
|
||||||
### Initialize the optimizer
|
### Initialize the optimizer
|
||||||
@ -63,8 +65,10 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
* `params` is the list of parameters
|
* `params` is the list of parameters
|
||||||
* `lr` is the learning rate $\alpha$
|
* `lr` is the learning rate $\alpha$
|
||||||
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
||||||
* `eps` is $\hat{\epsilon}$
|
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
|
||||||
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
||||||
|
* 'optimized_update' is a flag whether to optimize the bias correction of the second moment
|
||||||
|
by doing it after adding $\epsilon$
|
||||||
* `defaults` is a dictionary of default for group values.
|
* `defaults` is a dictionary of default for group values.
|
||||||
This is useful when you want to extend the class `Adam`.
|
This is useful when you want to extend the class `Adam`.
|
||||||
"""
|
"""
|
||||||
@ -73,6 +77,7 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
super().__init__(params, defaults, lr, betas, eps)
|
super().__init__(params, defaults, lr, betas, eps)
|
||||||
|
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
|
self.optimized_update = optimized_update
|
||||||
|
|
||||||
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
||||||
"""
|
"""
|
||||||
@ -163,13 +168,23 @@ class Adam(GenericAdaptiveOptimizer):
|
|||||||
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
|
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
|
||||||
bias_correction2 = 1 - beta2 ** state['step']
|
bias_correction2 = 1 - beta2 ** state['step']
|
||||||
|
|
||||||
# $\sqrt{v_t} + \epsilon$
|
if self.optimized_update:
|
||||||
|
# $\sqrt{v_t} + \hat{\epsilon}$
|
||||||
denominator = v.sqrt().add_(group['eps'])
|
denominator = v.sqrt().add_(group['eps'])
|
||||||
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
|
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
|
||||||
step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1
|
step_size = self.get_lr(state, group) * math.sqrt(bias_correction2) / bias_correction1
|
||||||
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
|
||||||
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
|
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
|
||||||
param.data.addcdiv_(m, denominator, value=-step_size)
|
param.data.addcdiv_(m, denominator, value=-step_size)
|
||||||
|
else:
|
||||||
|
# $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
|
||||||
|
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||||
|
# $\frac{\alpha}{1-\beta_1^t}$
|
||||||
|
step_size = self.get_lr(state, group) / bias_correction1
|
||||||
|
# $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot
|
||||||
|
# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
|
||||||
|
param.data.addcdiv_(m, denominator, value=-step_size)
|
||||||
|
|
||||||
|
|
||||||
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||||
"""
|
"""
|
||||||
|
@ -13,10 +13,12 @@ from labml_nn.optimizers.amsgrad import AMSGrad
|
|||||||
|
|
||||||
class AdamWarmup(AMSGrad):
|
class AdamWarmup(AMSGrad):
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
||||||
weight_decay: WeightDecay = WeightDecay(), amsgrad=False, warmup=0, defaults=None):
|
weight_decay: WeightDecay = WeightDecay(),
|
||||||
|
optimized_update: bool = True,
|
||||||
|
amsgrad=False, warmup=0, defaults=None):
|
||||||
defaults = {} if defaults is None else defaults
|
defaults = {} if defaults is None else defaults
|
||||||
defaults.update(dict(warmup=warmup))
|
defaults.update(dict(warmup=warmup))
|
||||||
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults)
|
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
|
||||||
|
|
||||||
def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
|
def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
|
||||||
if group['warmup'] > state['step']:
|
if group['warmup'] > state['step']:
|
||||||
|
@ -33,7 +33,9 @@ class AMSGrad(Adam):
|
|||||||
defined in [`__init__.py`](index.html).
|
defined in [`__init__.py`](index.html).
|
||||||
"""
|
"""
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
||||||
weight_decay: WeightDecay = WeightDecay(), amsgrad=True, defaults=None):
|
weight_decay: WeightDecay = WeightDecay(),
|
||||||
|
optimized_update: bool = True,
|
||||||
|
amsgrad=True, defaults=None):
|
||||||
"""
|
"""
|
||||||
### Initialize the optimizer
|
### Initialize the optimizer
|
||||||
|
|
||||||
@ -49,7 +51,7 @@ class AMSGrad(Adam):
|
|||||||
defaults = {} if defaults is None else defaults
|
defaults = {} if defaults is None else defaults
|
||||||
defaults.update(dict(amsgrad=amsgrad))
|
defaults.update(dict(amsgrad=amsgrad))
|
||||||
|
|
||||||
super().__init__(params, lr, betas, eps, weight_decay, defaults)
|
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
|
||||||
|
|
||||||
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
||||||
"""
|
"""
|
||||||
|
@ -20,6 +20,7 @@ class OptimizerConfigs(BaseConfigs):
|
|||||||
weight_decouple: bool = True
|
weight_decouple: bool = True
|
||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
weight_decay_absolute: bool = False
|
weight_decay_absolute: bool = False
|
||||||
|
optimized_adam_update: bool = True
|
||||||
|
|
||||||
parameters: any
|
parameters: any
|
||||||
|
|
||||||
@ -58,11 +59,13 @@ def _adam_optimizer(c: OptimizerConfigs):
|
|||||||
from labml_nn.optimizers.amsgrad import AMSGrad
|
from labml_nn.optimizers.amsgrad import AMSGrad
|
||||||
return AMSGrad(c.parameters,
|
return AMSGrad(c.parameters,
|
||||||
lr=c.learning_rate, betas=c.betas, eps=c.eps,
|
lr=c.learning_rate, betas=c.betas, eps=c.eps,
|
||||||
|
optimized_update=c.optimized_adam_update,
|
||||||
weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad)
|
weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad)
|
||||||
else:
|
else:
|
||||||
from labml_nn.optimizers.adam import Adam
|
from labml_nn.optimizers.adam import Adam
|
||||||
return Adam(c.parameters,
|
return Adam(c.parameters,
|
||||||
lr=c.learning_rate, betas=c.betas, eps=c.eps,
|
lr=c.learning_rate, betas=c.betas, eps=c.eps,
|
||||||
|
optimized_update=c.optimized_adam_update,
|
||||||
weight_decay=c.weight_decay_obj)
|
weight_decay=c.weight_decay_obj)
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,11 +14,13 @@ from labml_nn.optimizers.amsgrad import AMSGrad
|
|||||||
|
|
||||||
class Noam(AMSGrad):
|
class Noam(AMSGrad):
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
||||||
weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
|
weight_decay: WeightDecay = WeightDecay(),
|
||||||
|
optimized_update: bool = True,
|
||||||
|
amsgrad=False,
|
||||||
warmup=0, d_model=512, defaults=None):
|
warmup=0, d_model=512, defaults=None):
|
||||||
defaults = {} if defaults is None else defaults
|
defaults = {} if defaults is None else defaults
|
||||||
defaults.update(dict(warmup=warmup))
|
defaults.update(dict(warmup=warmup))
|
||||||
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults)
|
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|
||||||
def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
|
def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
|
||||||
|
@ -1,10 +1,28 @@
|
|||||||
"""
|
"""
|
||||||
---
|
---
|
||||||
title: RAdam optimizer
|
title: Rectified Adam (RAdam) optimizer
|
||||||
summary: A simple PyTorch implementation/tutorial of RAdam optimizer.
|
summary: A simple PyTorch implementation/tutorial of RAdam optimizer.
|
||||||
---
|
---
|
||||||
|
|
||||||
Based on https://github.com/LiyuanLucasLiu/RAdam
|
# Rectified Adam (RAdam) optimizer
|
||||||
|
|
||||||
|
This implementation is based on
|
||||||
|
[the official implementation](https://github.com/LiyuanLucasLiu/RAdam)
|
||||||
|
of the paper
|
||||||
|
[On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265).
|
||||||
|
|
||||||
|
We have implemented it as an extension to [our AMSGrad implementation](amsgrad.html)
|
||||||
|
thus requiring only the modifications to be implemented.
|
||||||
|
|
||||||
|
Adam optimizer sometimes converges to a bad local optima during the initial stages of the training;
|
||||||
|
especially when training transformers.
|
||||||
|
Researches use warmups to counter this; for the the initial training steps (warm-up stage)
|
||||||
|
they use a low learning rate.
|
||||||
|
This paper identifies the problem to be the high variance of adaptive learning rate
|
||||||
|
during initial stages of training, and counters it using a new rectification term to
|
||||||
|
reduce variance.
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@ -21,7 +39,7 @@ class RAdam(AMSGrad):
|
|||||||
weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
|
weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
|
||||||
degenerated_to_sgd=True, defaults=None):
|
degenerated_to_sgd=True, defaults=None):
|
||||||
self.degenerated_to_sgd = degenerated_to_sgd
|
self.degenerated_to_sgd = degenerated_to_sgd
|
||||||
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults)
|
super().__init__(params, lr, betas, eps, weight_decay, False, amsgrad, defaults)
|
||||||
|
|
||||||
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||||
grad = self.weight_decay(param, grad, group)
|
grad = self.weight_decay(param, grad, group)
|
||||||
|
Reference in New Issue
Block a user