GPT-NEOX 型号

以下是 GPT-NEOX 模型层的代码和加载 20B 检查点的代码。

图层load_state 中的方法加载该层的检查点。检查点加载助手已启用 checkpoint.py

16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit
25from labml_nn.neox import checkpoint
26from labml_nn.neox.utils.cache import get_cache
29class NeoXModule(nn.Module):
30    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
31        pass

嵌入层

这是一个标准的嵌入层,其中包含用于加载检查点的代码。

34class Embedding(NeoXModule):
  • n_vocab 是词汇量的大小
  • n_hidden 是嵌入的大小
41    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
46        super().__init__()
47
48        self.emb = nn.Embedding(n_vocab, n_hidden)
  • x 是形状的令牌 ID[batch_size, seq_len]
50    def forward(self, x: torch.Tensor):
54        return self.emb(x)

加载检查点的代码

56    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
60        with monit.section('Load embedding layer'):
61            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)

旋转位置嵌入

GPT-NEOX 使用旋转位置嵌入(RoP)

我们在这里注释了 RoPe 的实现,并附上了更多关于理论的注释。

64class RoPE(nn.Module):
  • d_rope 是 RoPe 嵌入的要素数量
  • base 是的基础,默认为
74    def __init__(self, d_rope: int, base: float = 10_000.):
79        super().__init__()

为要素存储

82        self.theta = None

缓存

84        self.cos_cached = None
85        self.sin_cached = None

基地

88        self.base = base

ROPE 的要素数量

90        self.d_rope = d_rope

旋转要素

92    @staticmethod
93    def rotate_half(x: torch.Tensor):
99        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
100        return torch.cat((-x2, x1), dim=-1)
  • x 有形状[..., seq, n_heads, d_k]
  • offset 是的起始位置x 。这是我们缓存先前位置的键和查询的时候
102    def forward(self, x: torch.Tensor, offset: int = 0):

获取实际序列长度

110        seq_len = x.shape[-3] + offset

初始化

113        if self.theta is None:

115            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
116            self.theta = theta.to(x.device).to(x.dtype)

初始化缓存

119        if (
120                self.cos_cached is None or
121                seq_len > self.cos_cached.shape[1] or
122                self.cos_cached.device != x.device or
123                self.cos_cached.dtype != x.dtype
124        ):

获取头寸指数

126            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

128            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)

连接这样我们就有 row

132            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)

在 fp32 中计算

135            with autocast(enabled=False):
136                idx_theta2 = idx_theta2.float()

添加头部尺寸

138                self.cos_cached = idx_theta2.cos()[:, None, :]
139                self.sin_cached = idx_theta2.sin()[:, None, :]

缓存它们

142            self.cos_cached = self.cos_cached.to(x.dtype)
143            self.sin_cached = self.sin_cached.to(x.dtype)

拆分要素。我们仅将 RoPe 应用于要d_rope

146        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]

从缓存中获取 sin 和 cos 值

149        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]

绳索嵌入

对于

161        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)

连接未获得 RoPe 嵌入的功能

164        return torch.cat((x_rope, x_pass), dim=-1)

注意层

167class AttentionLayer(nn.Module):
  • n_hidden 嵌入中的要素数量
  • n_heads 注意头的数量
  • rope_percentage 要添加 RoPe 嵌入的要素百分比
  • mask_fill 遮蔽注意力矩阵的填充值
172    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
173                 mask_fill: float = -10_000.0):
180        super().__init__()
181
182        self.n_heads = n_heads
183        self.mask_fill = mask_fill

用于查询、键和值的线性图层

186        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)

最后的线性层

188        self.output = nn.Linear(n_hidden, n_hidden)

每头特征数

191        d_k = n_hidden // n_heads

绳索嵌入模块

193        self.rope = RoPE(int(d_k * rope_percentage))

注意力缩放系数

196        self.scale = 1 / math.sqrt(d_k)

缓存因果掩码

199        self.causal_mask = None

注意 softmax 模块

202        self.softmax = nn.Softmax(dim=-2)

计算因果掩码

204    def _get_mask(self, attn: torch.Tensor):

查询和密钥长度

212        nq, nk = attn.shape[1:3]

创建遮罩

215        if (
216                self.causal_mask is None or
217                self.causal_mask.shape[0] != nq or
218                self.causal_mask.shape[1] != nk or
219                self.causal_mask.device != attn.device
220        ):
221            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)

从缓存中返回

224        return self.causal_mask[None, :, :, None]
  • x 有形状[batch_size, seq_len, n_hidden]
226    def forward(self, x: torch.Tensor):

获取查询、键和值嵌入(全部串联)。最后一个维度大小将从 n_hidden 更改为->3 x n_hidden

232        qkv = self.qkv_lin(x)

通过将形状改为分成头部[batch_size, seq_len, n_heads, 3 * d_k]

235        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)

分为查询、键和值各形状[batch_size, seq_len, n_heads, 3 * d_k]

237        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)

如果我们正在缓存之前令牌的状态

240        if get_cache().get('use_cache', False):

获取状态 ID。我们用它来检索以前的状态并存储下一个状态

242            prev_state_id, next_state_id = get_cache().get('state_ids')

如果有缓存

244            if prev_state_id is not None:

获取过去的键和值。这些会有形状[batch_size, prev_seq_len, n_heads, d_k]

246                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')

当前嵌入的偏移量

248                offset = k_past.shape[1]

添加绳索嵌入

251                q = self.rope(q, offset=offset)
252                k = self.rope(k, offset=offset)

串联过去

255                k = torch.cat([k_past, k], dim=1)
256                v = torch.cat([v_past, v], dim=1)
257            else:

添加绳索嵌入

259                q = self.rope(q)
260                k = self.rope(k)

保存当前状态

263            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
264        else:

没有缓存-只需添加 RoPe 嵌入即可

266            q = self.rope(q)
267            k = self.rope(k)

禁用自动投射到 fp16 以进行注意力计算

270        with autocast(enabled=False):
271            if q.dtype == torch.float16:

如果当前数据类型为 fp16,则转换为 fp32

273                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
274            else:

不要为 bfloat 进行投射

276                attn = torch.einsum('bihk,bjhk->bijh', q, k)

缩放注意力

279            attn = attn * self.scale

获得因果口罩

282            mask = self._get_mask(attn)

涂抹面膜

284            attn.masked_fill_(mask, self.mask_fill)

注意 softmax

287            attn = self.softmax(attn)

获取注意力加权值

290        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)

[batch_size, seq_len, n_heads, d_k] to batch_size、seq_len、n_hidden 进行重塑 `

293        output = output.reshape(*x.shape)

最后的线性层

296        return self.output(output)

前馈网络

299class FFNLayer(nn.Module):
  • n_hidden 是嵌入的大小
304    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
308        super().__init__()
309
310        if not d_ff:
311            d_ff = n_hidden * 4

扩展线性层

314        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)

GELU 激活

316        self.activation = nn.GELU()

收缩线性层

318        self.dense_h4_h = nn.Linear(d_ff, n_hidden)
  • x 有形状[batch_size, seq_len, n_hidden]
320    def forward(self, x: torch.Tensor):
324        x = self.dense_h_h4(x)
325        x = self.activation(x)
326        x = self.dense_h4_h(x)
327
328        return x

变压器层

331class TransformerLayer(NeoXModule):
  • n_hidden 是嵌入的大小
  • n_heads 是人头的数量
Ou@@

t 实现不包括辍学

336    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):
343        super().__init__()

注意之前的图层规范化

346        self.pre_ln_attn = nn.LayerNorm(n_hidden)

FFN 之前的层标准化

348        self.pre_ln_ffn = nn.LayerNorm(n_hidden)

注意层

351        self.attention = AttentionLayer(n_hidden, n_heads)

FFN 层

353        self.ffn = FFNLayer(n_hidden)
  • x 是形状的嵌入[batch_size, seq_len, n_hidden]
355    def forward(self, x: torch.Tensor):

剩余连接

361        residual = x

NeoX 并行运行注意力和前馈网络

363        attn = self.attention(self.pre_ln_attn(x))
364        ffn = self.ffn(self.pre_ln_ffn(x))

添加它们和剩余的连接

366        return attn + ffn + residual

加载检查点的代码

368    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
372        with monit.section('Load transformer layer'):

注意力输出变换

374            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
375            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)

注意力查询、关键和价值转换

378            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
379            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)

注意之前先进行分层规范

382            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
383            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)

FFN 第二次变换

386            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
387            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)

FFN 首次改造

390            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
391            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)

FFN 之前的分层规范

394            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
395            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)

最终归一化层

398class FinalNorm(NeoXModule):
  • n_hidden 是嵌入的大小
403    def __init__(self, n_hidden: int = 6_144):
407        super().__init__()
408
409        self.ln = nn.LayerNorm(n_hidden)
  • x 是形状的嵌入[batch_size, seq_len, n_hidden]
411    def forward(self, x: torch.Tensor):
415        return self.ln(x)

加载检查点的代码

417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421        with monit.section('Load final normalization layer'):
422            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
423            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)

读出层

426class ReadoutLayer(NeoXModule):
  • n_hidden 是嵌入的大小
  • n_vocab 是词汇量的大小
431    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
436        super().__init__()
437
438        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
  • x 是形状的嵌入[batch_size, seq_len, n_hidden]
440    def forward(self, x: torch.Tensor):
444        return self.linear(x)

加载检查点的代码

446    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
450        with monit.section('Load final linear layer'):
451            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
454class LayerGenerator:
455    pre_created_layers: Dict[Any, Optional[NeoXModule]]

用于创建层的生成器

图层的生成顺序与检查点相同。

它给出了层None 何时不可用;我们使用层索引作为 NeoX,并且在我们的实现中不需要两个变换层。

  • n_vocab 是词汇表中代币的数量
  • n_hidden 是嵌入中的要素数量
  • n_layers 是变压器层的数量
  • n_heads 是注意头的数量
  • filter_layers 是要使用的图层集。如果为 None,则使用所有图层。这用于测试层数较少的模型的较小版本
  • is_clone_layers 指定是否克隆变压器层(快一点)
  • dtype 是模型的数据类型
  • device 是该型号的设备
  • is_llm_int8 指定是否使用 int8 量化
  • llm_int8_threshold用于分隔异常值要素的阈值
457    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
458                 n_layers: int = 44, n_heads: int = 64,
459                 filter_layers: Optional[Set] = None,
460                 is_clone_layers: bool = True,
461                 dtype: torch.dtype = torch.float,
462                 device: torch.device = torch.device('cpu'),
463                 is_llm_int8: bool = False,
464                 llm_int8_threshold: float = 6.0,
465                 ):
486        if filter_layers is None:
487            filter_layers = set(range(n_layers + 3))
488
489        self.n_vocab = n_vocab
490        self.n_hidden = n_hidden
491        self.n_layers = n_layers
492        self.n_heads = n_heads
493        self.filter_layers = filter_layers
494        self.is_clone_layers = is_clone_layers
495        self.dtype = dtype
496        self.device = device
497        self.is_llm_int8 = is_llm_int8
498        self.llm_int8_threshold = llm_int8_threshold
499
500        self.pre_created_layers = dict(
501            transformer_layer=None,
502        )

准备图层以供使用

我们将图层移动到设备并将其转换为正确的数据类型

  • layer 是要准备的图层
  • 返回准备好的图层

504    def _prepare_layer(self, layer: NeoXModule):
513        return layer.to(self.device, self.dtype)

加载检查点后的图层变换

此函数在加载检查点后实现层转换。

目前,它仅应用 int8 量化。

  • layer 是要准备的图层
  • is_llm_int8 指定是否使用 int8 量化
  • device 是该型号的设备
  • llm_int8_threshold用于分隔异常值要素的阈值
  • 返回准备好的图层

515    @torch.no_grad()
516    def post_load_prepare(self, layer: NeoXModule, *,
517                          is_llm_int8: bool = None,
518                          device: torch.device = None,
519                          llm_int8_threshold: float = None,
520                          ):

如果未指定,则获取默认值

538        if is_llm_int8 is None:
539            is_llm_int8 = self.is_llm_int8
540        if device is None:
541            device = self.device
542        if llm_int8_threshold is None:
543            llm_int8_threshold = self.llm_int8_threshold

如果不使用 int8 量化则跳过

546        if not is_llm_int8:
547            return layer

仅转换变压器层中的线性层

550        if not isinstance(layer, TransformerLayer):
551            return layer

使用在实用程序make_llm_int8_linear 定义。

554        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear

转换线性图层

557        with monit.section('Convert to int8'):
558            layer.attention.output = make_llm_int8_linear(layer.attention.output,
559                                                          device=device,
560                                                          threshold=llm_int8_threshold)
561            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
562                                                           device=device,
563                                                           threshold=llm_int8_threshold)
564            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
565                                                        device=device,
566                                                        threshold=llm_int8_threshold)
567            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
568                                                        device=device,
569                                                        threshold=llm_int8_threshold)

571        return layer

创建和缓存图层

复制缓存图层比初始化新图层要快,因为初始化参数需要时间。

  • name 是层的名称
  • creator 是创建图层的函数
  • 返回创建的图层或缓存图层的副本

573    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
585        if not self.is_clone_layers:
586            return self._prepare_layer(creator())
587
588        if self.pre_created_layers[name] is None:
589            self.pre_created_layers[name] = self._prepare_layer(creator())
590
591        layer = copy.deepcopy(self.pre_created_layers[name])
592        return layer
594    def _create_transformer_layer(self):
595        return self._create_and_cache_layer(
596            'transformer_layer',
597            lambda: TransformerLayer(self.n_hidden, self.n_heads)
598        )
600    def _create_embedding_layer(self):
601        return Embedding(self.n_vocab, self.n_hidden)
603    def _create_final_norm_layer(self):
604        return FinalNorm(self.n_hidden)
606    def _create_readout_layer(self):
607        return ReadoutLayer(self.n_hidden, self.n_vocab)

获取图层的生成器

609    @torch.no_grad()
610    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:

嵌入层

615        if 0 in self.filter_layers:
616            with monit.section('Embedding layer'):
617                layer = self._prepare_layer(self._create_embedding_layer())
618            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')

变压器层

621        for i in range(self.n_layers):

变压器层

623            if i + 1 in self.filter_layers:
624                with monit.section(f'Transformer Layer {i}'):
625                    yield self._create_transformer_layer(), \
626                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
627                           f'layer_{i + 2 :02d}-model_01-model_states.pt')

最终归一化层

630        if self.n_layers + 1 in self.filter_layers:
631            with monit.section('Final norm layer'):
632                layer = self._prepare_layer(self._create_final_norm_layer())
633            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')

读出层

636        if self.n_layers + 2 in self.filter_layers:
637            with monit.section('Readout layer'):
638                layer = self._prepare_layer(self._create_readout_layer())
639            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
640
641        for k in self.pre_created_layers.keys():
642            self.pre_created_layers[k] = None

返回总层数

644    @property
645    def total_layers(self):
649        return self.n_layers + 3

用于加载层的生成器

651    @torch.no_grad()
652    def load(self) -> Generator[NeoXModule, None, None]:
656        with monit.section("Layers"):
657            for i, (layer, files) in enumerate(self.get_layers()):
658                if files is not None:
659                    layer.load_state(*checkpoint.load_checkpoint_files(files))
660
661                layer = self.post_load_prepare(layer)
662
663                monit.progress(min(0.99, (i + 1) / self.total_layers))
664                yield layer