10from typing import Dict, Tuple, Optional, Any
11
12import torch
13from torch import nn
14from torch.optim import Optimizer
15from torch.cuda.amp import grad_scaler
16from collections import defaultdict, abc
17
18from labml_nn.optimizers import WeightDecay
19from labml_nn.optimizers.adam import Adam
22class AdamFP16(Adam):
29 def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
30 weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
31 defaults: Optional[Dict[str, Any]] = None):
32 ビットのグラデーションを格納するパラメーター。GradScaler
これには以下の定義が入力されます
33 self.grad_fp32 = {}
35 super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
state
はパラメーター (テンソル) のオプティマイザー状態ですgroup
パラメータグループのオプティマイザ属性を格納しますparam
はパラメータテンソル すべてのステートテンソルは FP32 を使用します。
37 def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
これは、パラメーターに対して実行されたオプティマイザーステップの数です。
49 state['step'] = 0
勾配の指数移動平均、
51 state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
二乗勾配値の指数移動平均、
53 state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
パラメータの FP32 コピーを管理
55 state['fp32_copy'] = param.to(torch.float)
state
はパラメーター (テンソル) のオプティマイザー状態ですgroup
パラメータグループのオプティマイザ属性を格納しますgrad
パラメータの現在の勾配テンソルです param
はパラメータテンソル 57 def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
FP32 パラメータを取得
68 param_fp32 = state['fp32_copy']
可能な場合は FP32 のグラデーションを取得
70 grad_fp32 = self.grad_fp32.get(param, None)
71 if grad_fp32 is not None:
72 del self.grad_fp32[param]
73 grad = grad_fp32
74 else:
それ以外の場合は、グラデーションを FP32 に変換します。
76 grad = grad.to(torch.float)
体重減少の計算
79 grad = self.weight_decay(param_fp32, grad, group)
取得して
82 m, v = self.get_mv(state, group, grad)
オプティマイザーのステップ数を増やす
85 state['step'] += 1
Adam アップデートを実行
88 self.adam_update(state, group, param_fp32, m, v)
パラメータを設定
91 param.data = param_fp32.to(param.dtype)
94class GradScalerFP16(grad_scaler.GradScaler):
101 def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
102 allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
103 per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
104 per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
105
106 per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
107
108 with torch.no_grad():
ループスルーパラメータ
110 for group in optimizer.param_groups:
111 for param in group["params"]:
トレーニング不可のパラメータをスキップ
113 if param.grad is None:
114 continue
スパステンソルには実装されていません
116 if param.grad.is_sparse:
117 raise NotImplementedError
FP32 AdamFP16
optimizer.grad_fp32[param]
グラデーションに設定されたオプティマイザーを使用している場合
120 if isinstance(optimizer, AdamFP16):
121 grad = param.grad.to(torch.float)
122 optimizer.grad_fp32[param] = grad
それ以外の場合は、グラデーションを FP32 に変換しないでください。
124 else:
125 grad = param.grad
126
127 per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)
すべてのグラデーションをスケール解除
130 for device, per_dtype_grads in per_device_and_dtype_grads.items():
131 for grads in per_dtype_grads.values():
132 torch._amp_foreach_non_finite_check_and_unscale_(grads,
133 per_device_found_inf.get(device),
134 per_device_inv_scale.get(device))
136 return per_device_found_inf._per_device_tensors