GPT-NeoX Model

Here is the code for layers of GPT-NeoX model and the code to load 20B checkpoint.

The method load_state in the layers load the checkpoints of that layer. The checkpoint loading helpers are on checkpoint.py

16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit
25from labml_nn.neox import checkpoint
26from labml_nn.neox.utils.cache import get_cache
29class NeoXModule(nn.Module):
30    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
31        pass

Embedding layer

This is a standard embeddings layer with code to load the checkpoint.

34class Embedding(NeoXModule):
  • n_vocab is the size of the vocabulary
  • n_hidden is the size of the embeddings
41    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
46        super().__init__()
47
48        self.emb = nn.Embedding(n_vocab, n_hidden)
  • x are the token ids of shape [batch_size, seq_len]
50    def forward(self, x: torch.Tensor):
54        return self.emb(x)

Code to load the checkpoint

56    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
60        with monit.section('Load embedding layer'):
61            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)

Rotary Positional Embeddings

GPT-NeoX uses rotary positional embeddings (RoPE).

WE have annotated implementation of RoPE here with more notes the theory.

64class RoPE(nn.Module):
  • d_rope is the number of features for RoPE embeddings
  • base is the base for , which defaults to
74    def __init__(self, d_rope: int, base: float = 10_000.):
79        super().__init__()

To store for the features

82        self.theta = None

Cache and

84        self.cos_cached = None
85        self.sin_cached = None

Base for

88        self.base = base

Number of features for RoPE

90        self.d_rope = d_rope

Rotate the features

92    @staticmethod
93    def rotate_half(x: torch.Tensor):
99        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
100        return torch.cat((-x2, x1), dim=-1)
  • x has shape [..., seq, n_heads, d_k]
  • offset is the starting position of x . This is when we have cached the keys and queries of previous positions
102    def forward(self, x: torch.Tensor, offset: int = 0):

Get the actual sequence length

110        seq_len = x.shape[-3] + offset

Initialize

113        if self.theta is None:

115            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
116            self.theta = theta.to(x.device).to(x.dtype)

Initialize and cache

119        if (
120                self.cos_cached is None or
121                seq_len > self.cos_cached.shape[1] or
122                self.cos_cached.device != x.device or
123                self.cos_cached.dtype != x.dtype
124        ):

Get position indexes

126            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

128            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)

Concatenate so that for row we have

132            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)

Calculate and in fp32

135            with autocast(enabled=False):
136                idx_theta2 = idx_theta2.float()

Add head dimension

138                self.cos_cached = idx_theta2.cos()[:, None, :]
139                self.sin_cached = idx_theta2.sin()[:, None, :]

Cache them

142            self.cos_cached = self.cos_cached.to(x.dtype)
143            self.sin_cached = self.sin_cached.to(x.dtype)

Split the features. We apply RoPE to only d_rope features

146        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]

Get the sin and cos values from the cache

149        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]

RoPE embeddings

for

161        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)

Concatenate with features that didn't get RoPE embeddings

164        return torch.cat((x_rope, x_pass), dim=-1)

Attention layer

167class AttentionLayer(nn.Module):
  • n_hidden the number of features in embeddings
  • n_heads the number of attention heads
  • rope_percentage percentage of features to add RoPE embeddings
  • mask_fill masking fill value for attention matrix
172    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
173                 mask_fill: float = -10_000.0):
180        super().__init__()
181
182        self.n_heads = n_heads
183        self.mask_fill = mask_fill

Linear layer for query, key and value

186        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)

Final linear layer

188        self.output = nn.Linear(n_hidden, n_hidden)

Number of features per head

191        d_k = n_hidden // n_heads

RoPE embedding module

193        self.rope = RoPE(int(d_k * rope_percentage))

Attention scaling factor

196        self.scale = 1 / math.sqrt(d_k)

To cache causal mask

199        self.causal_mask = None

Attention softmax module

202        self.softmax = nn.Softmax(dim=-2)

Calculate the causal mask

204    def _get_mask(self, attn: torch.Tensor):

Query and key lengths

212        nq, nk = attn.shape[1:3]

Create mask

215        if (
216                self.causal_mask is None or
217                self.causal_mask.shape[0] != nq or
218                self.causal_mask.shape[1] != nk or
219                self.causal_mask.device != attn.device
220        ):
221            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)

Return from cache

224        return self.causal_mask[None, :, :, None]
  • x has shape [batch_size, seq_len, n_hidden]
226    def forward(self, x: torch.Tensor):

Get query, key and value embeddings (all concatenated). The last dimension size will change from n_hidden -> 3 x n_hidden

232        qkv = self.qkv_lin(x)

Split into heads by changing the shape to [batch_size, seq_len, n_heads, 3 * d_k]

235        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)

Split into query, key and value each of shape [batch_size, seq_len, n_heads, 3 * d_k]

237        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)

If we are caching the states of previous tokens

240        if get_cache().get('use_cache', False):

Get the state id's. We use to retrieve previous states and store the next states

242            prev_state_id, next_state_id = get_cache().get('state_ids')

If there's cache

244            if prev_state_id is not None:

Get the past keys and values. These will have shape [batch_size, prev_seq_len, n_heads, d_k]

246                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')

Offset of the current embeddings

248                offset = k_past.shape[1]

Add RoPE embeddings

251                q = self.rope(q, offset=offset)
252                k = self.rope(k, offset=offset)

Concatenate the past

255                k = torch.cat([k_past, k], dim=1)
256                v = torch.cat([v_past, v], dim=1)
257            else:

Add RoPE embeddings

259                q = self.rope(q)
260                k = self.rope(k)

Save the current state

263            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
264        else:

No cache - simply add RoPE embeddings

266            q = self.rope(q)
267            k = self.rope(k)

Disable auto-casting to fp16 for attention computation

270        with autocast(enabled=False):
271            if q.dtype == torch.float16:

Convert to fp32 if the current dtype is fp16

273                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
274            else:

Do not cast for bfloat

276                attn = torch.einsum('bihk,bjhk->bijh', q, k)

Scale attention

279            attn = attn * self.scale

Get causal mask

282            mask = self._get_mask(attn)

Apply mask

284            attn.masked_fill_(mask, self.mask_fill)

Attention softmax

287            attn = self.softmax(attn)

Get attention weighted values

290        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)

Reshape from [batch_size, seq_len, n_heads, d_k] to batch_size, seq_len, n_hidden`

293        output = output.reshape(*x.shape)

Final linear layer

296        return self.output(output)

Feedforward Network

299class FFNLayer(nn.Module):
  • n_hidden is the embedding size
304    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
308        super().__init__()
309
310        if not d_ff:
311            d_ff = n_hidden * 4

Expansion linear layer

314        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)

GELU activation

316        self.activation = nn.GELU()

Contraction linear layer

318        self.dense_h4_h = nn.Linear(d_ff, n_hidden)
  • x has shape [batch_size, seq_len, n_hidden]
320    def forward(self, x: torch.Tensor):
324        x = self.dense_h_h4(x)
325        x = self.activation(x)
326        x = self.dense_h4_h(x)
327
328        return x

Transformer Layer

331class TransformerLayer(NeoXModule):
  • n_hidden is the embedding size
  • n_heads is the number of heads

Out implementation doesn't include dropout.

336    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):
343        super().__init__()

Layer normalization before attention

346        self.pre_ln_attn = nn.LayerNorm(n_hidden)

Layer normalization before FFN

348        self.pre_ln_ffn = nn.LayerNorm(n_hidden)

Attention layer

351        self.attention = AttentionLayer(n_hidden, n_heads)

FFN layer

353        self.ffn = FFNLayer(n_hidden)
  • x are the embeddings of shape [batch_size, seq_len, n_hidden]
355    def forward(self, x: torch.Tensor):

Residual connection

361        residual = x

NeoX runs attention and feedforward network in parallel

363        attn = self.attention(self.pre_ln_attn(x))
364        ffn = self.ffn(self.pre_ln_ffn(x))

Add them and the residual connection

366        return attn + ffn + residual

Code to load the checkpoint

368    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
372        with monit.section('Load transformer layer'):

Attention output transform

374            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
375            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)

Attention query, key and value transform

378            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
379            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)

Layer norm before attention

382            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
383            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)

FFN second transform

386            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
387            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)

FFN first transform

390            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
391            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)

Layer norm before FFN

394            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
395            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)

Final normalization layer

398class FinalNorm(NeoXModule):
  • n_hidden is the embedding size
403    def __init__(self, n_hidden: int = 6_144):
407        super().__init__()
408
409        self.ln = nn.LayerNorm(n_hidden)
  • x are the embeddings of shape [batch_size, seq_len, n_hidden]
411    def forward(self, x: torch.Tensor):
415        return self.ln(x)

Code to load the checkpoint

417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421        with monit.section('Load final normalization layer'):
422            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
423            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)

Readout layer

426class ReadoutLayer(NeoXModule):
  • n_hidden is the embedding size
  • n_vocab is the size of the vocabulary
431    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
436        super().__init__()
437
438        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
  • x are the embeddings of shape [batch_size, seq_len, n_hidden]
440    def forward(self, x: torch.Tensor):
444        return self.linear(x)

Code to load the checkpoint

446    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
450        with monit.section('Load final linear layer'):
451            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
454class LayerGenerator:
455    pre_created_layers: Dict[Any, Optional[NeoXModule]]

Generator to create layers

The layers are generated in the same order as checkpoints.

It gives None when a layer is not available; we use the layer indices as NeoX and there are two transformation layers we don't need in our implementation.

  • n_vocab is the number of tokens in the vocabulary
  • n_hidden is the number of features in the embeddings
  • n_layers is the number of transformer layers
  • n_heads is the number of attention heads
  • filter_layers are the set of layers to be used. All layers will be used if None. This is used to test smaller versions of the model with fewer layers
  • is_clone_layers specifies whether to clone the transformer layers (a bit faster)
  • dtype is the data type of the model
  • device is the device of the model
  • is_llm_int8 specifies whether to use int8 quantization
  • llm_int8_threshold is the threshold used to separate outlier features
457    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
458                 n_layers: int = 44, n_heads: int = 64,
459                 filter_layers: Optional[Set] = None,
460                 is_clone_layers: bool = True,
461                 dtype: torch.dtype = torch.float,
462                 device: torch.device = torch.device('cpu'),
463                 is_llm_int8: bool = False,
464                 llm_int8_threshold: float = 6.0,
465                 ):
486        if filter_layers is None:
487            filter_layers = set(range(n_layers + 3))
488
489        self.n_vocab = n_vocab
490        self.n_hidden = n_hidden
491        self.n_layers = n_layers
492        self.n_heads = n_heads
493        self.filter_layers = filter_layers
494        self.is_clone_layers = is_clone_layers
495        self.dtype = dtype
496        self.device = device
497        self.is_llm_int8 = is_llm_int8
498        self.llm_int8_threshold = llm_int8_threshold
499
500        self.pre_created_layers = dict(
501            transformer_layer=None,
502        )

Prepares the layer for usage

We move the layer to the device and convert it to the correct data type

  • layer is the layer to prepare
  • Returns the prepared layer

504    def _prepare_layer(self, layer: NeoXModule):
513        return layer.to(self.device, self.dtype)

Layer transformations after loading the checkpoint

This function implements layer transformations after loading the checkpoint.

Currently, it only applies the int8 quantization.

  • layer is the layer to prepare
  • is_llm_int8 specifies whether to use int8 quantization
  • device is the device of the model
  • llm_int8_threshold is the threshold used to separate outlier features
  • Returns the prepared layer

515    @torch.no_grad()
516    def post_load_prepare(self, layer: NeoXModule, *,
517                          is_llm_int8: bool = None,
518                          device: torch.device = None,
519                          llm_int8_threshold: float = None,
520                          ):

Get default values if not specified

538        if is_llm_int8 is None:
539            is_llm_int8 = self.is_llm_int8
540        if device is None:
541            device = self.device
542        if llm_int8_threshold is None:
543            llm_int8_threshold = self.llm_int8_threshold

Skip if not using int8 quantization

546        if not is_llm_int8:
547            return layer

Only convert the linear layers in the transformer layers

550        if not isinstance(layer, TransformerLayer):
551            return layer

Use make_llm_int8_linear defined in utilities.

554        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear

Convert the linear layers

557        with monit.section('Convert to int8'):
558            layer.attention.output = make_llm_int8_linear(layer.attention.output,
559                                                          device=device,
560                                                          threshold=llm_int8_threshold)
561            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
562                                                           device=device,
563                                                           threshold=llm_int8_threshold)
564            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
565                                                        device=device,
566                                                        threshold=llm_int8_threshold)
567            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
568                                                        device=device,
569                                                        threshold=llm_int8_threshold)

571        return layer

Creates and caches a layer

Copying cached layers is faster than initializing new layers because it takes time to initialize parameters.

  • name is the name of the layer
  • creator is the function to create the layer
  • Returns the created layer or a copy of the cached layer

573    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
585        if not self.is_clone_layers:
586            return self._prepare_layer(creator())
587
588        if self.pre_created_layers[name] is None:
589            self.pre_created_layers[name] = self._prepare_layer(creator())
590
591        layer = copy.deepcopy(self.pre_created_layers[name])
592        return layer
594    def _create_transformer_layer(self):
595        return self._create_and_cache_layer(
596            'transformer_layer',
597            lambda: TransformerLayer(self.n_hidden, self.n_heads)
598        )
600    def _create_embedding_layer(self):
601        return Embedding(self.n_vocab, self.n_hidden)
603    def _create_final_norm_layer(self):
604        return FinalNorm(self.n_hidden)
606    def _create_readout_layer(self):
607        return ReadoutLayer(self.n_hidden, self.n_vocab)

Generator to get layers

609    @torch.no_grad()
610    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:

Embedding layer

615        if 0 in self.filter_layers:
616            with monit.section('Embedding layer'):
617                layer = self._prepare_layer(self._create_embedding_layer())
618            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')

Transformer layers

621        for i in range(self.n_layers):

Transformer layer

623            if i + 1 in self.filter_layers:
624                with monit.section(f'Transformer Layer {i}'):
625                    yield self._create_transformer_layer(), \
626                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
627                           f'layer_{i + 2 :02d}-model_01-model_states.pt')

Final normalization layer

630        if self.n_layers + 1 in self.filter_layers:
631            with monit.section('Final norm layer'):
632                layer = self._prepare_layer(self._create_final_norm_layer())
633            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')

Readout layer

636        if self.n_layers + 2 in self.filter_layers:
637            with monit.section('Readout layer'):
638                layer = self._prepare_layer(self._create_readout_layer())
639            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
640
641        for k in self.pre_created_layers.keys():
642            self.pre_created_layers[k] = None

Returns the total number of layers

644    @property
645    def total_layers(self):
649        return self.n_layers + 3

Generator to load layers

651    @torch.no_grad()
652    def load(self) -> Generator[NeoXModule, None, None]:
656        with monit.section("Layers"):
657            for i, (layer, files) in enumerate(self.get_layers()):
658                if files is not None:
659                    layer.load_state(*checkpoint.load_checkpoint_files(files))
660
661                layer = self.post_load_prepare(layer)
662
663                monit.progress(min(0.99, (i + 1) / self.total_layers))
664                yield layer