Here is the code for layers of GPT-NeoX model and the code to load 20B checkpoint.
The method load_state
 in the layers load the checkpoints of that layer. The checkpoint loading helpers are on checkpoint.py
16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit
25from labml_nn.neox import checkpoint
26from labml_nn.neox.utils.cache import get_cache29class NeoXModule(nn.Module):30    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
31        pass34class Embedding(NeoXModule):n_vocab
  is the size of the vocabulary n_hidden
  is the size of the embeddings41    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):46        super().__init__()
47
48        self.emb = nn.Embedding(n_vocab, n_hidden)x
  are the token ids of shape [batch_size, seq_len]
50    def forward(self, x: torch.Tensor):54        return self.emb(x)Code to load the checkpoint
56    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):60        with monit.section('Load embedding layer'):
61            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)GPT-NeoX uses rotary positional embeddings (RoPE).
WE have annotated implementation of RoPE here with more notes the theory.
64class RoPE(nn.Module):d_rope
  is the number of features for RoPE embeddings base
  is the base for , which defaults to 74    def __init__(self, d_rope: int, base: float = 10_000.):79        super().__init__()To store for the features
82        self.theta = NoneCache and
84        self.cos_cached = None
85        self.sin_cached = NoneBase for
88        self.base = baseNumber of features for RoPE
90        self.d_rope = d_rope92    @staticmethod
93    def rotate_half(x: torch.Tensor):99        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
100        return torch.cat((-x2, x1), dim=-1)x
  has shape [..., seq, n_heads, d_k]
 offset
  is the starting position of x
. This is  when we have cached the keys and queries of previous positions102    def forward(self, x: torch.Tensor, offset: int = 0):Get the actual sequence length
110        seq_len = x.shape[-3] + offsetInitialize
113        if self.theta is None:115            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
116            self.theta = theta.to(x.device).to(x.dtype)Initialize and cache
119        if (
120                self.cos_cached is None or
121                seq_len > self.cos_cached.shape[1] or
122                self.cos_cached.device != x.device or
123                self.cos_cached.dtype != x.dtype
124        ):Get position indexes
126            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)128            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)132            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)Calculate and in fp32
135            with autocast(enabled=False):
136                idx_theta2 = idx_theta2.float()Add head dimension
138                self.cos_cached = idx_theta2.cos()[:, None, :]
139                self.sin_cached = idx_theta2.sin()[:, None, :]Cache them
142            self.cos_cached = self.cos_cached.to(x.dtype)
143            self.sin_cached = self.sin_cached.to(x.dtype)Split the features. We apply RoPE to only d_rope
 features 
146        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]Get the sin and cos values from the cache
149        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]161        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)Concatenate with features that didn't get RoPE embeddings
164        return torch.cat((x_rope, x_pass), dim=-1)167class AttentionLayer(nn.Module):n_hidden
  the number of features in embeddings n_heads
  the number of attention heads rope_percentage
  percentage of features to add RoPE embeddings mask_fill
  masking fill value for attention matrix172    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
173                 mask_fill: float = -10_000.0):180        super().__init__()
181
182        self.n_heads = n_heads
183        self.mask_fill = mask_fillLinear layer for query, key and value
186        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)Final linear layer
188        self.output = nn.Linear(n_hidden, n_hidden)Number of features per head
191        d_k = n_hidden // n_headsRoPE embedding module
193        self.rope = RoPE(int(d_k * rope_percentage))Attention scaling factor
196        self.scale = 1 / math.sqrt(d_k)To cache causal mask
199        self.causal_mask = NoneAttention softmax module
202        self.softmax = nn.Softmax(dim=-2)204    def _get_mask(self, attn: torch.Tensor):Query and key lengths
212        nq, nk = attn.shape[1:3]Create mask
215        if (
216                self.causal_mask is None or
217                self.causal_mask.shape[0] != nq or
218                self.causal_mask.shape[1] != nk or
219                self.causal_mask.device != attn.device
220        ):
221            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)Return from cache
224        return self.causal_mask[None, :, :, None]x
  has shape [batch_size, seq_len, n_hidden]
226    def forward(self, x: torch.Tensor):Get query, key and value embeddings (all concatenated). The last dimension size will change from n_hidden -> 3 x n_hidden
 
232        qkv = self.qkv_lin(x)Split into heads by changing the shape to [batch_size, seq_len, n_heads, 3 * d_k]
 
235        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)Split into query, key and value each of shape [batch_size, seq_len, n_heads, 3 * d_k]
 
237        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)If we are caching the states of previous tokens
240        if get_cache().get('use_cache', False):Get the state id's. We use to retrieve previous states and store the next states
242            prev_state_id, next_state_id = get_cache().get('state_ids')If there's cache
244            if prev_state_id is not None:Get the past keys and values. These will have shape [batch_size, prev_seq_len, n_heads, d_k]
 
246                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')Offset of the current embeddings
248                offset = k_past.shape[1]Add RoPE embeddings
251                q = self.rope(q, offset=offset)
252                k = self.rope(k, offset=offset)Concatenate the past
255                k = torch.cat([k_past, k], dim=1)
256                v = torch.cat([v_past, v], dim=1)
257            else:Add RoPE embeddings
259                q = self.rope(q)
260                k = self.rope(k)Save the current state
263            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
264        else:No cache - simply add RoPE embeddings
266            q = self.rope(q)
267            k = self.rope(k)Disable auto-casting to fp16 for attention computation
270        with autocast(enabled=False):
271            if q.dtype == torch.float16:Convert to fp32 if the current dtype is fp16
273                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
274            else:Do not cast for bfloat
276                attn = torch.einsum('bihk,bjhk->bijh', q, k)Scale attention
279            attn = attn * self.scaleGet causal mask
282            mask = self._get_mask(attn)Apply mask
284            attn.masked_fill_(mask, self.mask_fill)Attention softmax
287            attn = self.softmax(attn)Get attention weighted values
290        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)Reshape from [batch_size, seq_len, n_heads, d_k] to
batch_size, seq_len, n_hidden` 
293        output = output.reshape(*x.shape)Final linear layer
296        return self.output(output)299class FFNLayer(nn.Module):n_hidden
  is the embedding size304    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):308        super().__init__()
309
310        if not d_ff:
311            d_ff = n_hidden * 4Expansion linear layer
314        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)GELU activation
316        self.activation = nn.GELU()Contraction linear layer
318        self.dense_h4_h = nn.Linear(d_ff, n_hidden)x
  has shape [batch_size, seq_len, n_hidden]
320    def forward(self, x: torch.Tensor):324        x = self.dense_h_h4(x)
325        x = self.activation(x)
326        x = self.dense_h4_h(x)
327
328        return x331class TransformerLayer(NeoXModule):n_hidden
  is the embedding size n_heads
  is the number of headsOut implementation doesn't include dropout.
336    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):343        super().__init__()Layer normalization before attention
346        self.pre_ln_attn = nn.LayerNorm(n_hidden)Layer normalization before FFN
348        self.pre_ln_ffn = nn.LayerNorm(n_hidden)Attention layer
351        self.attention = AttentionLayer(n_hidden, n_heads)FFN layer
353        self.ffn = FFNLayer(n_hidden)x
  are the embeddings of shape [batch_size, seq_len, n_hidden]
355    def forward(self, x: torch.Tensor):Residual connection
361        residual = xNeoX runs attention and feedforward network in parallel
363        attn = self.attention(self.pre_ln_attn(x))
364        ffn = self.ffn(self.pre_ln_ffn(x))Add them and the residual connection
366        return attn + ffn + residualCode to load the checkpoint
368    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):372        with monit.section('Load transformer layer'):Attention output transform
374            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
375            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)Attention query, key and value transform
378            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
379            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)Layer norm before attention
382            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
383            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)FFN second transform
386            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
387            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)FFN first transform
390            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
391            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)Layer norm before FFN
394            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
395            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)398class FinalNorm(NeoXModule):n_hidden
  is the embedding size403    def __init__(self, n_hidden: int = 6_144):407        super().__init__()
408
409        self.ln = nn.LayerNorm(n_hidden)x
  are the embeddings of shape [batch_size, seq_len, n_hidden]
411    def forward(self, x: torch.Tensor):415        return self.ln(x)Code to load the checkpoint
417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):421        with monit.section('Load final normalization layer'):
422            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
423            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)Readout layer
426class ReadoutLayer(NeoXModule):n_hidden
  is the embedding size n_vocab
  is the size of the vocabulary431    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):436        super().__init__()
437
438        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)x
  are the embeddings of shape [batch_size, seq_len, n_hidden]
440    def forward(self, x: torch.Tensor):444        return self.linear(x)Code to load the checkpoint
446    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):450        with monit.section('Load final linear layer'):
451            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)454class LayerGenerator:
455    pre_created_layers: Dict[Any, Optional[NeoXModule]]The layers are generated in the same order as checkpoints.
It gives None
 when a layer is not available; we use the layer indices as NeoX and there are two transformation layers we don't need in our implementation.
n_vocab
  is the number of tokens in the vocabulary n_hidden
  is the number of features in the embeddings n_layers
  is the number of transformer layers n_heads
  is the number of attention heads filter_layers
  are the set of layers to be used. All layers will be used if None.  This is used to test smaller versions of the model with fewer layers is_clone_layers
  specifies whether to clone the transformer layers (a bit faster) dtype
  is the data type of the model device
  is the device of the model is_llm_int8
  specifies whether to use int8 quantization llm_int8_threshold
  is the threshold  used to separate outlier features457    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
458                 n_layers: int = 44, n_heads: int = 64,
459                 filter_layers: Optional[Set] = None,
460                 is_clone_layers: bool = True,
461                 dtype: torch.dtype = torch.float,
462                 device: torch.device = torch.device('cpu'),
463                 is_llm_int8: bool = False,
464                 llm_int8_threshold: float = 6.0,
465                 ):486        if filter_layers is None:
487            filter_layers = set(range(n_layers + 3))
488
489        self.n_vocab = n_vocab
490        self.n_hidden = n_hidden
491        self.n_layers = n_layers
492        self.n_heads = n_heads
493        self.filter_layers = filter_layers
494        self.is_clone_layers = is_clone_layers
495        self.dtype = dtype
496        self.device = device
497        self.is_llm_int8 = is_llm_int8
498        self.llm_int8_threshold = llm_int8_threshold
499
500        self.pre_created_layers = dict(
501            transformer_layer=None,
502        )We move the layer to the device and convert it to the correct data type
layer
  is the layer to prepare Returns the prepared layer
504    def _prepare_layer(self, layer: NeoXModule):513        return layer.to(self.device, self.dtype)This function implements layer transformations after loading the checkpoint.
Currently, it only applies the int8 quantization.
layer
  is the layer to prepare is_llm_int8
  specifies whether to use int8 quantization device
  is the device of the model llm_int8_threshold
  is the threshold  used to separate outlier features Returns the prepared layer
515    @torch.no_grad()
516    def post_load_prepare(self, layer: NeoXModule, *,
517                          is_llm_int8: bool = None,
518                          device: torch.device = None,
519                          llm_int8_threshold: float = None,
520                          ):Get default values if not specified
538        if is_llm_int8 is None:
539            is_llm_int8 = self.is_llm_int8
540        if device is None:
541            device = self.device
542        if llm_int8_threshold is None:
543            llm_int8_threshold = self.llm_int8_thresholdSkip if not using int8 quantization
546        if not is_llm_int8:
547            return layerOnly convert the linear layers in the transformer layers
550        if not isinstance(layer, TransformerLayer):
551            return layer554        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linearConvert the linear layers
557        with monit.section('Convert to int8'):
558            layer.attention.output = make_llm_int8_linear(layer.attention.output,
559                                                          device=device,
560                                                          threshold=llm_int8_threshold)
561            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
562                                                           device=device,
563                                                           threshold=llm_int8_threshold)
564            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
565                                                        device=device,
566                                                        threshold=llm_int8_threshold)
567            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
568                                                        device=device,
569                                                        threshold=llm_int8_threshold)571        return layerCopying cached layers is faster than initializing new layers because it takes time to initialize parameters.
name
  is the name of the layer creator
  is the function to create the layer Returns the created layer or a copy of the cached layer
573    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):585        if not self.is_clone_layers:
586            return self._prepare_layer(creator())
587
588        if self.pre_created_layers[name] is None:
589            self.pre_created_layers[name] = self._prepare_layer(creator())
590
591        layer = copy.deepcopy(self.pre_created_layers[name])
592        return layer594    def _create_transformer_layer(self):
595        return self._create_and_cache_layer(
596            'transformer_layer',
597            lambda: TransformerLayer(self.n_hidden, self.n_heads)
598        )600    def _create_embedding_layer(self):
601        return Embedding(self.n_vocab, self.n_hidden)603    def _create_final_norm_layer(self):
604        return FinalNorm(self.n_hidden)606    def _create_readout_layer(self):
607        return ReadoutLayer(self.n_hidden, self.n_vocab)609    @torch.no_grad()
610    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:Embedding layer
615        if 0 in self.filter_layers:
616            with monit.section('Embedding layer'):
617                layer = self._prepare_layer(self._create_embedding_layer())
618            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')Transformer layers
621        for i in range(self.n_layers):Transformer layer
623            if i + 1 in self.filter_layers:
624                with monit.section(f'Transformer Layer {i}'):
625                    yield self._create_transformer_layer(), \
626                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
627                           f'layer_{i + 2 :02d}-model_01-model_states.pt')Final normalization layer
630        if self.n_layers + 1 in self.filter_layers:
631            with monit.section('Final norm layer'):
632                layer = self._prepare_layer(self._create_final_norm_layer())
633            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')Readout layer
636        if self.n_layers + 2 in self.filter_layers:
637            with monit.section('Readout layer'):
638                layer = self._prepare_layer(self._create_readout_layer())
639            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
640
641        for k in self.pre_created_layers.keys():
642            self.pre_created_layers[k] = None644    @property
645    def total_layers(self):649        return self.n_layers + 3651    @torch.no_grad()
652    def load(self) -> Generator[NeoXModule, None, None]:656        with monit.section("Layers"):
657            for i, (layer, files) in enumerate(self.get_layers()):
658                if files is not None:
659                    layer.load_state(*checkpoint.load_checkpoint_files(files))
660
661                layer = self.post_load_prepare(layer)
662
663                monit.progress(min(0.99, (i + 1) / self.total_layers))
664                yield layer