mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-27 20:24:17 +08:00
56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
from typing import List, Dict
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from labml_nn.neox.model import TransformerLayer, NeoXModule
|
|
|
|
|
|
class FineTuner:
|
|
def __init__(self, layers: List[NeoXModule]):
|
|
self.layers = layers
|
|
|
|
def get_trainable_params(self) -> Dict[str, nn.Parameter]:
|
|
params = {}
|
|
for i, layer in enumerate(self.layers):
|
|
params.update(self.get_layer_trainable_params(layer, prefix=f'layer_{i :02d}'))
|
|
|
|
return params
|
|
|
|
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
|
|
raise NotImplementedError
|
|
|
|
def set_trainable_params(self):
|
|
for layer in self.layers:
|
|
# Set `requires_grad` to `False` for the entire layer.
|
|
layer.requires_grad_(False)
|
|
#
|
|
for p in self.get_trainable_params().values():
|
|
p.requires_grad_(True)
|
|
|
|
def state_dict(self):
|
|
return {n: p.data.cpu() for n, p in self.get_trainable_params().items()}
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
|
|
params = self.get_trainable_params()
|
|
for n, p in params.items():
|
|
p.data[:] = state_dict[n].to(p.data.device)
|
|
|
|
for n in state_dict.keys():
|
|
assert n in params, n
|
|
|
|
|
|
class FineTuneBiases(FineTuner):
|
|
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
|
|
params = {}
|
|
|
|
if isinstance(layer, TransformerLayer):
|
|
# No need to train the mlp bias because we are adding it with attention output
|
|
params[f'{prefix}.attention.output.bias'] = layer.attention.output.bias
|
|
params[f'{prefix}.attention.qkv_lin.bias'] = layer.attention.qkv_lin.bias
|
|
params[f'{prefix}.ffn.dense_h_h4.bias'] = layer.ffn.dense_h_h4.bias
|
|
else:
|
|
pass
|
|
|
|
return params
|