GPT-neox 检查点

11from typing import Dict, Union, Tuple
12
13import torch
14from torch import nn
15
16from labml import monit, lab, logger
17from labml.logger import Text, inspect
18from labml.utils.download import download_file

家长网址

21CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'

下载路径

24CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
25if not CHECKPOINTS_DOWNLOAD_PATH.exists():
26    CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
27inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH)

获取要下载的文件

    返回要下载的文件列表

30def get_files_to_download(n_layers: int = 44):
36    layers = (

嵌入层

38            [0] +

变压器层

40            list(range(2, 2 + n_layers)) +

最终归一化层和读出层

42            [47, 48]
43    )
44
45    return (

词汇和配置

47            ['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +

图层检查点

49            [f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +

空状态(未使用)

51            [f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
52    )

下载所有检查点文件

55def download(n_layers: int = 44):

获取要下载的文件

61    files = get_files_to_download(n_layers)

迭代

64    for i, f in monit.enum('Download All', files):

日志

66        logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])

下载

68        download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f)

加载一对检查点文件

  • files 一对要加载的文件
  • 返回加载的参数张量

71def load_checkpoint_files(files: Tuple[str, str]):
78    checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
79    with monit.section('Load checkpoint'):
80        data = [torch.load(checkpoint_path / f) for f in files]
81
82    return data

通过合并沿第一维度的分区来加载参数

  • param 是参数
  • key 是参数的名称
  • p1 第一个分区字典
  • p2 第二个分区字典
85def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
86                       p2: Dict[str, torch.Tensor]):
95    w1, w2 = p1[key], p2[key]
96    param.data[:w1.shape[0]] = w1
97    param.data[w1.shape[0]:] = w2

通过合并第二维度的分区来加载参数

  • param 是参数
  • key 是参数的名称
  • p1 第一个分区字典
  • p2 第二个分区字典
100def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
101                       p2: Dict[str, torch.Tensor]):
110    w1, w2 = p1[key], p2[key]
111    param.data[:, :w1.shape[1]] = w1
112    param.data[:, w1.shape[1]:] = w2

加载未分区的参数

这会进行健全性检查,以使用两个分区是相同的

  • param 是参数
  • key 是参数的名称
  • p1 第一个分区字典
  • p2 第二个分区字典
115def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
116                           p2: Dict[str, torch.Tensor]):
127    w1, w2 = p1[key], p2[key]
128
129    diff = sum((w1 - w2) ** 2).item()
130    assert diff < 1e-4, f'The partitions do not match: {key}'
131
132    param.data[:] = (w1 + w2) / 2.

分区的负载偏差在 reduce 时被添加

  • param 是参数
  • key 是参数的名称
  • p1 第一个分区字典
  • p2 第二个分区字典
135def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
136                     p2: Dict[str, torch.Tensor]):
145    w1, w2 = p1[key], p2[key]
146
147    param.data[:] = w1 + w2

151if __name__ == '__main__':
152    download()