අඩක්නිරවද්යතාව පුහුණුව සඳහා ආදම් ප්රශස්තකරණය

10from typing import Dict, Tuple, Optional, Any
11
12import torch
13from torch import nn
14from torch.optim import Optimizer
15from torch.cuda.amp import grad_scaler
16from collections import defaultdict, abc
17
18from labml_nn.optimizers import WeightDecay
19from labml_nn.optimizers.adam import Adam

අඩක්නිරවද්යතාව පුහුණුව සඳහා ආදම් ප්රශස්තකරණය

අපි ඇඩම් ඔප්ටිමයිසර් දීර් extend කරන නමුත් ශ්රේණි සහ මොහොත ගබඩා කිරීම සඳහා FP32 භාවිතා කරමු.

22class AdamFP16(Adam):
29    def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
30                 weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
31                 defaults: Optional[Dict[str, Any]] = None):

බිට්අනුක්රමික 32 ක් ගබඩා කිරීමේ පරාමිතිය. පහත GradScaler අර්ථ දක්වා ඇති පරිදි මෙය ජනාකීර්ණ වේ.

33        self.grad_fp32 = {}

ඇඩම් ප්රශස්තිකරණ ආරම්භකය අමතන්න

35        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)

පරාමිතිතත්වයක් ආරම්භ කරන්න

  • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
  • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
  • param පරාමිතිය tensor වේ
  • සියලුමරාජ්ය ආතතීන් FP32 භාවිතා කරයි.

    37    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):

    පරාමිතියමත ගෙන ඇති ප්රශස්තිකරණ පියවර ගණන මෙයයි,

    49        state['step'] = 0

    අනුක්රමිකක ඝාතීය වෙනස්වන සාමාන්යය,

    51        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)

    වර්ගඵලය අනුක්රමික වටිනාකම් ඝාතීය වෙනස්වන සාමාන්යය,

    53        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)

    පරාමිතීන්ගේFP32 පිටපතක් පවත්වා ගන්න

    55        state['fp32_copy'] = param.to(torch.float)

    දීඇති පරාමිති ටෙන්සරයක් සඳහා යාවත්කාලීන පියවරක් ගන්න

    • state පරාමිතිය ප්රශස්තකරණය රාජ්ය වේ (tensor)
    • group පරාමිති කණ්ඩායමේ ප්රශස්තිකරණ ගුණාංග ගබඩා කරයි
    • grad පරාමිතිය සඳහා වත්මන් ඵලය අනුක්රමික tensor වේ
  • param පරාමිතිය tensor වේ
  • 57    def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):

    FP32පරාමිතීන් ලබා ගන්න

    68        param_fp32 = state['fp32_copy']

    ලබාගත හැකි නම් FP32 අනුක්රමික ලබා ගන්න

    70        grad_fp32 = self.grad_fp32.get(param, None)
    71        if grad_fp32 is not None:
    72            del self.grad_fp32[param]
    73            grad = grad_fp32
    74        else:

    එසේනොමැතිනම්, ශ්රේණිය FP32 බවට පරිවර්තනය කරන්න

    76            grad = grad.to(torch.float)

    බරක්ෂය වීම ගණනය කරන්න

    79        grad = self.weight_decay(param_fp32, grad, group)

    ලබා ගන්න

    82        m, v = self.get_mv(state, group, grad)

    ප්රශස්තිකරණපියවර ගණන වැඩි කරන්න

    85        state['step'] += 1

    ආදම් යාවත්කාලීන කිරීම සිදු

    88        self.adam_update(state, group, param_fp32, m, v)

    පරාමිතීන්සකසන්න

    91        param.data = param_fp32.to(param.dtype)

    අර්ධනිරවද්යතා අනුක්රමික සහිත ග්රේඩියන්ට් පරිමාණය

    FP32ශ්රේණියේ භාවිතා කිරීම සඳහා අපි පයිටෝච් ශ්රේණියේ පරිමාණය දිගු කරමු.

    94class GradScalerFP16(grad_scaler.GradScaler):
    101    def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
    102                        allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
    103        per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
    104        per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
    105
    106        per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))  # type: ignore[var-annotated]
    107
    108        with torch.no_grad():

    පරාමිතීන්හරහා ලූප්

    110            for group in optimizer.param_groups:
    111                for param in group["params"]:

    පුහුණුකළ නොහැකි පරාමිතීන් මඟ හරින්න

    113                    if param.grad is None:
    114                        continue

    විරලආතතීන් සඳහා ක්රියාත්මක නොවේ

    116                    if param.grad.is_sparse:
    117                        raise NotImplementedError

    අපිFP32 ශ්රේණියේ AdamFP16 ප්රශස්තිකරණ කට්ටලය optimizer.grad_fp32[param] භාවිතා කරන්නේ නම්

    120                    if isinstance(optimizer, AdamFP16):
    121                        grad = param.grad.to(torch.float)
    122                        optimizer.grad_fp32[param] = grad

    එසේනොමැතිනම්, අනුක්රමික FP32 බවට පරිවර්තනය නොකරන්න

    124                    else:
    125                        grad = param.grad
    126
    127                    per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)

    සියලුමඅනුක්රමික පරිමාණය කරන්න

    130            for device, per_dtype_grads in per_device_and_dtype_grads.items():
    131                for grads in per_dtype_grads.values():
    132                    torch._amp_foreach_non_finite_check_and_unscale_(grads,
    133                                                                     per_device_found_inf.get(device),
    134                                                                     per_device_inv_scale.get(device))

    136        return per_device_found_inf._per_device_tensors