Files
Varuna Jayasiri 9a42ac2697 arxiv.org links
2023-10-24 14:42:32 +01:00

192 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
---
title: Sophia Optimizer
summary: A simple PyTorch implementation/tutorial of Sophia optimizer
---
# Sophia Optimizer
This is a [PyTorch](https://pytorch.org) implementation of *Sophia-G* from paper
[Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://arxiv.org/abs/2305.14342).
Official implementation is available at [Liuhong99/Sophia](https://github.com/Liuhong99/Sophia).
Sophia is more adaptive to heterogeneous curvatures than Adam, more resistant
to non-convexity and rapid change of Hessian than Newtons method, and also uses a low-cost
pre-conditioner.
Sophia keeps diagonal Hessian estimates with EMA across iterations.
The diagonal Hessian $\hat{h}_t$ is calculated every $k$ steps.
\begin{align}
h_t = \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t \ \ \ \ \text{ if } t \text{ mod } k = 1; \text{ else } h_t = h_{t-1}
\end{align}
Sophia uses EMA of gradients $m_t$, only considers positive entries of
the diagonal Hessian and does per-coordinate clipping to the update.
\begin{align}
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1)g_t \\
\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{ \max \{h_t, \epsilon \} }, \rho \bigg)
\end{align}
where $\epsilon$ is a very small value to prevent division by $0$.
### Gauss-Newton-Bartlett (GNB) estimator
\begin{align}
\hat{L}(\theta) &= \frac{1}{B} \sum^{B}_{b=1} \ell_{CE} \big( f(\theta, x_b), \hat{y}_b \big) \\
\hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta)
\end{align}
where $x_b$ are the inputs,
$B$ is the batch size (number of inputs/tokens),
$\ell_{CE}$ is cross entropy loss, and
$\hat{y}_b$ are sampled from the logits $f(\theta, x_b)$.
Note that this hessian estimate is always positive and therefore we
can replace $\max \{h_t, \epsilon \}$ with $h_t + \epsilon$.
Sophia with Gauss-Newton-Bartlett (GNB) estimator is **Sophia-G**
Here is an [experiment](../transformers/basic/with_sophia.html) that uses Sophia-G to train a transformer.
"""
from typing import Dict, Any, Tuple, Optional
import torch
from torch import nn
from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
class Sophia(GenericAdaptiveOptimizer):
"""
## Sophia-G Optimizer
We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html)
to implement the Sophia optimizer.
"""
def __init__(self, params,
lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.95), eps: float = 1e-12,
rho: float = 0.03,
weight_decay: WeightDecay = WeightDecay(),
defaults: Optional[Dict[str, Any]] = None):
"""
### Initialize the optimizer
* `params` is the list of parameters
* `lr` is the maximum learning rate $\eta \rho$
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
* `eps` is $\epsilon$
* `pho` is $\rho$
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
* `defaults` is a dictionary of default for group values.
This is useful when you want to extend the class `Adam`.
"""
defaults = {} if defaults is None else defaults
defaults.update(weight_decay.defaults())
defaults.update(dict(rho=rho))
super().__init__(params, defaults, lr, betas, eps)
self.weight_decay = weight_decay
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
"""
### Initialize a parameter state
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `param` is the parameter tensor $\theta_{t-1}$
"""
# This is the number of optimizer steps taken on the parameter, $t$
state['step'] = 0
# Exponential moving average of gradients, $m_t$
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
# Exponential moving average of Hessian diagonal, $h_t$
state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format)
def update_hessian(self, n_tokens_training_batch):
"""
### Update the EMA of Hessian diagonal $h_t$
* `n_tokens_training_batch` is the number of tokens/inputs in the batch $B$
\begin{align}
\hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\
h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t
\end{align}
"""
# Iterate through parameter groups
for group in self.param_groups:
# $\beta_2$
_, beta2 = group['betas']
# Iterate through parameters
for p in group['params']:
# Skip parameters without gradients
if p.grad is None:
continue
# Get optimizer state
state = self.state[p]
# Initialize state if empty
if len(state) == 0:
self.init_state(state, group, p)
# Update EMA Hessian diagonal
#
# \begin{align}
# \hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\
# h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t
# \end{align}
state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * n_tokens_training_batch)
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
"""
### Take an update step for a given parameter tensor
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
* `param` is the parameter tensor $\theta_{t-1}$
We do the following parameter update,
\begin{align}
\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)
\end{align}
"""
# Calculate weight decay
grad = self.weight_decay(param, grad, group)
# Get $\beta_1$ and $\beta_2$
beta1, beta2 = group['betas']
# Get $\rho$
rho = group['rho']
# Get $m_{t-1}$ and $h_{t}$
m, hessian = state['exp_avg'], state['hessian']
# In-place calculation of $m_t$
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
m.mul_(beta1).add_(grad, alpha=1 - beta1)
# Increment $t$ the number of optimizer steps
state['step'] += 1
# Get maximum learning rate $\eta \rho$
lr = group['lr']
# $\eta$
eta = lr / rho
# $$\operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)$$
ratio = (m / (hessian + group['eps'])).clamp(-rho, rho)
# $$\theta_{t + 1} \leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)$$
param.data.add_(ratio, alpha=-eta)