mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-08 00:29:29 +08:00
sophia wip
This commit is contained in:
124
labml_nn/optimizers/sophia.py
Normal file
124
labml_nn/optimizers/sophia.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
---
|
||||||
|
title: Sophia Optimizer
|
||||||
|
summary: A simple PyTorch implementation/tutorial of Sophia optimizer
|
||||||
|
---
|
||||||
|
|
||||||
|
# Sophia Optimizer
|
||||||
|
|
||||||
|
This is a [PyTorch](https://pytorch.org) implementation of *Sophia-G* from paper
|
||||||
|
[Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://papers.labml.ai/paper/2305.14342).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Tuple, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
|
||||||
|
|
||||||
|
|
||||||
|
class Sophia(GenericAdaptiveOptimizer):
|
||||||
|
"""
|
||||||
|
## Sophia-G Optimizer
|
||||||
|
|
||||||
|
We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html)
|
||||||
|
to implement the Sophia optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params,
|
||||||
|
lr: float = 1e-4, betas: Tuple[float, float] = (0.965, 0.99), eps: float = 1e-16,
|
||||||
|
rho: float = 0.04,
|
||||||
|
training_batch_tokens: int = None,
|
||||||
|
weight_decay: WeightDecay = WeightDecay(),
|
||||||
|
optimized_update: bool = True,
|
||||||
|
defaults: Optional[Dict[str, Any]] = None):
|
||||||
|
"""
|
||||||
|
### Initialize the optimizer
|
||||||
|
|
||||||
|
* `params` is the list of parameters
|
||||||
|
* `lr` is the learning rate $\alpha$
|
||||||
|
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
|
||||||
|
* `eps` is $\epsilon$
|
||||||
|
* `pho` is $\rho$
|
||||||
|
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
|
||||||
|
* `optimized_update` is a flag whether to optimize the bias correction of the second moment
|
||||||
|
by doing it after adding $\epsilon$
|
||||||
|
* `defaults` is a dictionary of default for group values.
|
||||||
|
This is useful when you want to extend the class `Adam`.
|
||||||
|
"""
|
||||||
|
if training_batch_tokens is None:
|
||||||
|
raise RuntimeError('Please set the number of tokens per training batch.')
|
||||||
|
|
||||||
|
defaults = {} if defaults is None else defaults
|
||||||
|
defaults.update(weight_decay.defaults())
|
||||||
|
defaults.update(dict(rho=rho, training_batch_tokens=training_batch_tokens))
|
||||||
|
super().__init__(params, defaults, lr, betas, eps)
|
||||||
|
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.optimized_update = optimized_update
|
||||||
|
|
||||||
|
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
||||||
|
"""
|
||||||
|
### Initialize a parameter state
|
||||||
|
|
||||||
|
* `state` is the optimizer state of the parameter (tensor)
|
||||||
|
* `group` stores optimizer attributes of the parameter group
|
||||||
|
* `param` is the parameter tensor $\theta_{t-1}$
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This is the number of optimizer steps taken on the parameter, $t$
|
||||||
|
state['step'] = 0
|
||||||
|
# state['hessian_updates']
|
||||||
|
# Exponential moving average of gradients, $m_t$
|
||||||
|
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
|
||||||
|
# Exponential moving average of Hessian
|
||||||
|
state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format)
|
||||||
|
|
||||||
|
def update_hessian(self, batch_size):
|
||||||
|
for group in self.param_groups:
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
if len(state) == 0:
|
||||||
|
self.init_state(state, group, p)
|
||||||
|
|
||||||
|
state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * batch_size)
|
||||||
|
|
||||||
|
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||||
|
"""
|
||||||
|
### Take an update step for a given parameter tensor
|
||||||
|
|
||||||
|
* `state` is the optimizer state of the parameter (tensor)
|
||||||
|
* `group` stores optimizer attributes of the parameter group
|
||||||
|
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
|
||||||
|
* `param` is the parameter tensor $\theta_{t-1}$
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Calculate weight decay
|
||||||
|
grad = self.weight_decay(param, grad, group)
|
||||||
|
|
||||||
|
# Get $\beta_1$ and $\beta_2$
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
|
||||||
|
rho = group['rho']
|
||||||
|
|
||||||
|
# Get $m_{t-1}$ and $v_{t-1}$
|
||||||
|
m, hessian = state['exp_avg'], state['hessain']
|
||||||
|
|
||||||
|
# In-place calculation of $m_t$
|
||||||
|
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
|
||||||
|
m.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||||
|
|
||||||
|
# Increment $t$ the number of optimizer steps
|
||||||
|
state['step'] += 1
|
||||||
|
|
||||||
|
# Get learning rate
|
||||||
|
lr = group['lr']
|
||||||
|
|
||||||
|
ratio = (m.abs() / (rho * hessian + group['training_batch_tokens'] * group['eps'])).clamp(None, 1)
|
||||||
|
|
||||||
|
param.data.addcmul_(m.sign(), ratio, value=-lr)
|
||||||
Reference in New Issue
Block a user