10from typing import Dict, Tuple, Optional, Any
11
12import torch
13from torch import nn
14from torch.optim import Optimizer
15from torch.cuda.amp import grad_scaler
16from collections import defaultdict, abc
17
18from labml_nn.optimizers import WeightDecay
19from labml_nn.optimizers.adam import AdamWe extend Adam Optimizer but use FP32 to store gradients and moments.
22class AdamFP16(Adam):29    def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
30                 weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
31                 defaults: Optional[Dict[str, Any]] = None):Parameter to store 32 bit gradients. This get populated by the GradScaler
 defined below. 
33        self.grad_fp32 = {}Call the Adam Optimizer initializer
35        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)state
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group param
 is the parameter tensor All the state tensors use FP32.
37    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):This is the number of optimizer steps taken on the parameter,
49        state['step'] = 0Exponential moving average of gradients,
51        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)Exponential moving average of squared gradient values,
53        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)Maintain a FP32 copy of the parameters
55        state['fp32_copy'] = param.to(torch.float)state
 is the optimizer state of the parameter (tensor) group
 stores optimizer attributes of the parameter group grad
 is the current gradient tensor  for the parameter  param
 is the parameter tensor 57    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):Get the FP32 parameters
68        param_fp32 = state['fp32_copy']Get the FP32 gradients if available
70        grad_fp32 = self.grad_fp32.get(param, None)
71        if grad_fp32 is not None:
72            del self.grad_fp32[param]
73            grad = grad_fp32
74        else:Otherwise, convert the gradients to FP32
76            grad = grad.to(torch.float)Calculate weight decay
79        grad = self.weight_decay(param_fp32, grad, group)Get and
82        m, v = self.get_mv(state, group, grad)Increment the number of optimizer steps
85        state['step'] += 1Perform Adam update
88        self.adam_update(state, group, param_fp32, m, v)Set the parameters
91        param.data = param_fp32.to(param.dtype)We extend PyTorch gradient scaler to use FP32 gradients.
94class GradScalerFP16(grad_scaler.GradScaler):101    def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
102                        allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
103        per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
104        per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
105
106        per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))  # type: ignore[var-annotated]
107
108        with torch.no_grad():Loop through parameters
110            for group in optimizer.param_groups:
111                for param in group["params"]:Skip non-trainable parameters
113                    if param.grad is None:
114                        continueNot implemented for sparse tensors
116                    if param.grad.is_sparse:
117                        raise NotImplementedErrorIf we are using the AdamFP16
 optimizer set optimizer.grad_fp32[param]
 to the FP32 gradients 
120                    if isinstance(optimizer, AdamFP16):
121                        grad = param.grad.to(torch.float)
122                        optimizer.grad_fp32[param] = gradOtherwise, do not convert the gradients to FP32
124                    else:
125                        grad = param.grad
126
127                    per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)Unscale all the gradients
130            for device, per_dtype_grads in per_device_and_dtype_grads.items():
131                for grads in per_dtype_grads.values():
132                    torch._amp_foreach_non_finite_check_and_unscale_(grads,
133                                                                     per_device_found_inf.get(device),
134                                                                     per_device_inv_scale.get(device))136        return per_device_found_inf._per_device_tensors