16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache30class NeoXModule(nn.Module):31    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32        pass35class Embedding(NeoXModule):n_vocab
是词汇量的大小n_hidden
是嵌入的大小42    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):47        super().__init__()
48
49        self.emb = nn.Embedding(n_vocab, n_hidden)x
是形状的令牌 ID[batch_size, seq_len]
51    def forward(self, x: torch.Tensor):55        return self.emb(x)加载检查点的代码
57    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):61        with monit.section('Load embedding layer'):
62            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)65class RoPE(nn.Module):d_rope
是 RoPe 嵌入的要素数量base
是的基础,默认为75    def __init__(self, d_rope: int, base: float = 10_000.):80        super().__init__()为要素存储
83        self.theta = None缓存和
85        self.cos_cached = None
86        self.sin_cached = None基地
89        self.base = baseROPE 的要素数量
91        self.d_rope = d_rope93    @staticmethod
94    def rotate_half(x: torch.Tensor):100        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101        return torch.cat((-x2, x1), dim=-1)x
有形状[..., seq, n_heads, d_k]
offset
是的起始位置x
。这是我们缓存先前位置的键和查询的时候103    def forward(self, x: torch.Tensor, offset: int = 0):获取实际序列长度
111        seq_len = x.shape[-3] + offset初始化
114        if self.theta is None:116            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117            self.theta = theta.to(x.device).to(x.dtype)初始化并缓存
120        if (
121                self.cos_cached is None or
122                seq_len > self.cos_cached.shape[1] or
123                self.cos_cached.device != x.device or
124                self.cos_cached.dtype != x.dtype
125        ):获取头寸指数
127            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)129            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)133            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)在 fp32 中计算和
136            with autocast(enabled=False):
137                idx_theta2 = idx_theta2.float()添加头部尺寸
139                self.cos_cached = idx_theta2.cos()[:, None, :]
140                self.sin_cached = idx_theta2.sin()[:, None, :]缓存它们
143            self.cos_cached = self.cos_cached.to(x.dtype)
144            self.sin_cached = self.sin_cached.to(x.dtype)拆分要素。我们仅将 RoPe 应用于要d_rope
素
147        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]从缓存中获取 sin 和 cos 值
150        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]162        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)连接未获得 RoPe 嵌入的功能
165        return torch.cat((x_rope, x_pass), dim=-1)168class AttentionLayer(nn.Module):n_hidden
嵌入中的特征数量n_heads
注意力头的数量rope_percentage
添加 RoPE 嵌入的功能百分比mask_fill
掩盖注意力矩阵的填充值is_flash_attention
指定是否使用 FlashAttention173    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174                 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):183        super().__init__()
184
185        self.n_heads = n_heads
186        self.mask_fill = mask_fill用于查询、键和值的线性图层
189        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)最后的线性层
191        self.output = nn.Linear(n_hidden, n_hidden)每头特征数
194        d_k = n_hidden // n_heads绳索嵌入模块
196        self.rope = RoPE(int(d_k * rope_percentage))注意力缩放系数
199        self.scale = 1 / math.sqrt(d_k)缓存因果掩码
202        self.causal_mask = None注意 softmax 模块
205        self.softmax = nn.Softmax(dim=-2)208        if is_flash_attention:
209            try:
210                from flash_attn.flash_attention import FlashAttention
211                self.flash_attention = FlashAttention()
212            except ImportError:
213                logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214                           'Falling back to normal attention', Text.warning)
215                self.flash_attention = None
216        else:
217            self.flash_attention = None219    def _get_mask(self, attn: torch.Tensor):查询和密钥长度
227        nq, nk = attn.shape[1:3]创建遮罩
230        if (
231                self.causal_mask is None or
232                self.causal_mask.shape[0] != nq or
233                self.causal_mask.shape[1] != nk or
234                self.causal_mask.device != attn.device
235        ):
236            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)从缓存中返回
239        return self.causal_mask[None, :, :, None]x
有形状[batch_size, seq_len, n_hidden]
241    def forward(self, x: torch.Tensor):获取查询、键和值嵌入(全部串联)。最后一个维度大小将从 n_hidden 更改为->3 x n_hidden
247        qkv = self.qkv_lin(x)通过将形状改为分成头部[batch_size, seq_len, n_heads, 3 * d_k]
250        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)分为查询、键和值各形状[batch_size, seq_len, n_heads, 3 * d_k]
252        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)如果我们正在缓存之前令牌的状态
255        if get_cache().get('use_cache', False):获取状态 ID。我们用它来检索以前的状态并存储下一个状态
257            prev_state_id, next_state_id = get_cache().get('state_ids')如果有缓存
259            if prev_state_id is not None:获取过去的键和值。这些会有形状[batch_size, prev_seq_len, n_heads, d_k]
261                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')当前嵌入的偏移量
263                offset = k_past.shape[1]添加绳索嵌入
266                q = self.rope(q, offset=offset)
267                k = self.rope(k, offset=offset)串联过去
270                k = torch.cat([k_past, k], dim=1)
271                v = torch.cat([v_past, v], dim=1)
272            else:添加绳索嵌入
274                q = self.rope(q)
275                k = self.rope(k)保存当前状态
278            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279        else:没有缓存-只需添加 RoPe 嵌入即可
281            q = self.rope(q)
282            k = self.rope(k)使用闪光灯注意力
285        if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286            output = self.compute_flash_attention(q, k, v)否则,请正常注意
288        else:
289            output = self.compute_attention(q, k, v)从[batch_size, seq_len, n_heads, d_k] to
 batch_size、seq_len、n_hidden 进行重塑 `
292        output = output.reshape(*x.shape)最后的线性层
295        return self.output(output)297    def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):将它们堆叠成形状[batch_size, seq_len, 3, n_heads, d_k]
299        qkv = torch.stack((q, k, v), dim=2)
300        d_k = qkv.shape[-1]
301        if d_k <= 32:
302            pad = 32 - d_k
303        elif d_k <= 64:
304            pad = 64 - d_k
305        elif d_k <= 128:
306            pad = 128 - d_k
307        else:
308            raise ValueError(f'Head size {d_k} too large for flash attention')
309
310        if pad > 0:
311            qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313        output, _ = self.flash_attention(qkv, causal=True)输出的形状是这样的[batch_size, seq_len, n_heads, d_k + padding]
315        output = output[:, :, :, :d_k]
316
317        return output319    def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):禁用自动投射到 fp16 以进行注意力计算
321        with autocast(enabled=False):
322            if q.dtype == torch.float16:如果当前数据类型为 fp16,则转换为 fp32
324                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325            else:不要为 bfloat 进行投射
327                attn = torch.einsum('bihk,bjhk->bijh', q, k)缩放注意力
330            attn = attn * self.scale获得因果口罩
333            mask = self._get_mask(attn)涂抹面膜
335            attn.masked_fill_(mask, self.mask_fill)注意 softmax
338            attn = self.softmax(attn)获取注意力加权值
341        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343        return output346class FFNLayer(nn.Module):n_hidden
是嵌入的大小351    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):355        super().__init__()
356
357        if not d_ff:
358            d_ff = n_hidden * 4扩展线性层
361        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)GELU 激活
363        self.activation = nn.GELU()收缩线性层
365        self.dense_h4_h = nn.Linear(d_ff, n_hidden)x
有形状[batch_size, seq_len, n_hidden]
367    def forward(self, x: torch.Tensor):371        x = self.dense_h_h4(x)
372        x = self.activation(x)
373        x = self.dense_h4_h(x)
374
375        return x378class TransformerLayer(NeoXModule):383    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):392        super().__init__()注意之前的图层规范化
395        self.pre_ln_attn = nn.LayerNorm(n_hidden)FFN 之前的层标准化
397        self.pre_ln_ffn = nn.LayerNorm(n_hidden)注意层
400        self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)FFN 层
402        self.ffn = FFNLayer(n_hidden)x
是形状的嵌入[batch_size, seq_len, n_hidden]
404    def forward(self, x: torch.Tensor):剩余连接
410        residual = xNeoX 并行运行注意力和前馈网络
412        attn = self.attention(self.pre_ln_attn(x))
413        ffn = self.ffn(self.pre_ln_ffn(x))添加它们和剩余的连接
415        return attn + ffn + residual加载检查点的代码
417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):421        with monit.section('Load transformer layer'):注意力输出变换
423            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)注意力查询、关键和价值转换
427            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)注意之前先进行分层规范
431            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)FFN 第二次变换
435            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)FFN 首次改造
439            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)FFN 之前的分层规范
443            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)447class FinalNorm(NeoXModule):n_hidden
是嵌入的大小452    def __init__(self, n_hidden: int = 6_144):456        super().__init__()
457
458        self.ln = nn.LayerNorm(n_hidden)x
是形状的嵌入[batch_size, seq_len, n_hidden]
460    def forward(self, x: torch.Tensor):464        return self.ln(x)加载检查点的代码
466    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):470        with monit.section('Load final normalization layer'):
471            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)读出层
475class ReadoutLayer(NeoXModule):n_hidden
是嵌入的大小n_vocab
是词汇量的大小480    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):485        super().__init__()
486
487        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)x
是形状的嵌入[batch_size, seq_len, n_hidden]
489    def forward(self, x: torch.Tensor):493        return self.linear(x)加载检查点的代码
495    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):499        with monit.section('Load final linear layer'):
500            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)503class LayerGenerator:
504    pre_created_layers: Dict[Any, Optional[NeoXModule]]图层的生成顺序与检查点的生成顺序相同。
它在图层不可用None
时给出;我们将图层索引用作 NeoX,并且在实现中不需要两个转换层。
n_vocab
是词汇表中的代币数量n_hidden
是嵌入中的特征数量n_layers
是变压器层数n_heads
是注意力头的数量filter_layers
是要使用的图层集。如果没有,则将使用所有图层。这用于测试层数较少的模型的较小版本is_clone_layers
指定是否克隆变压器层(快一点)dtype
是模型的数据类型device
是模型的设备is_llm_int8
指定是否使用 int8 量化llm_int8_threshold
是用于分离异常特征的阈值is_flash_attention
指定是否使用 FlashAttention506    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507                 n_layers: int = 44, n_heads: int = 64,
508                 filter_layers: Optional[Set] = None,
509                 is_clone_layers: bool = True,
510                 dtype: torch.dtype = torch.float,
511                 device: torch.device = torch.device('cpu'),
512                 is_llm_int8: bool = False,
513                 llm_int8_threshold: float = 6.0,
514                 is_flash_attention: bool = False
515                 ):538        if filter_layers is None:
539            filter_layers = set(range(n_layers + 3))
540
541        self.n_vocab = n_vocab
542        self.n_hidden = n_hidden
543        self.n_layers = n_layers
544        self.n_heads = n_heads
545        self.filter_layers = filter_layers
546        self.is_clone_layers = is_clone_layers
547        self.dtype = dtype
548        self.device = device
549        self.is_llm_int8 = is_llm_int8
550        self.llm_int8_threshold = llm_int8_threshold
551        self.is_flash_attention = is_flash_attention
552
553        self.pre_created_layers = dict(
554            transformer_layer=None,
555        )557    def _prepare_layer(self, layer: NeoXModule):566        return layer.to(self.device, self.dtype)此函数在加载检查点后实现层转换。
目前,它仅应用 int8 量化。
layer
是要准备的图层is_llm_int8
指定是否使用 int8 量化device
是该型号的设备llm_int8_threshold
是用于分隔异常值要素的阈值返回准备好的图层
568    @torch.no_grad()
569    def post_load_prepare(self, layer: NeoXModule, *,
570                          is_llm_int8: bool = None,
571                          device: torch.device = None,
572                          llm_int8_threshold: float = None,
573                          ):如果未指定,则获取默认值
591        if is_llm_int8 is None:
592            is_llm_int8 = self.is_llm_int8
593        if device is None:
594            device = self.device
595        if llm_int8_threshold is None:
596            llm_int8_threshold = self.llm_int8_threshold如果不使用 int8 量化则跳过
599        if not is_llm_int8:
600            return layer仅转换变压器层中的线性层
603        if not isinstance(layer, TransformerLayer):
604            return layer607        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear转换线性图层
610        with monit.section('Convert to int8'):
611            layer.attention.output = make_llm_int8_linear(layer.attention.output,
612                                                          device=device,
613                                                          threshold=llm_int8_threshold)
614            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615                                                           device=device,
616                                                           threshold=llm_int8_threshold)
617            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618                                                        device=device,
619                                                        threshold=llm_int8_threshold)
620            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621                                                        device=device,
622                                                        threshold=llm_int8_threshold)624        return layer626    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):638        if not self.is_clone_layers:
639            return self._prepare_layer(creator())
640
641        if self.pre_created_layers[name] is None:
642            self.pre_created_layers[name] = self._prepare_layer(creator())
643
644        layer = copy.deepcopy(self.pre_created_layers[name])
645        return layer647    def _create_transformer_layer(self):
648        return self._create_and_cache_layer(
649            'transformer_layer',
650            lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651        )653    def _create_embedding_layer(self):
654        return Embedding(self.n_vocab, self.n_hidden)656    def _create_final_norm_layer(self):
657        return FinalNorm(self.n_hidden)659    def _create_readout_layer(self):
660        return ReadoutLayer(self.n_hidden, self.n_vocab)662    @torch.no_grad()
663    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:嵌入层
668        if 0 in self.filter_layers:
669            with monit.section('Embedding layer'):
670                layer = self._prepare_layer(self._create_embedding_layer())
671            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')变压器层
674        for i in range(self.n_layers):变压器层
676            if i + 1 in self.filter_layers:
677                with monit.section(f'Transformer Layer {i}'):
678                    yield self._create_transformer_layer(), \
679                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680                           f'layer_{i + 2 :02d}-model_01-model_states.pt')最终归一化层
683        if self.n_layers + 1 in self.filter_layers:
684            with monit.section('Final norm layer'):
685                layer = self._prepare_layer(self._create_final_norm_layer())
686            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')读出层
689        if self.n_layers + 2 in self.filter_layers:
690            with monit.section('Readout layer'):
691                layer = self._prepare_layer(self._create_readout_layer())
692            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694        for k in self.pre_created_layers.keys():
695            self.pre_created_layers[k] = None697    @property
698    def total_layers(self):702        return self.n_layers + 3704    @torch.no_grad()
705    def load(self) -> Generator[NeoXModule, None, None]:709        with monit.section("Layers"):
710            for i, (layer, files) in enumerate(self.get_layers()):
711                if files is not None:
712                    layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714                layer = self.post_load_prepare(layer)
715
716                monit.progress(min(0.99, (i + 1) / self.total_layers))
717                yield layer