Here is the code for layers of GPT-NeoX model and the code to load 20B checkpoint.
The method load_state
in the layers load the checkpoints of that layer. The checkpoint loading helpers are on checkpoint.py
16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit
25from labml_nn.neox import checkpoint
26from labml_nn.neox.utils.cache import get_cache
29class NeoXModule(nn.Module):
30 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
31 pass
34class Embedding(NeoXModule):
n_vocab
is the size of the vocabulary n_hidden
is the size of the embeddings41 def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
46 super().__init__()
47
48 self.emb = nn.Embedding(n_vocab, n_hidden)
x
are the token ids of shape [batch_size, seq_len]
50 def forward(self, x: torch.Tensor):
54 return self.emb(x)
Code to load the checkpoint
56 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
60 with monit.section('Load embedding layer'):
61 checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
GPT-NeoX uses rotary positional embeddings (RoPE).
WE have annotated implementation of RoPE here with more notes the theory.
64class RoPE(nn.Module):
d_rope
is the number of features for RoPE embeddings base
is the base for , which defaults to 74 def __init__(self, d_rope: int, base: float = 10_000.):
79 super().__init__()
To store for the features
82 self.theta = None
Cache and
84 self.cos_cached = None
85 self.sin_cached = None
Base for
88 self.base = base
Number of features for RoPE
90 self.d_rope = d_rope
92 @staticmethod
93 def rotate_half(x: torch.Tensor):
99 x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
100 return torch.cat((-x2, x1), dim=-1)
x
has shape [..., seq, n_heads, d_k]
offset
is the starting position of x
. This is when we have cached the keys and queries of previous positions102 def forward(self, x: torch.Tensor, offset: int = 0):
Get the actual sequence length
110 seq_len = x.shape[-3] + offset
Initialize
113 if self.theta is None:
115 theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
116 self.theta = theta.to(x.device).to(x.dtype)
Initialize and cache
119 if (
120 self.cos_cached is None or
121 seq_len > self.cos_cached.shape[1] or
122 self.cos_cached.device != x.device or
123 self.cos_cached.dtype != x.dtype
124 ):
Get position indexes
126 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
128 idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
132 idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
Calculate and in fp32
135 with autocast(enabled=False):
136 idx_theta2 = idx_theta2.float()
Add head dimension
138 self.cos_cached = idx_theta2.cos()[:, None, :]
139 self.sin_cached = idx_theta2.sin()[:, None, :]
Cache them
142 self.cos_cached = self.cos_cached.to(x.dtype)
143 self.sin_cached = self.sin_cached.to(x.dtype)
Split the features. We apply RoPE to only d_rope
features
146 x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
Get the sin and cos values from the cache
149 cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
161 x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
Concatenate with features that didn't get RoPE embeddings
164 return torch.cat((x_rope, x_pass), dim=-1)
167class AttentionLayer(nn.Module):
n_hidden
the number of features in embeddings n_heads
the number of attention heads rope_percentage
percentage of features to add RoPE embeddings mask_fill
masking fill value for attention matrix172 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
173 mask_fill: float = -10_000.0):
180 super().__init__()
181
182 self.n_heads = n_heads
183 self.mask_fill = mask_fill
Linear layer for query, key and value
186 self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
Final linear layer
188 self.output = nn.Linear(n_hidden, n_hidden)
Number of features per head
191 d_k = n_hidden // n_heads
RoPE embedding module
193 self.rope = RoPE(int(d_k * rope_percentage))
Attention scaling factor
196 self.scale = 1 / math.sqrt(d_k)
To cache causal mask
199 self.causal_mask = None
Attention softmax module
202 self.softmax = nn.Softmax(dim=-2)
204 def _get_mask(self, attn: torch.Tensor):
Query and key lengths
212 nq, nk = attn.shape[1:3]
Create mask
215 if (
216 self.causal_mask is None or
217 self.causal_mask.shape[0] != nq or
218 self.causal_mask.shape[1] != nk or
219 self.causal_mask.device != attn.device
220 ):
221 self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
Return from cache
224 return self.causal_mask[None, :, :, None]
x
has shape [batch_size, seq_len, n_hidden]
226 def forward(self, x: torch.Tensor):
Get query, key and value embeddings (all concatenated). The last dimension size will change from n_hidden -> 3 x n_hidden
232 qkv = self.qkv_lin(x)
Split into heads by changing the shape to [batch_size, seq_len, n_heads, 3 * d_k]
235 qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
Split into query, key and value each of shape [batch_size, seq_len, n_heads, 3 * d_k]
237 q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
If we are caching the states of previous tokens
240 if get_cache().get('use_cache', False):
Get the state id's. We use to retrieve previous states and store the next states
242 prev_state_id, next_state_id = get_cache().get('state_ids')
If there's cache
244 if prev_state_id is not None:
Get the past keys and values. These will have shape [batch_size, prev_seq_len, n_heads, d_k]
246 k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
Offset of the current embeddings
248 offset = k_past.shape[1]
Add RoPE embeddings
251 q = self.rope(q, offset=offset)
252 k = self.rope(k, offset=offset)
Concatenate the past
255 k = torch.cat([k_past, k], dim=1)
256 v = torch.cat([v_past, v], dim=1)
257 else:
Add RoPE embeddings
259 q = self.rope(q)
260 k = self.rope(k)
Save the current state
263 get_cache().push(f'attn_kv_{next_state_id}', (k, v))
264 else:
No cache - simply add RoPE embeddings
266 q = self.rope(q)
267 k = self.rope(k)
Disable auto-casting to fp16 for attention computation
270 with autocast(enabled=False):
271 if q.dtype == torch.float16:
Convert to fp32 if the current dtype is fp16
273 attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
274 else:
Do not cast for bfloat
276 attn = torch.einsum('bihk,bjhk->bijh', q, k)
Scale attention
279 attn = attn * self.scale
Get causal mask
282 mask = self._get_mask(attn)
Apply mask
284 attn.masked_fill_(mask, self.mask_fill)
Attention softmax
287 attn = self.softmax(attn)
Get attention weighted values
290 output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
Reshape from [batch_size, seq_len, n_heads, d_k] to
batch_size, seq_len, n_hidden`
293 output = output.reshape(*x.shape)
Final linear layer
296 return self.output(output)
299class FFNLayer(nn.Module):
n_hidden
is the embedding size304 def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
308 super().__init__()
309
310 if not d_ff:
311 d_ff = n_hidden * 4
Expansion linear layer
314 self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
GELU activation
316 self.activation = nn.GELU()
Contraction linear layer
318 self.dense_h4_h = nn.Linear(d_ff, n_hidden)
x
has shape [batch_size, seq_len, n_hidden]
320 def forward(self, x: torch.Tensor):
324 x = self.dense_h_h4(x)
325 x = self.activation(x)
326 x = self.dense_h4_h(x)
327
328 return x
331class TransformerLayer(NeoXModule):
n_hidden
is the embedding size n_heads
is the number of headsOut implementation doesn't include dropout.
336 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):
343 super().__init__()
Layer normalization before attention
346 self.pre_ln_attn = nn.LayerNorm(n_hidden)
Layer normalization before FFN
348 self.pre_ln_ffn = nn.LayerNorm(n_hidden)
Attention layer
351 self.attention = AttentionLayer(n_hidden, n_heads)
FFN layer
353 self.ffn = FFNLayer(n_hidden)
x
are the embeddings of shape [batch_size, seq_len, n_hidden]
355 def forward(self, x: torch.Tensor):
Residual connection
361 residual = x
NeoX runs attention and feedforward network in parallel
363 attn = self.attention(self.pre_ln_attn(x))
364 ffn = self.ffn(self.pre_ln_ffn(x))
Add them and the residual connection
366 return attn + ffn + residual
Code to load the checkpoint
368 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
372 with monit.section('Load transformer layer'):
Attention output transform
374 checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
375 checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
Attention query, key and value transform
378 checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
379 checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
Layer norm before attention
382 checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
383 checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
FFN second transform
386 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
387 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
FFN first transform
390 checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
391 checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
Layer norm before FFN
394 checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
395 checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
398class FinalNorm(NeoXModule):
n_hidden
is the embedding size403 def __init__(self, n_hidden: int = 6_144):
407 super().__init__()
408
409 self.ln = nn.LayerNorm(n_hidden)
x
are the embeddings of shape [batch_size, seq_len, n_hidden]
411 def forward(self, x: torch.Tensor):
415 return self.ln(x)
Code to load the checkpoint
417 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421 with monit.section('Load final normalization layer'):
422 checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
423 checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
Readout layer
426class ReadoutLayer(NeoXModule):
n_hidden
is the embedding size n_vocab
is the size of the vocabulary431 def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
436 super().__init__()
437
438 self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
x
are the embeddings of shape [batch_size, seq_len, n_hidden]
440 def forward(self, x: torch.Tensor):
444 return self.linear(x)
Code to load the checkpoint
446 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
450 with monit.section('Load final linear layer'):
451 checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
454class LayerGenerator:
455 pre_created_layers: Dict[Any, Optional[NeoXModule]]
The layers are generated in the same order as checkpoints.
It gives None
when a layer is not available; we use the layer indices as NeoX and there are two transformation layers we don't need in our implementation.
n_vocab
is the number of tokens in the vocabulary n_hidden
is the number of features in the embeddings n_layers
is the number of transformer layers n_heads
is the number of attention heads filter_layers
are the set of layers to be used. All layers will be used if None. This is used to test smaller versions of the model with fewer layers is_clone_layers
specifies whether to clone the transformer layers (a bit faster) dtype
is the data type of the model device
is the device of the model is_llm_int8
specifies whether to use int8 quantization llm_int8_threshold
is the threshold used to separate outlier features457 def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
458 n_layers: int = 44, n_heads: int = 64,
459 filter_layers: Optional[Set] = None,
460 is_clone_layers: bool = True,
461 dtype: torch.dtype = torch.float,
462 device: torch.device = torch.device('cpu'),
463 is_llm_int8: bool = False,
464 llm_int8_threshold: float = 6.0,
465 ):
486 if filter_layers is None:
487 filter_layers = set(range(n_layers + 3))
488
489 self.n_vocab = n_vocab
490 self.n_hidden = n_hidden
491 self.n_layers = n_layers
492 self.n_heads = n_heads
493 self.filter_layers = filter_layers
494 self.is_clone_layers = is_clone_layers
495 self.dtype = dtype
496 self.device = device
497 self.is_llm_int8 = is_llm_int8
498 self.llm_int8_threshold = llm_int8_threshold
499
500 self.pre_created_layers = dict(
501 transformer_layer=None,
502 )
We move the layer to the device and convert it to the correct data type
layer
is the layer to prepare Returns the prepared layer
504 def _prepare_layer(self, layer: NeoXModule):
513 return layer.to(self.device, self.dtype)
This function implements layer transformations after loading the checkpoint.
Currently, it only applies the int8 quantization.
layer
is the layer to prepare is_llm_int8
specifies whether to use int8 quantization device
is the device of the model llm_int8_threshold
is the threshold used to separate outlier features Returns the prepared layer
515 @torch.no_grad()
516 def post_load_prepare(self, layer: NeoXModule, *,
517 is_llm_int8: bool = None,
518 device: torch.device = None,
519 llm_int8_threshold: float = None,
520 ):
Get default values if not specified
538 if is_llm_int8 is None:
539 is_llm_int8 = self.is_llm_int8
540 if device is None:
541 device = self.device
542 if llm_int8_threshold is None:
543 llm_int8_threshold = self.llm_int8_threshold
Skip if not using int8 quantization
546 if not is_llm_int8:
547 return layer
Only convert the linear layers in the transformer layers
550 if not isinstance(layer, TransformerLayer):
551 return layer
554 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
Convert the linear layers
557 with monit.section('Convert to int8'):
558 layer.attention.output = make_llm_int8_linear(layer.attention.output,
559 device=device,
560 threshold=llm_int8_threshold)
561 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
562 device=device,
563 threshold=llm_int8_threshold)
564 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
565 device=device,
566 threshold=llm_int8_threshold)
567 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
568 device=device,
569 threshold=llm_int8_threshold)
571 return layer
Copying cached layers is faster than initializing new layers because it takes time to initialize parameters.
name
is the name of the layer creator
is the function to create the layer Returns the created layer or a copy of the cached layer
573 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
585 if not self.is_clone_layers:
586 return self._prepare_layer(creator())
587
588 if self.pre_created_layers[name] is None:
589 self.pre_created_layers[name] = self._prepare_layer(creator())
590
591 layer = copy.deepcopy(self.pre_created_layers[name])
592 return layer
594 def _create_transformer_layer(self):
595 return self._create_and_cache_layer(
596 'transformer_layer',
597 lambda: TransformerLayer(self.n_hidden, self.n_heads)
598 )
600 def _create_embedding_layer(self):
601 return Embedding(self.n_vocab, self.n_hidden)
603 def _create_final_norm_layer(self):
604 return FinalNorm(self.n_hidden)
606 def _create_readout_layer(self):
607 return ReadoutLayer(self.n_hidden, self.n_vocab)
609 @torch.no_grad()
610 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
Embedding layer
615 if 0 in self.filter_layers:
616 with monit.section('Embedding layer'):
617 layer = self._prepare_layer(self._create_embedding_layer())
618 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
Transformer layers
621 for i in range(self.n_layers):
Transformer layer
623 if i + 1 in self.filter_layers:
624 with monit.section(f'Transformer Layer {i}'):
625 yield self._create_transformer_layer(), \
626 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
627 f'layer_{i + 2 :02d}-model_01-model_states.pt')
Final normalization layer
630 if self.n_layers + 1 in self.filter_layers:
631 with monit.section('Final norm layer'):
632 layer = self._prepare_layer(self._create_final_norm_layer())
633 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
Readout layer
636 if self.n_layers + 2 in self.filter_layers:
637 with monit.section('Readout layer'):
638 layer = self._prepare_layer(self._create_readout_layer())
639 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
640
641 for k in self.pre_created_layers.keys():
642 self.pre_created_layers[k] = None
644 @property
645 def total_layers(self):
649 return self.n_layers + 3
651 @torch.no_grad()
652 def load(self) -> Generator[NeoXModule, None, None]:
656 with monit.section("Layers"):
657 for i, (layer, files) in enumerate(self.get_layers()):
658 if files is not None:
659 layer.load_state(*checkpoint.load_checkpoint_files(files))
660
661 layer = self.post_load_prepare(layer)
662
663 monit.progress(min(0.99, (i + 1) / self.total_layers))
664 yield layer