Files
2022-08-11 15:44:13 +05:30

137 lines
5.4 KiB
Python

"""
---
title: Adam Optimizer for Half Precision Training
summary: A simple PyTorch implementation/tutorial of Adam optimizer
---
# Adam Optimizer for Half Precision Training
"""
from typing import Dict, Tuple, Optional, Any
import torch
from torch import nn
from torch.optim import Optimizer
from torch.cuda.amp import grad_scaler
from collections import defaultdict, abc
from labml_nn.optimizers import WeightDecay
from labml_nn.optimizers.adam import Adam
class AdamFP16(Adam):
"""
## Adam Optimizer for Half Precision Training
We extend [Adam Optimizer](adam.html) but use FP32 to store gradients and moments.
"""
def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
defaults: Optional[Dict[str, Any]] = None):
# Parameter to store 32 bit gradients. This get populated by the `GradScaler` defined below.
self.grad_fp32 = {}
# Call the [Adam Optimizer](adam.html) initializer
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
"""
### Initialize a parameter state
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `param` is the parameter tensor $\theta_{t-1}$
All the state tensors use FP32.
"""
# This is the number of optimizer steps taken on the parameter, $t$
state['step'] = 0
# Exponential moving average of gradients, $m_t$
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
# Exponential moving average of squared gradient values, $v_t$
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
# Maintain a FP32 copy of the parameters
state['fp32_copy'] = param.to(torch.float)
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
"""
### Take an update step for a given parameter tensor
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
* `param` is the parameter tensor $\theta_{t-1}$
"""
# Get the FP32 parameters
param_fp32 = state['fp32_copy']
# Get the FP32 gradients if available
grad_fp32 = self.grad_fp32.get(param, None)
if grad_fp32 is not None:
del self.grad_fp32[param]
grad = grad_fp32
else:
# Otherwise, convert the gradients to FP32
grad = grad.to(torch.float)
# Calculate weight decay
grad = self.weight_decay(param_fp32, grad, group)
# Get $m_t$ and $v_t$
m, v = self.get_mv(state, group, grad)
# Increment $t$ the number of optimizer steps
state['step'] += 1
# Perform *Adam* update
self.adam_update(state, group, param_fp32, m, v)
# Set the parameters
param.data = param_fp32.to(param.dtype)
class GradScalerFP16(grad_scaler.GradScaler):
"""
## Gradient Scaler with half precision gradients
We extend PyTorch gradient scaler to use FP32 gradients.
"""
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
# Loop through parameters
for group in optimizer.param_groups:
for param in group["params"]:
# Skip non-trainable parameters
if param.grad is None:
continue
# Not implemented for sparse tensors
if param.grad.is_sparse:
raise NotImplementedError
# If we are using the `AdamFP16` optimizer set `optimizer.grad_fp32[param]` to the FP32 gradients
if isinstance(optimizer, AdamFP16):
grad = param.grad.to(torch.float)
optimizer.grad_fp32[param] = grad
# Otherwise, do not convert the gradients to FP32
else:
grad = param.grad
per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)
# Unscale all the gradients
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device))
#
return per_device_found_inf._per_device_tensors