11from typing import Dict, Union, Tuple
12
13import torch
14from torch import nn
15
16from labml import monit, lab, logger
17from labml.logger import Text, inspect
18from labml.utils.download import download_file
Parent url
21CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
Download path
24CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
25if not CHECKPOINTS_DOWNLOAD_PATH.exists():
26 CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
27inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH)
30def get_files_to_download(n_layers: int = 44):
36 layers = (
Embedding layer
38 [0] +
Transformer layers
40 list(range(2, 2 + n_layers)) +
Final normalization layer and readout layer
42 [47, 48]
43 )
44
45 return (
Vocabulary and configs
47 ['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +
Layer checkpoints
49 [f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +
Empty states (not used)
51 [f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
52 )
55def download(n_layers: int = 44):
Get files to download
61 files = get_files_to_download(n_layers)
Iterate
64 for i, f in monit.enum('Download All', files):
Log
66 logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
Download
68 download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f)
71def load_checkpoint_files(files: Tuple[str, str]):
78 checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
79 with monit.section('Load checkpoint'):
80 data = [torch.load(checkpoint_path / f) for f in files]
81
82 return data
param
is the parameter key
is the name of the parameter p1
first partition dictionary p2
second partition dictionary85def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
86 p2: Dict[str, torch.Tensor]):
95 w1, w2 = p1[key], p2[key]
96 param.data[:w1.shape[0]] = w1
97 param.data[w1.shape[0]:] = w2
param
is the parameter key
is the name of the parameter p1
first partition dictionary p2
second partition dictionary100def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
101 p2: Dict[str, torch.Tensor]):
110 w1, w2 = p1[key], p2[key]
111 param.data[:, :w1.shape[1]] = w1
112 param.data[:, w1.shape[1]:] = w2
This does a sanity check to make use both partitions are the same
param
is the parameter key
is the name of the parameter p1
first partition dictionary p2
second partition dictionary115def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
116 p2: Dict[str, torch.Tensor]):
127 w1, w2 = p1[key], p2[key]
128
129 diff = sum((w1 - w2) ** 2).item()
130 assert diff < 1e-4, f'The partitions do not match: {key}'
131
132 param.data[:] = (w1 + w2) / 2.
param
is the parameter key
is the name of the parameter p1
first partition dictionary p2
second partition dictionary135def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
136 p2: Dict[str, torch.Tensor]):
145 w1, w2 = p1[key], p2[key]
146
147 param.data[:] = w1 + w2
151if __name__ == '__main__':
152 download()