mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 18:27:20 +08:00
137 lines
5.4 KiB
Python
137 lines
5.4 KiB
Python
"""
|
|
---
|
|
title: Adam Optimizer for Half Precision Training
|
|
summary: A simple PyTorch implementation/tutorial of Adam optimizer
|
|
---
|
|
|
|
# Adam Optimizer for Half Precision Training
|
|
"""
|
|
|
|
from typing import Dict, Tuple, Optional, Any
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.optim import Optimizer
|
|
from torch.cuda.amp import grad_scaler
|
|
from collections import defaultdict, abc
|
|
|
|
from labml_nn.optimizers import WeightDecay
|
|
from labml_nn.optimizers.adam import Adam
|
|
|
|
|
|
class AdamFP16(Adam):
|
|
"""
|
|
## Adam Optimizer for Half Precision Training
|
|
|
|
We extend [Adam Optimizer](adam.html) but use FP32 to store gradients and moments.
|
|
"""
|
|
|
|
def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
|
|
weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
|
|
defaults: Optional[Dict[str, Any]] = None):
|
|
# Parameter to store 32 bit gradients. This get populated by the `GradScaler` defined below.
|
|
self.grad_fp32 = {}
|
|
# Call the [Adam Optimizer](adam.html) initializer
|
|
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
|
|
|
|
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
|
"""
|
|
### Initialize a parameter state
|
|
|
|
* `state` is the optimizer state of the parameter (tensor)
|
|
* `group` stores optimizer attributes of the parameter group
|
|
* `param` is the parameter tensor $\theta_{t-1}$
|
|
|
|
All the state tensors use FP32.
|
|
"""
|
|
|
|
# This is the number of optimizer steps taken on the parameter, $t$
|
|
state['step'] = 0
|
|
# Exponential moving average of gradients, $m_t$
|
|
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
|
|
# Exponential moving average of squared gradient values, $v_t$
|
|
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
|
|
# Maintain a FP32 copy of the parameters
|
|
state['fp32_copy'] = param.to(torch.float)
|
|
|
|
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
|
"""
|
|
### Take an update step for a given parameter tensor
|
|
|
|
* `state` is the optimizer state of the parameter (tensor)
|
|
* `group` stores optimizer attributes of the parameter group
|
|
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
|
|
* `param` is the parameter tensor $\theta_{t-1}$
|
|
"""
|
|
|
|
# Get the FP32 parameters
|
|
param_fp32 = state['fp32_copy']
|
|
# Get the FP32 gradients if available
|
|
grad_fp32 = self.grad_fp32.get(param, None)
|
|
if grad_fp32 is not None:
|
|
del self.grad_fp32[param]
|
|
grad = grad_fp32
|
|
else:
|
|
# Otherwise, convert the gradients to FP32
|
|
grad = grad.to(torch.float)
|
|
|
|
# Calculate weight decay
|
|
grad = self.weight_decay(param_fp32, grad, group)
|
|
|
|
# Get $m_t$ and $v_t$
|
|
m, v = self.get_mv(state, group, grad)
|
|
|
|
# Increment $t$ the number of optimizer steps
|
|
state['step'] += 1
|
|
|
|
# Perform *Adam* update
|
|
self.adam_update(state, group, param_fp32, m, v)
|
|
|
|
# Set the parameters
|
|
param.data = param_fp32.to(param.dtype)
|
|
|
|
|
|
class GradScalerFP16(grad_scaler.GradScaler):
|
|
"""
|
|
## Gradient Scaler with half precision gradients
|
|
|
|
We extend PyTorch gradient scaler to use FP32 gradients.
|
|
"""
|
|
|
|
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
|
|
allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
|
|
per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
|
|
per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
|
|
|
|
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
|
|
|
with torch.no_grad():
|
|
# Loop through parameters
|
|
for group in optimizer.param_groups:
|
|
for param in group["params"]:
|
|
# Skip non-trainable parameters
|
|
if param.grad is None:
|
|
continue
|
|
# Not implemented for sparse tensors
|
|
if param.grad.is_sparse:
|
|
raise NotImplementedError
|
|
|
|
# If we are using the `AdamFP16` optimizer set `optimizer.grad_fp32[param]` to the FP32 gradients
|
|
if isinstance(optimizer, AdamFP16):
|
|
grad = param.grad.to(torch.float)
|
|
optimizer.grad_fp32[param] = grad
|
|
# Otherwise, do not convert the gradients to FP32
|
|
else:
|
|
grad = param.grad
|
|
|
|
per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)
|
|
|
|
# Unscale all the gradients
|
|
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
|
for grads in per_dtype_grads.values():
|
|
torch._amp_foreach_non_finite_check_and_unscale_(grads,
|
|
per_device_found_inf.get(device),
|
|
per_device_inv_scale.get(device))
|
|
#
|
|
return per_device_found_inf._per_device_tensors
|