Here is the code for layers of GPT-NeoX model and the code to load 20B checkpoint.
The method load_state
 in the layers load the checkpoints of that layer. The checkpoint loading helpers are on checkpoint.py
16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache30class NeoXModule(nn.Module):31    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32        pass35class Embedding(NeoXModule):n_vocab
  is the size of the vocabulary n_hidden
  is the size of the embeddings42    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):47        super().__init__()
48
49        self.emb = nn.Embedding(n_vocab, n_hidden)x
  are the token ids of shape [batch_size, seq_len]
51    def forward(self, x: torch.Tensor):55        return self.emb(x)Code to load the checkpoint
57    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):61        with monit.section('Load embedding layer'):
62            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)GPT-NeoX uses rotary positional embeddings (RoPE).
WE have annotated implementation of RoPE here with more notes the theory.
65class RoPE(nn.Module):d_rope
  is the number of features for RoPE embeddings base
  is the base for , which defaults to 75    def __init__(self, d_rope: int, base: float = 10_000.):80        super().__init__()To store for the features
83        self.theta = NoneCache and
85        self.cos_cached = None
86        self.sin_cached = NoneBase for
89        self.base = baseNumber of features for RoPE
91        self.d_rope = d_rope93    @staticmethod
94    def rotate_half(x: torch.Tensor):100        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101        return torch.cat((-x2, x1), dim=-1)x
  has shape [..., seq, n_heads, d_k]
 offset
  is the starting position of x
. This is  when we have cached the keys and queries of previous positions103    def forward(self, x: torch.Tensor, offset: int = 0):Get the actual sequence length
111        seq_len = x.shape[-3] + offsetInitialize
114        if self.theta is None:116            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117            self.theta = theta.to(x.device).to(x.dtype)Initialize and cache
120        if (
121                self.cos_cached is None or
122                seq_len > self.cos_cached.shape[1] or
123                self.cos_cached.device != x.device or
124                self.cos_cached.dtype != x.dtype
125        ):Get position indexes
127            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)129            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)133            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)Calculate and in fp32
136            with autocast(enabled=False):
137                idx_theta2 = idx_theta2.float()Add head dimension
139                self.cos_cached = idx_theta2.cos()[:, None, :]
140                self.sin_cached = idx_theta2.sin()[:, None, :]Cache them
143            self.cos_cached = self.cos_cached.to(x.dtype)
144            self.sin_cached = self.sin_cached.to(x.dtype)Split the features. We apply RoPE to only d_rope
 features 
147        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]Get the sin and cos values from the cache
150        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]162        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)Concatenate with features that didn't get RoPE embeddings
165        return torch.cat((x_rope, x_pass), dim=-1)168class AttentionLayer(nn.Module):n_hidden
  the number of features in embeddings n_heads
  the number of attention heads rope_percentage
  percentage of features to add RoPE embeddings mask_fill
  masking fill value for attention matrix is_flash_attention
  specifies whether to use  FlashAttention173    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174                 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):183        super().__init__()
184
185        self.n_heads = n_heads
186        self.mask_fill = mask_fillLinear layer for query, key and value
189        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)Final linear layer
191        self.output = nn.Linear(n_hidden, n_hidden)Number of features per head
194        d_k = n_hidden // n_headsRoPE embedding module
196        self.rope = RoPE(int(d_k * rope_percentage))Attention scaling factor
199        self.scale = 1 / math.sqrt(d_k)To cache causal mask
202        self.causal_mask = NoneAttention softmax module
205        self.softmax = nn.Softmax(dim=-2)208        if is_flash_attention:
209            try:
210                from flash_attn.flash_attention import FlashAttention
211                self.flash_attention = FlashAttention()
212            except ImportError:
213                logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214                           'Falling back to normal attention', Text.warning)
215                self.flash_attention = None
216        else:
217            self.flash_attention = None219    def _get_mask(self, attn: torch.Tensor):Query and key lengths
227        nq, nk = attn.shape[1:3]Create mask
230        if (
231                self.causal_mask is None or
232                self.causal_mask.shape[0] != nq or
233                self.causal_mask.shape[1] != nk or
234                self.causal_mask.device != attn.device
235        ):
236            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)Return from cache
239        return self.causal_mask[None, :, :, None]x
  has shape [batch_size, seq_len, n_hidden]
241    def forward(self, x: torch.Tensor):Get query, key and value embeddings (all concatenated). The last dimension size will change from n_hidden -> 3 x n_hidden
 
247        qkv = self.qkv_lin(x)Split into heads by changing the shape to [batch_size, seq_len, n_heads, 3 * d_k]
 
250        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)Split into query, key and value each of shape [batch_size, seq_len, n_heads, 3 * d_k]
 
252        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)If we are caching the states of previous tokens
255        if get_cache().get('use_cache', False):Get the state id's. We use to retrieve previous states and store the next states
257            prev_state_id, next_state_id = get_cache().get('state_ids')If there's cache
259            if prev_state_id is not None:Get the past keys and values. These will have shape [batch_size, prev_seq_len, n_heads, d_k]
 
261                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')Offset of the current embeddings
263                offset = k_past.shape[1]Add RoPE embeddings
266                q = self.rope(q, offset=offset)
267                k = self.rope(k, offset=offset)Concatenate the past
270                k = torch.cat([k_past, k], dim=1)
271                v = torch.cat([v_past, v], dim=1)
272            else:Add RoPE embeddings
274                q = self.rope(q)
275                k = self.rope(k)Save the current state
278            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279        else:No cache - simply add RoPE embeddings
281            q = self.rope(q)
282            k = self.rope(k)Use flash attention
285        if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286            output = self.compute_flash_attention(q, k, v)Otherwise, use normal attention
288        else:
289            output = self.compute_attention(q, k, v)Reshape from [batch_size, seq_len, n_heads, d_k] to
batch_size, seq_len, n_hidden` 
292        output = output.reshape(*x.shape)Final linear layer
295        return self.output(output)297    def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):Stack them into shape [batch_size, seq_len, 3, n_heads, d_k]
 
299        qkv = torch.stack((q, k, v), dim=2)
300        d_k = qkv.shape[-1]
301        if d_k <= 32:
302            pad = 32 - d_k
303        elif d_k <= 64:
304            pad = 64 - d_k
305        elif d_k <= 128:
306            pad = 128 - d_k
307        else:
308            raise ValueError(f'Head size {d_k} too large for flash attention')
309
310        if pad > 0:
311            qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313        output, _ = self.flash_attention(qkv, causal=True)The output is of shape [batch_size, seq_len, n_heads, d_k + padding]
 
315        output = output[:, :, :, :d_k]
316
317        return output319    def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):Disable auto-casting to fp16 for attention computation
321        with autocast(enabled=False):
322            if q.dtype == torch.float16:Convert to fp32 if the current dtype is fp16
324                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325            else:Do not cast for bfloat
327                attn = torch.einsum('bihk,bjhk->bijh', q, k)Scale attention
330            attn = attn * self.scaleGet causal mask
333            mask = self._get_mask(attn)Apply mask
335            attn.masked_fill_(mask, self.mask_fill)Attention softmax
338            attn = self.softmax(attn)Get attention weighted values
341        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343        return output346class FFNLayer(nn.Module):n_hidden
  is the embedding size351    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):355        super().__init__()
356
357        if not d_ff:
358            d_ff = n_hidden * 4Expansion linear layer
361        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)GELU activation
363        self.activation = nn.GELU()Contraction linear layer
365        self.dense_h4_h = nn.Linear(d_ff, n_hidden)x
  has shape [batch_size, seq_len, n_hidden]
367    def forward(self, x: torch.Tensor):371        x = self.dense_h_h4(x)
372        x = self.activation(x)
373        x = self.dense_h4_h(x)
374
375        return x378class TransformerLayer(NeoXModule):n_hidden
  is the embedding size n_heads
  is the number of heads is_flash_attention
  specifies whether to use  FlashAttentionOut implementation doesn't include dropout.
383    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):392        super().__init__()Layer normalization before attention
395        self.pre_ln_attn = nn.LayerNorm(n_hidden)Layer normalization before FFN
397        self.pre_ln_ffn = nn.LayerNorm(n_hidden)Attention layer
400        self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)FFN layer
402        self.ffn = FFNLayer(n_hidden)x
  are the embeddings of shape [batch_size, seq_len, n_hidden]
404    def forward(self, x: torch.Tensor):Residual connection
410        residual = xNeoX runs attention and feedforward network in parallel
412        attn = self.attention(self.pre_ln_attn(x))
413        ffn = self.ffn(self.pre_ln_ffn(x))Add them and the residual connection
415        return attn + ffn + residualCode to load the checkpoint
417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):421        with monit.section('Load transformer layer'):Attention output transform
423            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)Attention query, key and value transform
427            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)Layer norm before attention
431            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)FFN second transform
435            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)FFN first transform
439            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)Layer norm before FFN
443            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)447class FinalNorm(NeoXModule):n_hidden
  is the embedding size452    def __init__(self, n_hidden: int = 6_144):456        super().__init__()
457
458        self.ln = nn.LayerNorm(n_hidden)x
  are the embeddings of shape [batch_size, seq_len, n_hidden]
460    def forward(self, x: torch.Tensor):464        return self.ln(x)Code to load the checkpoint
466    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):470        with monit.section('Load final normalization layer'):
471            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)Readout layer
475class ReadoutLayer(NeoXModule):n_hidden
  is the embedding size n_vocab
  is the size of the vocabulary480    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):485        super().__init__()
486
487        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)x
  are the embeddings of shape [batch_size, seq_len, n_hidden]
489    def forward(self, x: torch.Tensor):493        return self.linear(x)Code to load the checkpoint
495    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):499        with monit.section('Load final linear layer'):
500            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)503class LayerGenerator:
504    pre_created_layers: Dict[Any, Optional[NeoXModule]]The layers are generated in the same order as checkpoints.
It gives None
 when a layer is not available; we use the layer indices as NeoX and there are two transformation layers we don't need in our implementation.
n_vocab
  is the number of tokens in the vocabulary n_hidden
  is the number of features in the embeddings n_layers
  is the number of transformer layers n_heads
  is the number of attention heads filter_layers
  are the set of layers to be used. All layers will be used if None.  This is used to test smaller versions of the model with fewer layers is_clone_layers
  specifies whether to clone the transformer layers (a bit faster) dtype
  is the data type of the model device
  is the device of the model is_llm_int8
  specifies whether to use int8 quantization llm_int8_threshold
  is the threshold  used to separate outlier features is_flash_attention
  specifies whether to use  FlashAttention506    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507                 n_layers: int = 44, n_heads: int = 64,
508                 filter_layers: Optional[Set] = None,
509                 is_clone_layers: bool = True,
510                 dtype: torch.dtype = torch.float,
511                 device: torch.device = torch.device('cpu'),
512                 is_llm_int8: bool = False,
513                 llm_int8_threshold: float = 6.0,
514                 is_flash_attention: bool = False
515                 ):538        if filter_layers is None:
539            filter_layers = set(range(n_layers + 3))
540
541        self.n_vocab = n_vocab
542        self.n_hidden = n_hidden
543        self.n_layers = n_layers
544        self.n_heads = n_heads
545        self.filter_layers = filter_layers
546        self.is_clone_layers = is_clone_layers
547        self.dtype = dtype
548        self.device = device
549        self.is_llm_int8 = is_llm_int8
550        self.llm_int8_threshold = llm_int8_threshold
551        self.is_flash_attention = is_flash_attention
552
553        self.pre_created_layers = dict(
554            transformer_layer=None,
555        )We move the layer to the device and convert it to the correct data type
layer
  is the layer to prepare Returns the prepared layer
557    def _prepare_layer(self, layer: NeoXModule):566        return layer.to(self.device, self.dtype)This function implements layer transformations after loading the checkpoint.
Currently, it only applies the int8 quantization.
layer
  is the layer to prepare is_llm_int8
  specifies whether to use int8 quantization device
  is the device of the model llm_int8_threshold
  is the threshold  used to separate outlier features Returns the prepared layer
568    @torch.no_grad()
569    def post_load_prepare(self, layer: NeoXModule, *,
570                          is_llm_int8: bool = None,
571                          device: torch.device = None,
572                          llm_int8_threshold: float = None,
573                          ):Get default values if not specified
591        if is_llm_int8 is None:
592            is_llm_int8 = self.is_llm_int8
593        if device is None:
594            device = self.device
595        if llm_int8_threshold is None:
596            llm_int8_threshold = self.llm_int8_thresholdSkip if not using int8 quantization
599        if not is_llm_int8:
600            return layerOnly convert the linear layers in the transformer layers
603        if not isinstance(layer, TransformerLayer):
604            return layer607        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linearConvert the linear layers
610        with monit.section('Convert to int8'):
611            layer.attention.output = make_llm_int8_linear(layer.attention.output,
612                                                          device=device,
613                                                          threshold=llm_int8_threshold)
614            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615                                                           device=device,
616                                                           threshold=llm_int8_threshold)
617            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618                                                        device=device,
619                                                        threshold=llm_int8_threshold)
620            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621                                                        device=device,
622                                                        threshold=llm_int8_threshold)624        return layerCopying cached layers is faster than initializing new layers because it takes time to initialize parameters.
name
  is the name of the layer creator
  is the function to create the layer Returns the created layer or a copy of the cached layer
626    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):638        if not self.is_clone_layers:
639            return self._prepare_layer(creator())
640
641        if self.pre_created_layers[name] is None:
642            self.pre_created_layers[name] = self._prepare_layer(creator())
643
644        layer = copy.deepcopy(self.pre_created_layers[name])
645        return layer647    def _create_transformer_layer(self):
648        return self._create_and_cache_layer(
649            'transformer_layer',
650            lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651        )653    def _create_embedding_layer(self):
654        return Embedding(self.n_vocab, self.n_hidden)656    def _create_final_norm_layer(self):
657        return FinalNorm(self.n_hidden)659    def _create_readout_layer(self):
660        return ReadoutLayer(self.n_hidden, self.n_vocab)662    @torch.no_grad()
663    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:Embedding layer
668        if 0 in self.filter_layers:
669            with monit.section('Embedding layer'):
670                layer = self._prepare_layer(self._create_embedding_layer())
671            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')Transformer layers
674        for i in range(self.n_layers):Transformer layer
676            if i + 1 in self.filter_layers:
677                with monit.section(f'Transformer Layer {i}'):
678                    yield self._create_transformer_layer(), \
679                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680                           f'layer_{i + 2 :02d}-model_01-model_states.pt')Final normalization layer
683        if self.n_layers + 1 in self.filter_layers:
684            with monit.section('Final norm layer'):
685                layer = self._prepare_layer(self._create_final_norm_layer())
686            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')Readout layer
689        if self.n_layers + 2 in self.filter_layers:
690            with monit.section('Readout layer'):
691                layer = self._prepare_layer(self._create_readout_layer())
692            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694        for k in self.pre_created_layers.keys():
695            self.pre_created_layers[k] = None697    @property
698    def total_layers(self):702        return self.n_layers + 3704    @torch.no_grad()
705    def load(self) -> Generator[NeoXModule, None, None]:709        with monit.section("Layers"):
710            for i, (layer, files) in enumerate(self.get_layers()):
711                if files is not None:
712                    layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714                layer = self.post_load_prepare(layer)
715
716                monit.progress(min(0.99, (i + 1) / self.total_layers))
717                yield layer