mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
♻️ adam+ optimizers
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Dict
|
||||
from typing import Dict, Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -9,10 +9,8 @@ from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
|
||||
|
||||
class Adam(GenericAdaptiveOptimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
||||
amsgrad=False,
|
||||
weight_decay: WeightDecay = WeightDecay()):
|
||||
defaults = dict(amsgrad=amsgrad,
|
||||
buffer=[[None, None, None] for _ in range(10)])
|
||||
weight_decay: WeightDecay = WeightDecay(), defaults=None):
|
||||
defaults = {} if defaults is None else defaults
|
||||
defaults.update(weight_decay.defaults())
|
||||
super().__init__(params, defaults, lr, betas, eps)
|
||||
|
||||
@ -25,31 +23,37 @@ class Adam(GenericAdaptiveOptimizer):
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
if group['amsgrad']:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||
self.weight_decay(param, group)
|
||||
|
||||
def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
# get current state variable
|
||||
m, v = state['exp_avg'], state['exp_avg_sq']
|
||||
|
||||
state['step'] += 1
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
# Update first and second moment running average
|
||||
m.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
if group['amsgrad']:
|
||||
v_max = state['max_exp_avg_sq']
|
||||
torch.maximum(v_max, v, out=v_max)
|
||||
denominator = (v_max.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
else:
|
||||
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
return m, v
|
||||
|
||||
def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
|
||||
return group['lr']
|
||||
|
||||
def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
|
||||
m: torch.Tensor, v: torch.Tensor):
|
||||
beta1, beta2 = group['betas']
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
step_size = self.get_lr(state, group) / bias_correction1
|
||||
param.data.addcdiv_(m, denominator, value=-step_size)
|
||||
|
||||
def calculate(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||
self.weight_decay(param, group)
|
||||
|
||||
m, v = self.get_mv(state, group, grad)
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
self.adam_update(state, group, param, m, v)
|
||||
|
||||
param.data.addcdiv_(m, denominator, value=-group['lr'] / bias_correction1)
|
||||
|
Reference in New Issue
Block a user