mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 09:38:56 +08:00
Zero3 memory optimizations (#140)
This commit is contained in:
55
labml_nn/neox/utils/finetune.py
Normal file
55
labml_nn/neox/utils/finetune.py
Normal file
@ -0,0 +1,55 @@
|
||||
from typing import List, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_nn.neox.model import TransformerLayer, NeoXModule
|
||||
|
||||
|
||||
class FineTuner:
|
||||
def __init__(self, layers: List[NeoXModule]):
|
||||
self.layers = layers
|
||||
|
||||
def get_trainable_params(self) -> Dict[str, nn.Parameter]:
|
||||
params = {}
|
||||
for i, layer in enumerate(self.layers):
|
||||
params.update(self.get_layer_trainable_params(layer, prefix=f'layer_{i :02d}'))
|
||||
|
||||
return params
|
||||
|
||||
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_trainable_params(self):
|
||||
for layer in self.layers:
|
||||
# Set `requires_grad` to `False` for the entire layer.
|
||||
layer.requires_grad_(False)
|
||||
#
|
||||
for p in self.get_trainable_params().values():
|
||||
p.requires_grad_(True)
|
||||
|
||||
def state_dict(self):
|
||||
return {n: p.data.cpu() for n, p in self.get_trainable_params().items()}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
|
||||
params = self.get_trainable_params()
|
||||
for n, p in params.items():
|
||||
p.data[:] = state_dict[n].to(p.data.device)
|
||||
|
||||
for n in state_dict.keys():
|
||||
assert n in params, n
|
||||
|
||||
|
||||
class FineTuneBiases(FineTuner):
|
||||
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
|
||||
params = {}
|
||||
|
||||
if isinstance(layer, TransformerLayer):
|
||||
# No need to train the mlp bias because we are adding it with attention output
|
||||
params[f'{prefix}.attention.output.bias'] = layer.attention.output.bias
|
||||
params[f'{prefix}.attention.qkv_lin.bias'] = layer.attention.qkv_lin.bias
|
||||
params[f'{prefix}.ffn.dense_h_h4.bias'] = layer.ffn.dense_h_h4.bias
|
||||
else:
|
||||
pass
|
||||
|
||||
return params
|
||||
Reference in New Issue
Block a user