これは、GPT-Neoxモデルのレイヤー用のコードと20Bのチェックポイントをロードするコードです。
load_state
レイヤー内のメソッドは、そのレイヤーのチェックポイントをロードします。チェックポイントロードヘルパーがオンになっています checkpoint.py
16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache30class NeoXModule(nn.Module):31    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32        pass35class Embedding(NeoXModule):n_vocab
ボキャブラリーの大きさですn_hidden
は埋め込みのサイズです42    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):47        super().__init__()
48
49        self.emb = nn.Embedding(n_vocab, n_hidden)x
形状のトークンIDです [batch_size, seq_len]
51    def forward(self, x: torch.Tensor):55        return self.emb(x)チェックポイントをロードするコード
57    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):61        with monit.section('Load embedding layer'):
62            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)65class RoPE(nn.Module):d_rope
RoPE 埋め込みの機能の数ですbase
がの基底で、デフォルトは 75    def __init__(self, d_rope: int, base: float = 10_000.):80        super().__init__()機能用に保存するには
83        self.theta = Noneキャッシュと
85        self.cos_cached = None
86        self.sin_cached = Noneのベース
89        self.base = baseRoPE の機能の数
91        self.d_rope = d_rope93    @staticmethod
94    def rotate_half(x: torch.Tensor):100        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101        return torch.cat((-x2, x1), dim=-1)x
形がある [..., seq, n_heads, d_k]
offset
x
の開始位置です。これは、以前のポジションのキーとクエリをキャッシュしたときです103    def forward(self, x: torch.Tensor, offset: int = 0):実際のシーケンス長を取得
111        seq_len = x.shape[-3] + offset[初期化]
114        if self.theta is None:116            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117            self.theta = theta.to(x.device).to(x.dtype)初期化とキャッシュ
120        if (
121                self.cos_cached is None or
122                seq_len > self.cos_cached.shape[1] or
123                self.cos_cached.device != x.device or
124                self.cos_cached.dtype != x.dtype
125        ):位置インデックスを取得
127            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)129            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)133            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)計算して fp32 で
136            with autocast(enabled=False):
137                idx_theta2 = idx_theta2.float()頭部寸法を追加
139                self.cos_cached = idx_theta2.cos()[:, None, :]
140                self.sin_cached = idx_theta2.sin()[:, None, :]それらをキャッシュする
143            self.cos_cached = self.cos_cached.to(x.dtype)
144            self.sin_cached = self.sin_cached.to(x.dtype)機能を分割してください。RoPE d_rope
 は機能にのみ適用されます
147        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]キャッシュから sin と cos の値を取得
150        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]162        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)RoPe 埋め込みに対応していなかった機能との連携
165        return torch.cat((x_rope, x_pass), dim=-1)168class AttentionLayer(nn.Module):n_hidden
埋め込みに含まれる機能の数n_heads
アテンション・ヘッドの数rope_percentage
RoPe 埋め込みを追加する機能の割合mask_fill
アテンション・マトリックスのマスキング・フィル値is_flash_attention
フラッシュアテンションを使用するかどうかを指定します173    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174                 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):183        super().__init__()
184
185        self.n_heads = n_heads
186        self.mask_fill = mask_fillクエリ、キー、値の線形レイヤー
189        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)最終線形レイヤー
191        self.output = nn.Linear(n_hidden, n_hidden)ヘッドあたりの機能数
194        d_k = n_hidden // n_headsRoPE 埋め込みモジュール
196        self.rope = RoPE(int(d_k * rope_percentage))アテンションスケーリングファクター
199        self.scale = 1 / math.sqrt(d_k)因果マスクをキャッシュするには
202        self.causal_mask = Noneアテンションソフトマックスモジュール
205        self.softmax = nn.Softmax(dim=-2)208        if is_flash_attention:
209            try:
210                from flash_attn.flash_attention import FlashAttention
211                self.flash_attention = FlashAttention()
212            except ImportError:
213                logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214                           'Falling back to normal attention', Text.warning)
215                self.flash_attention = None
216        else:
217            self.flash_attention = None219    def _get_mask(self, attn: torch.Tensor):クエリとキーの長さ
227        nq, nk = attn.shape[1:3]マスク作成
230        if (
231                self.causal_mask is None or
232                self.causal_mask.shape[0] != nq or
233                self.causal_mask.shape[1] != nk or
234                self.causal_mask.device != attn.device
235        ):
236            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)キャッシュから戻る
239        return self.causal_mask[None, :, :, None]x
形がある [batch_size, seq_len, n_hidden]
241    def forward(self, x: torch.Tensor):247        qkv = self.qkv_lin(x)形状を以下のように変更して頭部に分割します [batch_size, seq_len, n_heads, 3 * d_k]
250        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)形状ごとにクエリ、キー、値に分割 [batch_size, seq_len, n_heads, 3 * d_k]
252        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)以前のトークンの状態をキャッシュする場合
255        if get_cache().get('use_cache', False):ステート ID を取得します。前のステートを取得したり、次のステートを保存したりするのに使います。
257            prev_state_id, next_state_id = get_cache().get('state_ids')キャッシュがある場合
259            if prev_state_id is not None:過去のキーと値を取得します。これらは形になります [batch_size, prev_seq_len, n_heads, d_k]
261                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')現在の埋め込みのオフセット
263                offset = k_past.shape[1]RoPe 埋め込みを追加
266                q = self.rope(q, offset=offset)
267                k = self.rope(k, offset=offset)過去を連結する
270                k = torch.cat([k_past, k], dim=1)
271                v = torch.cat([v_past, v], dim=1)
272            else:RoPe 埋め込みを追加
274                q = self.rope(q)
275                k = self.rope(k)現在の状態を保存する
278            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279        else:キャッシュなし-RoPE 埋め込みを追加するだけ
281            q = self.rope(q)
282            k = self.rope(k)フラッシュアテンションを使う
285        if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286            output = self.compute_flash_attention(q, k, v)それ以外の場合は、通常の注意を払ってください
288        else:
289            output = self.compute_attention(q, k, v)[batch_size, seq_len, n_heads, d_k] to
バッチサイズ、シーケンス番号、n_hidden `から形状を変更
292        output = output.reshape(*x.shape)最終線形レイヤー
295        return self.output(output)297    def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):それらを積み重ねて形を整える [batch_size, seq_len, 3, n_heads, d_k]
299        qkv = torch.stack((q, k, v), dim=2)
300        d_k = qkv.shape[-1]
301        if d_k <= 32:
302            pad = 32 - d_k
303        elif d_k <= 64:
304            pad = 64 - d_k
305        elif d_k <= 128:
306            pad = 128 - d_k
307        else:
308            raise ValueError(f'Head size {d_k} too large for flash attention')
309
310        if pad > 0:
311            qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313        output, _ = self.flash_attention(qkv, causal=True)出力は整形しています [batch_size, seq_len, n_heads, d_k + padding]
315        output = output[:, :, :, :d_k]
316
317        return output319    def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):アテンション計算の fp16 への自動キャストを無効にする
321        with autocast(enabled=False):
322            if q.dtype == torch.float16:現在の dtype が fp16 の場合は fp32 に変換
324                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325            else:bfloatにはキャストしないでください
327                attn = torch.einsum('bihk,bjhk->bijh', q, k)スケールアテンション
330            attn = attn * self.scaleカジュアルマスクをゲット
333            mask = self._get_mask(attn)マスクを適用
335            attn.masked_fill_(mask, self.mask_fill)注意ソフトマックス
338            attn = self.softmax(attn)アテンション加重値を取得
341        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343        return output346class FFNLayer(nn.Module):n_hidden
は埋め込みサイズ351    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):355        super().__init__()
356
357        if not d_ff:
358            d_ff = n_hidden * 4拡張リニアレイヤー
361        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)GELU アクティベーション
363        self.activation = nn.GELU()収縮線状層
365        self.dense_h4_h = nn.Linear(d_ff, n_hidden)x
形がある [batch_size, seq_len, n_hidden]
367    def forward(self, x: torch.Tensor):371        x = self.dense_h_h4(x)
372        x = self.activation(x)
373        x = self.dense_h4_h(x)
374
375        return x378class TransformerLayer(NeoXModule):n_hidden
は埋め込みサイズn_heads
は頭の数ですis_flash_attention
フラッシュアテンションを使用するかどうかを指定しますアウトの実装にはドロップアウトは含まれていません。
383    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):392        super().__init__()注意前のレイヤー正規化
395        self.pre_ln_attn = nn.LayerNorm(n_hidden)FFN 前のレイヤー正規化
397        self.pre_ln_ffn = nn.LayerNorm(n_hidden)アテンションレイヤー
400        self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)FFN レイヤー
402        self.ffn = FFNLayer(n_hidden)x
形が埋め込まれているものです [batch_size, seq_len, n_hidden]
404    def forward(self, x: torch.Tensor):残留接続
410        residual = xNeoXはアテンションネットワークとフィードフォワードネットワークを並行して実行します
412        attn = self.attention(self.pre_ln_attn(x))
413        ffn = self.ffn(self.pre_ln_ffn(x))それらと残りの接続を追加します
415        return attn + ffn + residualチェックポイントをロードするコード
417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):421        with monit.section('Load transformer layer'):アテンション出力変換
423            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)アテンションクエリ、キー、値の変換
427            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)注目される前のレイヤーノルム
431            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)FFN 2 番目のトランスフォーム
435            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)FFN ファーストトランスフォーム
439            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)FFN 前のレイヤーノルム
443            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)447class FinalNorm(NeoXModule):n_hidden
は埋め込みサイズ452    def __init__(self, n_hidden: int = 6_144):456        super().__init__()
457
458        self.ln = nn.LayerNorm(n_hidden)x
形が埋め込まれているものです [batch_size, seq_len, n_hidden]
460    def forward(self, x: torch.Tensor):464        return self.ln(x)チェックポイントをロードするコード
466    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):470        with monit.section('Load final normalization layer'):
471            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)読み出し層
475class ReadoutLayer(NeoXModule):n_hidden
は埋め込みサイズn_vocab
ボキャブラリーの大きさです480    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):485        super().__init__()
486
487        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)x
形が埋め込まれているものです [batch_size, seq_len, n_hidden]
489    def forward(self, x: torch.Tensor):493        return self.linear(x)チェックポイントをロードするコード
495    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):499        with monit.section('Load final linear layer'):
500            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)503class LayerGenerator:
504    pre_created_layers: Dict[Any, Optional[NeoXModule]]レイヤーはチェックポイントと同じ順序で生成されます。
None
レイヤーが使用できない場合に返されます。レイヤーインデックスをNeoXとして使用し、実装には必要のない変換レイヤーが2つあります。
n_vocab
ボキャブラリ内のトークンの数ですn_hidden
は埋め込み内のフィーチャの数ですn_layers
変圧器層の数ですn_heads
アテンション・ヘッドの数ですfilter_layers
使用するレイヤーのセットです。None の場合はすべてのレイヤーが使用されます。これは、レイヤー数の少ないモデルの小さいバージョンをテストする場合に使用しますis_clone_layers
トランスフォーマーレイヤーのクローンを作成するかどうかを指定します (少し速くなります)dtype
モデルのデータ型ですdevice
モデルのデバイスですis_llm_int8
int8 量子化を使用するかどうかを指定しますllm_int8_threshold
外れ値の特徴を分離するための閾値ですis_flash_attention
フラッシュアテンションを使用するかどうかを指定します506    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507                 n_layers: int = 44, n_heads: int = 64,
508                 filter_layers: Optional[Set] = None,
509                 is_clone_layers: bool = True,
510                 dtype: torch.dtype = torch.float,
511                 device: torch.device = torch.device('cpu'),
512                 is_llm_int8: bool = False,
513                 llm_int8_threshold: float = 6.0,
514                 is_flash_attention: bool = False
515                 ):538        if filter_layers is None:
539            filter_layers = set(range(n_layers + 3))
540
541        self.n_vocab = n_vocab
542        self.n_hidden = n_hidden
543        self.n_layers = n_layers
544        self.n_heads = n_heads
545        self.filter_layers = filter_layers
546        self.is_clone_layers = is_clone_layers
547        self.dtype = dtype
548        self.device = device
549        self.is_llm_int8 = is_llm_int8
550        self.llm_int8_threshold = llm_int8_threshold
551        self.is_flash_attention = is_flash_attention
552
553        self.pre_created_layers = dict(
554            transformer_layer=None,
555        )557    def _prepare_layer(self, layer: NeoXModule):566        return layer.to(self.device, self.dtype)この関数は、チェックポイントを読み込んだ後にレイヤー変換を実装します。
現在、適用されるのは int8 量子化のみです。
layer
準備するレイヤーですis_llm_int8
int8 量子化を使用するかどうかを指定しますdevice
モデルのデバイスですllm_int8_threshold
外れ値の特徴を分離するための閾値です準備したレイヤーを返します
568    @torch.no_grad()
569    def post_load_prepare(self, layer: NeoXModule, *,
570                          is_llm_int8: bool = None,
571                          device: torch.device = None,
572                          llm_int8_threshold: float = None,
573                          ):指定しない場合はデフォルト値を取得
591        if is_llm_int8 is None:
592            is_llm_int8 = self.is_llm_int8
593        if device is None:
594            device = self.device
595        if llm_int8_threshold is None:
596            llm_int8_threshold = self.llm_int8_thresholdint8 量子化を使用しない場合はスキップ
599        if not is_llm_int8:
600            return layerトランスレイヤーの線形レイヤーのみを変換します
603        if not isinstance(layer, TransformerLayer):
604            return layermake_llm_int8_linear
ユーティリティで定義されている用途。
607        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear線形レイヤーの変換
610        with monit.section('Convert to int8'):
611            layer.attention.output = make_llm_int8_linear(layer.attention.output,
612                                                          device=device,
613                                                          threshold=llm_int8_threshold)
614            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615                                                           device=device,
616                                                           threshold=llm_int8_threshold)
617            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618                                                        device=device,
619                                                        threshold=llm_int8_threshold)
620            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621                                                        device=device,
622                                                        threshold=llm_int8_threshold)624        return layerキャッシュされたレイヤーのコピーは、パラメーターの初期化に時間がかかるため、新しいレイヤーを初期化するよりも高速です。
name
レイヤーの名前ですcreator
レイヤーを作成する関数です作成されたレイヤーまたはキャッシュされたレイヤーのコピーを返します
626    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):638        if not self.is_clone_layers:
639            return self._prepare_layer(creator())
640
641        if self.pre_created_layers[name] is None:
642            self.pre_created_layers[name] = self._prepare_layer(creator())
643
644        layer = copy.deepcopy(self.pre_created_layers[name])
645        return layer647    def _create_transformer_layer(self):
648        return self._create_and_cache_layer(
649            'transformer_layer',
650            lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651        )653    def _create_embedding_layer(self):
654        return Embedding(self.n_vocab, self.n_hidden)656    def _create_final_norm_layer(self):
657        return FinalNorm(self.n_hidden)659    def _create_readout_layer(self):
660        return ReadoutLayer(self.n_hidden, self.n_vocab)662    @torch.no_grad()
663    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:埋め込みレイヤー
668        if 0 in self.filter_layers:
669            with monit.section('Embedding layer'):
670                layer = self._prepare_layer(self._create_embedding_layer())
671            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')トランスフォーマー層
674        for i in range(self.n_layers):変圧器層
676            if i + 1 in self.filter_layers:
677                with monit.section(f'Transformer Layer {i}'):
678                    yield self._create_transformer_layer(), \
679                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680                           f'layer_{i + 2 :02d}-model_01-model_states.pt')最終正規化レイヤー
683        if self.n_layers + 1 in self.filter_layers:
684            with monit.section('Final norm layer'):
685                layer = self._prepare_layer(self._create_final_norm_layer())
686            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')読み出し層
689        if self.n_layers + 2 in self.filter_layers:
690            with monit.section('Readout layer'):
691                layer = self._prepare_layer(self._create_readout_layer())
692            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694        for k in self.pre_created_layers.keys():
695            self.pre_created_layers[k] = None697    @property
698    def total_layers(self):702        return self.n_layers + 3704    @torch.no_grad()
705    def load(self) -> Generator[NeoXModule, None, None]:709        with monit.section("Layers"):
710            for i, (layer, files) in enumerate(self.get_layers()):
711                if files is not None:
712                    layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714                layer = self.post_load_prepare(layer)
715
716                monit.progress(min(0.99, (i + 1) / self.total_layers))
717                yield layer