GPT-NeoX Checkpoints

11from typing import Dict, Union, Tuple
12
13import torch
14from torch import nn
15
16from labml import monit, lab, logger
17from labml.logger import Text, inspect
18from labml.utils.download import download_file

Parent url

21CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'

Download path

24CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
25if not CHECKPOINTS_DOWNLOAD_PATH.exists():
26    CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
27inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH)

Get files to download

    Returns a list of files to be downloaded

30def get_files_to_download(n_layers: int = 44):
36    layers = (

Embedding layer

38            [0] +

Transformer layers

40            list(range(2, 2 + n_layers)) +

Final normalization layer and readout layer

42            [47, 48]
43    )
44
45    return (

Vocabulary and configs

47            ['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +

Layer checkpoints

49            [f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +

Empty states (not used)

51            [f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
52    )

Download all checkpoint files

55def download(n_layers: int = 44):

Get files to download

61    files = get_files_to_download(n_layers)

Iterate

64    for i, f in monit.enum('Download All', files):

Log

66        logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])

Download

68        download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f)

Load a pair of checkpoint files

  • files pair of files to load
  • Returns the loaded parameter tensors

71def load_checkpoint_files(files: Tuple[str, str]):
78    checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
79    with monit.section('Load checkpoint'):
80        data = [torch.load(checkpoint_path / f) for f in files]
81
82    return data

Load a parameter by merging the partitions along first dimension

  • param is the parameter
  • key is the name of the parameter
  • p1 first partition dictionary
  • p2 second partition dictionary
85def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
86                       p2: Dict[str, torch.Tensor]):
95    w1, w2 = p1[key], p2[key]
96    param.data[:w1.shape[0]] = w1
97    param.data[w1.shape[0]:] = w2

Load a parameter by merging the partitions along second dimension

  • param is the parameter
  • key is the name of the parameter
  • p1 first partition dictionary
  • p2 second partition dictionary
100def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
101                       p2: Dict[str, torch.Tensor]):
110    w1, w2 = p1[key], p2[key]
111    param.data[:, :w1.shape[1]] = w1
112    param.data[:, w1.shape[1]:] = w2

Load an un-partitioned parameter

This does a sanity check to make use both partitions are the same

  • param is the parameter
  • key is the name of the parameter
  • p1 first partition dictionary
  • p2 second partition dictionary
115def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
116                           p2: Dict[str, torch.Tensor]):
127    w1, w2 = p1[key], p2[key]
128
129    diff = sum((w1 - w2) ** 2).item()
130    assert diff < 1e-4, f'The partitions do not match: {key}'
131
132    param.data[:] = (w1 + w2) / 2.

Load biases that are partitioned which gets added on reduce

  • param is the parameter
  • key is the name of the parameter
  • p1 first partition dictionary
  • p2 second partition dictionary
135def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
136                     p2: Dict[str, torch.Tensor]):
145    w1, w2 = p1[key], p2[key]
146
147    param.data[:] = w1 + w2

151if __name__ == '__main__':
152    download()