diff --git a/labml_nn/optimizers/__init__.py b/labml_nn/optimizers/__init__.py index 8a259886..a6ab43d3 100644 --- a/labml_nn/optimizers/__init__.py +++ b/labml_nn/optimizers/__init__.py @@ -6,7 +6,7 @@ from torch.optim.optimizer import Optimizer class GenericAdaptiveOptimizer(Optimizer): - def __init__(self, params, defaults, lr: float, betas: Tuple[float, float], eps: float, ): + def __init__(self, params, defaults, lr: float, betas: Tuple[float, float], eps: float): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: diff --git a/labml_nn/optimizers/ada_belief.py b/labml_nn/optimizers/ada_belief.py index 5e82588a..535582d6 100644 --- a/labml_nn/optimizers/ada_belief.py +++ b/labml_nn/optimizers/ada_belief.py @@ -1,14 +1,17 @@ """ -This is forked from AdaBelief official implementation +This is based from AdaBelief official implementation https://github.com/juntang-zhuang/Adabelief-Optimizer """ -import math +from typing import Dict, Any import torch -from torch.optim.optimizer import Optimizer +from torch import nn + +from labml_nn.optimizers import WeightDecay +from labml_nn.optimizers.radam import RAdam -class AdaBelief(Optimizer): +class AdaBelief(RAdam): r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch Arguments: params (iterable): iterable of parameters to optimize or dicts defining @@ -39,125 +42,50 @@ class AdaBelief(Optimizer): """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, - weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True, - degenerated_to_sgd=True): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + weight_decay: WeightDecay = WeightDecay(), amsgrad=False, + degenerated_to_sgd=True, + rectify=True, defaults=None): - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad) - super().__init__(params, defaults) - - self.degenerated_to_sgd = degenerated_to_sgd - self.weight_decouple = weight_decouple + defaults = {} if defaults is None else defaults + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerated_to_sgd, defaults) self.rectify = rectify - self.fixed_decay = fixed_decay - def __setstate__(self, state): - super().__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) + def init_state(self, state: Dict[str, any], group: Dict[str, any], p: nn.Parameter): + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p, memory_format=torch.preserve_format) - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() + if group['amsgrad']: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p, memory_format=torch.preserve_format) - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - grad = p.grad.data - if grad.is_sparse: - raise RuntimeError('AdaBelief does not support sparse gradients,' - ' please consider SparseAdam instead') + def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor): + beta1, beta2 = group['betas'] - state = self.state[p] - # Lazy state initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_var'] = torch.zeros_like(p, memory_format=torch.preserve_format) - if group['amsgrad']: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_var'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # get current state variable + m, v = state['exp_avg'], state['exp_avg_var'] - beta1, beta2 = group['betas'] + # Update first and second moment running average + m.mul_(beta1).add_(grad, alpha=1 - beta1) + grad_residual = grad - m + v.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) - # get current state variable - exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] + if group['amsgrad']: + v_max = state['max_exp_avg_var'] + torch.maximum(v_max, v, out=v_max) - state['step'] += 1 - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] + return m, v_max + else: + return m, v - # Update first and second moment running average - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - grad_residual = grad - exp_avg - exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) + def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): + self.weight_decay(param, group) + m, v = self.get_mv(state, group, grad) + state['step'] += 1 - if group['amsgrad']: - max_exp_avg_var = state['max_exp_avg_var'] - # Maintains the maximum of all 2nd moment running avg. till now - torch.max(max_exp_avg_var, exp_avg_var, out=max_exp_avg_var) - - # Use the max. for normalizing running avg. of gradient - denom = ((max_exp_avg_var + group['eps']).sqrt_() / math.sqrt(bias_correction2)).add_(group['eps']) - else: - # denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - denom = (exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - - # perform weight decay, check if decoupled weight decay - if self.weight_decouple: - if not self.fixed_decay: - p.data.mul_(1.0 - group['lr'] * group['weight_decay']) - else: - p.data.mul_(1.0 - group['weight_decay']) - else: - if group['weight_decay'] != 0: - grad.add_(p.data, alpha=group['weight_decay']) - - # update - if not self.rectify: - # Default update - step_size = group['lr'] / bias_correction1 - p.data.addcdiv_(exp_avg, denom, value=-step_size) - else: # Rectified update, forked from RAdam - beta2_t = beta2 ** state['step'] - N_sma_max = 2 / (1 - beta2) - 1 - N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) - - # more conservative since it's an approximated value - if N_sma >= 5: - step_size = math.sqrt( - (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( - N_sma_max - 2)) / (1 - beta1 ** state['step']) - elif self.degenerated_to_sgd: - step_size = 1.0 / (1 - beta1 ** state['step']) - else: - step_size = -1 - - if N_sma >= 5: - denom = exp_avg_var.sqrt().add_(group['eps']) - p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) - elif step_size > 0: - p.data.add_(exp_avg, alpha=-step_size * group['lr']) - - return loss + if not self.rectify: + self.adam_update(state, group, param, m, v) + else: # Rectified update, forked from RAdam + self.r_adam_update(state, group, param, m, v) diff --git a/labml_nn/optimizers/adam.py b/labml_nn/optimizers/adam.py index 7038b288..ecace315 100644 --- a/labml_nn/optimizers/adam.py +++ b/labml_nn/optimizers/adam.py @@ -1,5 +1,5 @@ import math -from typing import Dict +from typing import Dict, Any import torch from torch import nn @@ -9,10 +9,8 @@ from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay class Adam(GenericAdaptiveOptimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, - amsgrad=False, - weight_decay: WeightDecay = WeightDecay()): - defaults = dict(amsgrad=amsgrad, - buffer=[[None, None, None] for _ in range(10)]) + weight_decay: WeightDecay = WeightDecay(), defaults=None): + defaults = {} if defaults is None else defaults defaults.update(weight_decay.defaults()) super().__init__(params, defaults, lr, betas, eps) @@ -25,31 +23,37 @@ class Adam(GenericAdaptiveOptimizer): # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) - if group['amsgrad']: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) - - def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): - self.weight_decay(param, group) - + def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor): beta1, beta2 = group['betas'] # get current state variable m, v = state['exp_avg'], state['exp_avg_sq'] - state['step'] += 1 - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - # Update first and second moment running average m.mul_(beta1).add_(grad, alpha=1 - beta1) v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - if group['amsgrad']: - v_max = state['max_exp_avg_sq'] - torch.maximum(v_max, v, out=v_max) - denominator = (v_max.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - else: - denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + return m, v + + def get_lr(self, state: Dict[str, any], group: Dict[str, any]): + return group['lr'] + + def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter, + m: torch.Tensor, v: torch.Tensor): + beta1, beta2 = group['betas'] + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + step_size = self.get_lr(state, group) / bias_correction1 + param.data.addcdiv_(m, denominator, value=-step_size) + + def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): + self.weight_decay(param, group) + + m, v = self.get_mv(state, group, grad) + + state['step'] += 1 + + self.adam_update(state, group, param, m, v) - param.data.addcdiv_(m, denominator, value=-group['lr'] / bias_correction1) diff --git a/labml_nn/optimizers/adam_warmup.py b/labml_nn/optimizers/adam_warmup.py new file mode 100644 index 00000000..df4d59f9 --- /dev/null +++ b/labml_nn/optimizers/adam_warmup.py @@ -0,0 +1,18 @@ +from typing import Dict + +from labml_nn.optimizers import WeightDecay +from labml_nn.optimizers.amsgrad import AMSGrad + + +class AdamWarmup(AMSGrad): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, + weight_decay: WeightDecay = WeightDecay(), amsgrad=False, warmup=0, defaults=None): + defaults = {} if defaults is None else defaults + defaults.update(dict(warmup=warmup)) + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults) + + def get_lr(self, state: Dict[str, any], group: Dict[str, any]): + if group['warmup'] > state['step']: + return 1e-8 + state['step'] * group['lr'] / group['warmup'] + else: + return group['lr'] diff --git a/labml_nn/optimizers/amsgrad.py b/labml_nn/optimizers/amsgrad.py new file mode 100644 index 00000000..d9837d4c --- /dev/null +++ b/labml_nn/optimizers/amsgrad.py @@ -0,0 +1,32 @@ +from typing import Dict + +import torch +from torch import nn + +from labml_nn.optimizers import WeightDecay +from labml_nn.optimizers.adam import Adam + + +class AMSGrad(Adam): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, + weight_decay: WeightDecay = WeightDecay(), amsgrad=True, defaults=None): + defaults = {} if defaults is None else defaults + defaults.update(dict(amsgrad=amsgrad)) + + super().__init__(params, lr, betas, eps, weight_decay, defaults) + + def init_state(self, state: Dict[str, any], group: Dict[str, any], p: nn.Parameter): + super().init_state(state, group, p) + # Maintains max of all exp. moving avg. of sq. grad. values + if group['amsgrad']: + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + def get_mv(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor): + m, v = super().get_mv(state, group, grad) + if group['amsgrad']: + v_max = state['max_exp_avg_sq'] + torch.maximum(v_max, v, out=v_max) + + return m, v_max + else: + return m, v diff --git a/labml_nn/optimizers/mnist.py b/labml_nn/optimizers/mnist_experiment.py similarity index 85% rename from labml_nn/optimizers/mnist.py rename to labml_nn/optimizers/mnist_experiment.py index f0cfe307..9aa00963 100644 --- a/labml_nn/optimizers/mnist.py +++ b/labml_nn/optimizers/mnist_experiment.py @@ -82,11 +82,23 @@ def model(c: Configs): @option(OptimizerConfigs.optimizer, 'AdaBelief') -def ada_belief(c: OptimizerConfigs): - from labml_nn.optimizers.ada_belief_buffer import AdaBelief +def _ada_belief(c: OptimizerConfigs): + from labml_nn.optimizers.ada_belief import AdaBelief return AdaBelief(c.parameters, lr=c.learning_rate, betas=c.betas, eps=c.eps) +@option(OptimizerConfigs.optimizer, 'Adam') +def _adam(c: OptimizerConfigs): + from labml_nn.optimizers.adam import Adam + return Adam(c.parameters, lr=c.learning_rate, betas=c.betas, eps=c.eps) + + +@option(OptimizerConfigs.optimizer, 'AdamWarmup') +def _adam_warmup(c: OptimizerConfigs): + from labml_nn.optimizers.adam_warmup import AdamWarmup + return AdamWarmup(c.parameters, lr=c.learning_rate, betas=c.betas, eps=c.eps) + + @option(Configs.optimizer) def _optimizer(c: Configs): opt_conf = OptimizerConfigs() diff --git a/labml_nn/optimizers/radam.py b/labml_nn/optimizers/radam.py index bdb20786..d599f9d9 100644 --- a/labml_nn/optimizers/radam.py +++ b/labml_nn/optimizers/radam.py @@ -1,159 +1,47 @@ """ -Forked from https://github.com/LiyuanLucasLiu/RAdam +Based on https://github.com/LiyuanLucasLiu/RAdam """ import math +from typing import Dict + import torch -from torch.optim.optimizer import Optimizer + +from labml_nn.optimizers import WeightDecay +from labml_nn.optimizers.amsgrad import AMSGrad -class RAdam(Optimizer): - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - +class RAdam(AMSGrad): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay: WeightDecay = WeightDecay(), amsgrad=False, + degenerated_to_sgd=True, defaults=None): self.degenerated_to_sgd = degenerated_to_sgd - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, defaults) - super().__init__(params, defaults) + def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter): + self.weight_decay(param, group) - def __setstate__(self, state): - super().__setstate__(state) + m, v = self.get_mv(state, group, grad) + state['step'] += 1 - def step(self, closure=None): + self.r_adam_update(state, group, param, m, v) - loss = None - if closure is not None: - loss = closure() + def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter, + m: torch.Tensor, v: torch.Tensor): + beta1, beta2 = group['betas'] + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] - for group in self.param_groups: + beta2_t = beta2 ** state['step'] + rho_inf = 2 / (1 - beta2) - 1 + rho = rho_inf - 2 * state['step'] * beta2_t / (1 - beta2_t) - for p in group['params']: - if p.grad is None: - continue - grad = p.grad.data.float() - if grad.is_sparse: - raise RuntimeError('RAdam does not support sparse gradients') - - p_data_fp32 = p.data.float() - - state = self.state[p] - - if len(state) == 0: - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_data_fp32) - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) - else: - state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] - - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - exp_avg.mul_(beta1).add_(1 - beta1, grad) - - state['step'] += 1 - beta2_t = beta2 ** state['step'] - N_sma_max = 2 / (1 - beta2) - 1 - N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) - - # more conservative since it's an approximated value - if N_sma >= 5: - if group['weight_decay'] != 0: - p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) - step_size = group['lr'] * math.sqrt( - (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( - N_sma_max - 2)) / (1 - beta1 ** state['step']) - denom = exp_avg_sq.sqrt().add_(group['eps']) - p_data_fp32.addcdiv_(-step_size, exp_avg, denom) - p.data.copy_(p_data_fp32) - elif self.degenerated_to_sgd: - if group['weight_decay'] != 0: - p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) - step_size = group['lr'] / (1 - beta1 ** state['step']) - p_data_fp32.add_(-step_size, exp_avg) - p.data.copy_(p_data_fp32) - - return loss - - -class AdamW(Optimizer): - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, warmup=warmup) - super(AdamW, self).__init__(params, defaults) - - def __setstate__(self, state): - super(AdamW, self).__setstate__(state) - - def step(self, closure=None): - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - - for p in group['params']: - if p.grad is None: - continue - grad = p.grad.data.float() - if grad.is_sparse: - raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - - p_data_fp32 = p.data.float() - - state = self.state[p] - - if len(state) == 0: - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_data_fp32) - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) - else: - state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - exp_avg.mul_(beta1).add_(1 - beta1, grad) - - denom = exp_avg_sq.sqrt().add_(group['eps']) - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - - if group['warmup'] > state['step']: - scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] - else: - scheduled_lr = group['lr'] - - step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 - - if group['weight_decay'] != 0: - p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) - - p_data_fp32.addcdiv_(-step_size, exp_avg, denom) - - p.data.copy_(p_data_fp32) - - return loss + # more conservative since it's an approximated value + if rho >= 5: + r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2) + denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + step_size = self.get_lr(state, group) * math.sqrt(r2) / bias_correction1 + param.data.addcdiv_(m, denominator, value=-step_size) + elif self.degenerated_to_sgd: + step_size = self.get_lr(state, group) / bias_correction1 + param.data.add_(m, alpha=-step_size)