16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit
25from labml_nn.neox import checkpoint
26from labml_nn.neox.utils.cache import get_cache29class NeoXModule(nn.Module):30 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
31 pass34class Embedding(NeoXModule):n_vocab
是词汇量的大小n_hidden
是嵌入的大小41 def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):46 super().__init__()
47
48 self.emb = nn.Embedding(n_vocab, n_hidden)x
是形状的令牌 ID[batch_size, seq_len]
50 def forward(self, x: torch.Tensor):54 return self.emb(x)加载检查点的代码
56 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):60 with monit.section('Load embedding layer'):
61 checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)64class RoPE(nn.Module):d_rope
是 RoPe 嵌入的要素数量base
是的基础,默认为74 def __init__(self, d_rope: int, base: float = 10_000.):79 super().__init__()为要素存储
82 self.theta = None缓存和
84 self.cos_cached = None
85 self.sin_cached = None基地
88 self.base = baseROPE 的要素数量
90 self.d_rope = d_rope92 @staticmethod
93 def rotate_half(x: torch.Tensor):99 x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
100 return torch.cat((-x2, x1), dim=-1)x
有形状[..., seq, n_heads, d_k]
offset
是的起始位置x
。这是我们缓存先前位置的键和查询的时候102 def forward(self, x: torch.Tensor, offset: int = 0):获取实际序列长度
110 seq_len = x.shape[-3] + offset初始化
113 if self.theta is None:115 theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
116 self.theta = theta.to(x.device).to(x.dtype)初始化并缓存
119 if (
120 self.cos_cached is None or
121 seq_len > self.cos_cached.shape[1] or
122 self.cos_cached.device != x.device or
123 self.cos_cached.dtype != x.dtype
124 ):获取头寸指数
126 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)128 idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)132 idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)在 fp32 中计算和
135 with autocast(enabled=False):
136 idx_theta2 = idx_theta2.float()添加头部尺寸
138 self.cos_cached = idx_theta2.cos()[:, None, :]
139 self.sin_cached = idx_theta2.sin()[:, None, :]缓存它们
142 self.cos_cached = self.cos_cached.to(x.dtype)
143 self.sin_cached = self.sin_cached.to(x.dtype)拆分要素。我们仅将 RoPe 应用于要d_rope
素
146 x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]从缓存中获取 sin 和 cos 值
149 cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]161 x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)连接未获得 RoPe 嵌入的功能
164 return torch.cat((x_rope, x_pass), dim=-1)167class AttentionLayer(nn.Module):n_hidden
嵌入中的要素数量n_heads
注意头的数量rope_percentage
要添加 RoPe 嵌入的要素百分比mask_fill
遮蔽注意力矩阵的填充值172 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
173 mask_fill: float = -10_000.0):180 super().__init__()
181
182 self.n_heads = n_heads
183 self.mask_fill = mask_fill用于查询、键和值的线性图层
186 self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)最后的线性层
188 self.output = nn.Linear(n_hidden, n_hidden)每头特征数
191 d_k = n_hidden // n_heads绳索嵌入模块
193 self.rope = RoPE(int(d_k * rope_percentage))注意力缩放系数
196 self.scale = 1 / math.sqrt(d_k)缓存因果掩码
199 self.causal_mask = None注意 softmax 模块
202 self.softmax = nn.Softmax(dim=-2)204 def _get_mask(self, attn: torch.Tensor):查询和密钥长度
212 nq, nk = attn.shape[1:3]创建遮罩
215 if (
216 self.causal_mask is None or
217 self.causal_mask.shape[0] != nq or
218 self.causal_mask.shape[1] != nk or
219 self.causal_mask.device != attn.device
220 ):
221 self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)从缓存中返回
224 return self.causal_mask[None, :, :, None]x
有形状[batch_size, seq_len, n_hidden]
226 def forward(self, x: torch.Tensor):获取查询、键和值嵌入(全部串联)。最后一个维度大小将从 n_hidden 更改为->3 x n_hidden
232 qkv = self.qkv_lin(x)通过将形状改为分成头部[batch_size, seq_len, n_heads, 3 * d_k]
235 qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)分为查询、键和值各形状[batch_size, seq_len, n_heads, 3 * d_k]
237 q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)如果我们正在缓存之前令牌的状态
240 if get_cache().get('use_cache', False):获取状态 ID。我们用它来检索以前的状态并存储下一个状态
242 prev_state_id, next_state_id = get_cache().get('state_ids')如果有缓存
244 if prev_state_id is not None:获取过去的键和值。这些会有形状[batch_size, prev_seq_len, n_heads, d_k]
246 k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')当前嵌入的偏移量
248 offset = k_past.shape[1]添加绳索嵌入
251 q = self.rope(q, offset=offset)
252 k = self.rope(k, offset=offset)串联过去
255 k = torch.cat([k_past, k], dim=1)
256 v = torch.cat([v_past, v], dim=1)
257 else:添加绳索嵌入
259 q = self.rope(q)
260 k = self.rope(k)保存当前状态
263 get_cache().push(f'attn_kv_{next_state_id}', (k, v))
264 else:没有缓存-只需添加 RoPe 嵌入即可
266 q = self.rope(q)
267 k = self.rope(k)禁用自动投射到 fp16 以进行注意力计算
270 with autocast(enabled=False):
271 if q.dtype == torch.float16:如果当前数据类型为 fp16,则转换为 fp32
273 attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
274 else:不要为 bfloat 进行投射
276 attn = torch.einsum('bihk,bjhk->bijh', q, k)缩放注意力
279 attn = attn * self.scale获得因果口罩
282 mask = self._get_mask(attn)涂抹面膜
284 attn.masked_fill_(mask, self.mask_fill)注意 softmax
287 attn = self.softmax(attn)获取注意力加权值
290 output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)从[batch_size, seq_len, n_heads, d_k] to
batch_size、seq_len、n_hidden 进行重塑 `
293 output = output.reshape(*x.shape)最后的线性层
296 return self.output(output)299class FFNLayer(nn.Module):n_hidden
是嵌入的大小304 def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):308 super().__init__()
309
310 if not d_ff:
311 d_ff = n_hidden * 4扩展线性层
314 self.dense_h_h4 = nn.Linear(n_hidden, d_ff)GELU 激活
316 self.activation = nn.GELU()收缩线性层
318 self.dense_h4_h = nn.Linear(d_ff, n_hidden)x
有形状[batch_size, seq_len, n_hidden]
320 def forward(self, x: torch.Tensor):324 x = self.dense_h_h4(x)
325 x = self.activation(x)
326 x = self.dense_h4_h(x)
327
328 return x331class TransformerLayer(NeoXModule):336 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):343 super().__init__()注意之前的图层规范化
346 self.pre_ln_attn = nn.LayerNorm(n_hidden)FFN 之前的层标准化
348 self.pre_ln_ffn = nn.LayerNorm(n_hidden)注意层
351 self.attention = AttentionLayer(n_hidden, n_heads)FFN 层
353 self.ffn = FFNLayer(n_hidden)x
是形状的嵌入[batch_size, seq_len, n_hidden]
355 def forward(self, x: torch.Tensor):剩余连接
361 residual = xNeoX 并行运行注意力和前馈网络
363 attn = self.attention(self.pre_ln_attn(x))
364 ffn = self.ffn(self.pre_ln_ffn(x))添加它们和剩余的连接
366 return attn + ffn + residual加载检查点的代码
368 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):372 with monit.section('Load transformer layer'):注意力输出变换
374 checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
375 checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)注意力查询、关键和价值转换
378 checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
379 checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)注意之前先进行分层规范
382 checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
383 checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)FFN 第二次变换
386 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
387 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)FFN 首次改造
390 checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
391 checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)FFN 之前的分层规范
394 checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
395 checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)398class FinalNorm(NeoXModule):n_hidden
是嵌入的大小403 def __init__(self, n_hidden: int = 6_144):407 super().__init__()
408
409 self.ln = nn.LayerNorm(n_hidden)x
是形状的嵌入[batch_size, seq_len, n_hidden]
411 def forward(self, x: torch.Tensor):415 return self.ln(x)加载检查点的代码
417 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):421 with monit.section('Load final normalization layer'):
422 checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
423 checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)读出层
426class ReadoutLayer(NeoXModule):n_hidden
是嵌入的大小n_vocab
是词汇量的大小431 def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):436 super().__init__()
437
438 self.linear = nn.Linear(n_hidden, n_vocab, bias=False)x
是形状的嵌入[batch_size, seq_len, n_hidden]
440 def forward(self, x: torch.Tensor):444 return self.linear(x)加载检查点的代码
446 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):450 with monit.section('Load final linear layer'):
451 checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)454class LayerGenerator:
455 pre_created_layers: Dict[Any, Optional[NeoXModule]]图层的生成顺序与检查点相同。
它给出了层None
何时不可用;我们使用层索引作为 NeoX,并且在我们的实现中不需要两个变换层。
n_vocab
是词汇表中代币的数量n_hidden
是嵌入中的要素数量n_layers
是变压器层的数量n_heads
是注意头的数量filter_layers
是要使用的图层集。如果为 None,则使用所有图层。这用于测试层数较少的模型的较小版本is_clone_layers
指定是否克隆变压器层(快一点)dtype
是模型的数据类型device
是该型号的设备is_llm_int8
指定是否使用 int8 量化llm_int8_threshold
是用于分隔异常值要素的阈值457 def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
458 n_layers: int = 44, n_heads: int = 64,
459 filter_layers: Optional[Set] = None,
460 is_clone_layers: bool = True,
461 dtype: torch.dtype = torch.float,
462 device: torch.device = torch.device('cpu'),
463 is_llm_int8: bool = False,
464 llm_int8_threshold: float = 6.0,
465 ):486 if filter_layers is None:
487 filter_layers = set(range(n_layers + 3))
488
489 self.n_vocab = n_vocab
490 self.n_hidden = n_hidden
491 self.n_layers = n_layers
492 self.n_heads = n_heads
493 self.filter_layers = filter_layers
494 self.is_clone_layers = is_clone_layers
495 self.dtype = dtype
496 self.device = device
497 self.is_llm_int8 = is_llm_int8
498 self.llm_int8_threshold = llm_int8_threshold
499
500 self.pre_created_layers = dict(
501 transformer_layer=None,
502 )504 def _prepare_layer(self, layer: NeoXModule):513 return layer.to(self.device, self.dtype)此函数在加载检查点后实现层转换。
目前,它仅应用 int8 量化。
layer
是要准备的图层is_llm_int8
指定是否使用 int8 量化device
是该型号的设备llm_int8_threshold
是用于分隔异常值要素的阈值返回准备好的图层
515 @torch.no_grad()
516 def post_load_prepare(self, layer: NeoXModule, *,
517 is_llm_int8: bool = None,
518 device: torch.device = None,
519 llm_int8_threshold: float = None,
520 ):如果未指定,则获取默认值
538 if is_llm_int8 is None:
539 is_llm_int8 = self.is_llm_int8
540 if device is None:
541 device = self.device
542 if llm_int8_threshold is None:
543 llm_int8_threshold = self.llm_int8_threshold如果不使用 int8 量化则跳过
546 if not is_llm_int8:
547 return layer仅转换变压器层中的线性层
550 if not isinstance(layer, TransformerLayer):
551 return layer554 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear转换线性图层
557 with monit.section('Convert to int8'):
558 layer.attention.output = make_llm_int8_linear(layer.attention.output,
559 device=device,
560 threshold=llm_int8_threshold)
561 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
562 device=device,
563 threshold=llm_int8_threshold)
564 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
565 device=device,
566 threshold=llm_int8_threshold)
567 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
568 device=device,
569 threshold=llm_int8_threshold)571 return layer573 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):585 if not self.is_clone_layers:
586 return self._prepare_layer(creator())
587
588 if self.pre_created_layers[name] is None:
589 self.pre_created_layers[name] = self._prepare_layer(creator())
590
591 layer = copy.deepcopy(self.pre_created_layers[name])
592 return layer594 def _create_transformer_layer(self):
595 return self._create_and_cache_layer(
596 'transformer_layer',
597 lambda: TransformerLayer(self.n_hidden, self.n_heads)
598 )600 def _create_embedding_layer(self):
601 return Embedding(self.n_vocab, self.n_hidden)603 def _create_final_norm_layer(self):
604 return FinalNorm(self.n_hidden)606 def _create_readout_layer(self):
607 return ReadoutLayer(self.n_hidden, self.n_vocab)609 @torch.no_grad()
610 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:嵌入层
615 if 0 in self.filter_layers:
616 with monit.section('Embedding layer'):
617 layer = self._prepare_layer(self._create_embedding_layer())
618 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')变压器层
621 for i in range(self.n_layers):变压器层
623 if i + 1 in self.filter_layers:
624 with monit.section(f'Transformer Layer {i}'):
625 yield self._create_transformer_layer(), \
626 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
627 f'layer_{i + 2 :02d}-model_01-model_states.pt')最终归一化层
630 if self.n_layers + 1 in self.filter_layers:
631 with monit.section('Final norm layer'):
632 layer = self._prepare_layer(self._create_final_norm_layer())
633 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')读出层
636 if self.n_layers + 2 in self.filter_layers:
637 with monit.section('Readout layer'):
638 layer = self._prepare_layer(self._create_readout_layer())
639 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
640
641 for k in self.pre_created_layers.keys():
642 self.pre_created_layers[k] = None644 @property
645 def total_layers(self):649 return self.n_layers + 3651 @torch.no_grad()
652 def load(self) -> Generator[NeoXModule, None, None]:656 with monit.section("Layers"):
657 for i, (layer, files) in enumerate(self.get_layers()):
658 if files is not None:
659 layer.load_state(*checkpoint.load_checkpoint_files(files))
660
661 layer = self.post_load_prepare(layer)
662
663 monit.progress(min(0.99, (i + 1) / self.total_layers))
664 yield layer