mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			137 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			137 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Adam Optimizer for Half Precision Training
 | |
| summary: A simple PyTorch implementation/tutorial of Adam optimizer
 | |
| ---
 | |
| 
 | |
| # Adam Optimizer for Half Precision Training
 | |
| """
 | |
| 
 | |
| from typing import Dict, Tuple, Optional, Any
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| from torch.optim import Optimizer
 | |
| from torch.cuda.amp import grad_scaler
 | |
| from collections import defaultdict, abc
 | |
| 
 | |
| from labml_nn.optimizers import WeightDecay
 | |
| from labml_nn.optimizers.adam import Adam
 | |
| 
 | |
| 
 | |
| class AdamFP16(Adam):
 | |
|     """
 | |
|     ## Adam Optimizer for Half Precision Training
 | |
| 
 | |
|     We extend [Adam Optimizer](adam.html) but use FP32 to store gradients and moments.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
 | |
|                  weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
 | |
|                  defaults: Optional[Dict[str, Any]] = None):
 | |
|         # Parameter to store 32 bit gradients. This get populated by the `GradScaler` defined below.
 | |
|         self.grad_fp32 = {}
 | |
|         # Call the [Adam Optimizer](adam.html) initializer
 | |
|         super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
 | |
| 
 | |
|     def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
 | |
|         """
 | |
|         ### Initialize a parameter state
 | |
| 
 | |
|         * `state` is the optimizer state of the parameter (tensor)
 | |
|         * `group` stores optimizer attributes of the parameter group
 | |
|         * `param` is the parameter tensor $\theta_{t-1}$
 | |
| 
 | |
|         All the state tensors use FP32.
 | |
|         """
 | |
| 
 | |
|         # This is the number of optimizer steps taken on the parameter, $t$
 | |
|         state['step'] = 0
 | |
|         # Exponential moving average of gradients, $m_t$
 | |
|         state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
 | |
|         # Exponential moving average of squared gradient values, $v_t$
 | |
|         state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
 | |
|         # Maintain a FP32 copy of the parameters
 | |
|         state['fp32_copy'] = param.to(torch.float)
 | |
| 
 | |
|     def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
 | |
|         """
 | |
|         ### Take an update step for a given parameter tensor
 | |
| 
 | |
|         * `state` is the optimizer state of the parameter (tensor)
 | |
|         * `group` stores optimizer attributes of the parameter group
 | |
|         * `grad` is the current gradient tensor  $g_t$ for the parameter $\theta_{t-1}$
 | |
|         * `param` is the parameter tensor $\theta_{t-1}$
 | |
|         """
 | |
| 
 | |
|         # Get the FP32 parameters
 | |
|         param_fp32 = state['fp32_copy']
 | |
|         # Get the FP32 gradients if available
 | |
|         grad_fp32 = self.grad_fp32.get(param, None)
 | |
|         if grad_fp32 is not None:
 | |
|             del self.grad_fp32[param]
 | |
|             grad = grad_fp32
 | |
|         else:
 | |
|             # Otherwise, convert the gradients to FP32
 | |
|             grad = grad.to(torch.float)
 | |
| 
 | |
|         # Calculate weight decay
 | |
|         grad = self.weight_decay(param_fp32, grad, group)
 | |
| 
 | |
|         # Get $m_t$ and $v_t$
 | |
|         m, v = self.get_mv(state, group, grad)
 | |
| 
 | |
|         # Increment $t$ the number of optimizer steps
 | |
|         state['step'] += 1
 | |
| 
 | |
|         # Perform *Adam* update
 | |
|         self.adam_update(state, group, param_fp32, m, v)
 | |
| 
 | |
|         # Set the parameters
 | |
|         param.data = param_fp32.to(param.dtype)
 | |
| 
 | |
| 
 | |
| class GradScalerFP16(grad_scaler.GradScaler):
 | |
|     """
 | |
|     ## Gradient Scaler with half precision gradients
 | |
| 
 | |
|     We extend PyTorch gradient scaler to use FP32 gradients.
 | |
|     """
 | |
| 
 | |
|     def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
 | |
|                         allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
 | |
|         per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
 | |
|         per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
 | |
| 
 | |
|         per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))  # type: ignore[var-annotated]
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             # Loop through parameters
 | |
|             for group in optimizer.param_groups:
 | |
|                 for param in group["params"]:
 | |
|                     # Skip non-trainable parameters
 | |
|                     if param.grad is None:
 | |
|                         continue
 | |
|                     # Not implemented for sparse tensors
 | |
|                     if param.grad.is_sparse:
 | |
|                         raise NotImplementedError
 | |
| 
 | |
|                     # If we are using the `AdamFP16` optimizer set `optimizer.grad_fp32[param]` to the FP32 gradients
 | |
|                     if isinstance(optimizer, AdamFP16):
 | |
|                         grad = param.grad.to(torch.float)
 | |
|                         optimizer.grad_fp32[param] = grad
 | |
|                     # Otherwise, do not convert the gradients to FP32
 | |
|                     else:
 | |
|                         grad = param.grad
 | |
| 
 | |
|                     per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)
 | |
| 
 | |
|             # Unscale all the gradients
 | |
|             for device, per_dtype_grads in per_device_and_dtype_grads.items():
 | |
|                 for grads in per_dtype_grads.values():
 | |
|                     torch._amp_foreach_non_finite_check_and_unscale_(grads,
 | |
|                                                                      per_device_found_inf.get(device),
 | |
|                                                                      per_device_inv_scale.get(device))
 | |
|         #
 | |
|         return per_device_found_inf._per_device_tensors
 | 
