Files
Varuna Jayasiri 4860cc680b docs
2022-08-20 11:13:36 +05:30

665 lines
24 KiB
Python

"""
---
title: GPT-NeoX Model Definition
summary: >
This is the model definition of GPT-NeoX.
---
# GPT-NeoX Model
Here is the code for layers of GPT-NeoX model and the code to load
20B checkpoint.
The method `load_state` in the layers load the checkpoints of that layer.
The checkpoint loading helpers are on [`checkpoint.py`](checkpoint.html)
"""
import copy
import math
from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
import torch
from torch import nn
from torch.cuda.amp import autocast
from labml import monit
from labml_nn.neox import checkpoint
from labml_nn.neox.utils.cache import get_cache
class NeoXModule(nn.Module):
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
pass
class Embedding(NeoXModule):
"""
## Embedding layer
This is a standard embeddings layer with code to load the checkpoint.
"""
def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
"""
:param n_vocab: is the size of the vocabulary
:param n_hidden: is the size of the embeddings
"""
super().__init__()
self.emb = nn.Embedding(n_vocab, n_hidden)
def forward(self, x: torch.Tensor):
"""
:param x: are the token ids of shape `[batch_size, seq_len]`
"""
return self.emb(x)
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
"""
Code to load the checkpoint
"""
with monit.section('Load embedding layer'):
checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
class RoPE(nn.Module):
"""
## Rotary Positional Embeddings
GPT-NeoX uses [rotary positional embeddings (RoPE)](https://papers.labml.ai/paper/2104.09864).
WE have annotated implementation of RoPE [here](https://nn.labml.ai/transformers/rope/index.html)
with more notes the theory.
"""
def __init__(self, d_rope: int, base: float = 10_000.):
"""
:param d_rope: is the number of features for RoPE embeddings
:param base: is the base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$, which defaults to $10000$
"""
super().__init__()
# To store $\theta_i$ for the features
self.theta = None
# Cache $\cos m\theta_i$ and $\sin m\theta_i$
self.cos_cached = None
self.sin_cached = None
# Base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$
self.base = base
# Number of features for RoPE
self.d_rope = d_rope
@staticmethod
def rotate_half(x: torch.Tensor):
"""
### Rotate the features
$[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
"""
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def forward(self, x: torch.Tensor, offset: int = 0):
"""
:param x: has shape `[..., seq, n_heads, d_k]`
:param offset: is the starting position of `x`. This is $\gt 0$ when we have
cached the keys and queries of previous positions
"""
# Get the actual sequence length
seq_len = x.shape[-3] + offset
# Initialize $\theta$
if self.theta is None:
# $\theta_i = 10000^{\frac{2(i-1)}{d}}$
theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
self.theta = theta.to(x.device).to(x.dtype)
# Initialize $\cos m\theta_i$ and $\sin m\theta_i$ cache
if (
self.cos_cached is None or
seq_len > self.cos_cached.shape[1] or
self.cos_cached.device != x.device or
self.cos_cached.dtype != x.dtype
):
# Get position indexes $m$
seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
# $m \theta_i$
idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
# Concatenate so that for row $m$ we have
#
# $$[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$$
idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
# Calculate $\cos m\theta_i$ and $\sin m\theta_i$ in fp32
with autocast(enabled=False):
idx_theta2 = idx_theta2.float()
# Add head dimension
self.cos_cached = idx_theta2.cos()[:, None, :]
self.sin_cached = idx_theta2.sin()[:, None, :]
# Cache them
self.cos_cached = self.cos_cached.to(x.dtype)
self.sin_cached = self.sin_cached.to(x.dtype)
# Split the features. We apply RoPE to only `d_rope` features
x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
# Get the sin and cos values from the cache
cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
# RoPE embeddings
#
# \begin{align}
# \begin{pmatrix}
# x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
# x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
# \end{pmatrix} \\
# \end{align}
#
# for $i \in {1, 2, ..., \frac{d}{2}}$
x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
# Concatenate with features that didn't get RoPE embeddings
return torch.cat((x_rope, x_pass), dim=-1)
class AttentionLayer(nn.Module):
"""
## Attention layer
"""
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
mask_fill: float = -10_000.0):
"""
:param n_hidden: the number of features in embeddings
:param n_heads: the number of attention heads
:param rope_percentage: percentage of features to add RoPE embeddings
:param mask_fill: masking fill value for attention matrix
"""
super().__init__()
self.n_heads = n_heads
self.mask_fill = mask_fill
# Linear layer for query, key and value
self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
# Final linear layer
self.output = nn.Linear(n_hidden, n_hidden)
# Number of features per head
d_k = n_hidden // n_heads
# RoPE embedding module
self.rope = RoPE(int(d_k * rope_percentage))
# Attention scaling factor
self.scale = 1 / math.sqrt(d_k)
# To cache causal mask
self.causal_mask = None
# Attention softmax module
self.softmax = nn.Softmax(dim=-2)
def _get_mask(self, attn: torch.Tensor):
"""
#### Calculate the causal mask
* `attn` has shape [batch_size, query_seq_len, key_seq_len, n_heads]
"""
# Query and key lengths
nq, nk = attn.shape[1:3]
# Create mask
if (
self.causal_mask is None or
self.causal_mask.shape[0] != nq or
self.causal_mask.shape[1] != nk or
self.causal_mask.device != attn.device
):
self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
# Return from cache
return self.causal_mask[None, :, :, None]
def forward(self, x: torch.Tensor):
"""
:param x: has shape `[batch_size, seq_len, n_hidden]`
"""
# Get query, key and value embeddings (all concatenated).
# The last dimension size will change from n_hidden -> `3 x n_hidden`
qkv = self.qkv_lin(x)
# Split into heads by changing the shape to `[batch_size, seq_len, n_heads, 3 * d_k]`
qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
# Split into query, key and value each of shape `[batch_size, seq_len, n_heads, 3 * d_k]`
q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
# If we are caching the states of previous tokens
if get_cache().get('use_cache', False):
# Get the state id's. We use to retrieve previous states and store the next states
prev_state_id, next_state_id = get_cache().get('state_ids')
# If there's cache
if prev_state_id is not None:
# Get the past keys and values. These will have shape `[batch_size, prev_seq_len, n_heads, d_k]`
k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
# Offset of the current embeddings
offset = k_past.shape[1]
# Add RoPE embeddings
q = self.rope(q, offset=offset)
k = self.rope(k, offset=offset)
# Concatenate the past
k = torch.cat([k_past, k], dim=1)
v = torch.cat([v_past, v], dim=1)
else:
# Add RoPE embeddings
q = self.rope(q)
k = self.rope(k)
# Save the current state
get_cache().push(f'attn_kv_{next_state_id}', (k, v))
else:
# No cache - simply add RoPE embeddings
q = self.rope(q)
k = self.rope(k)
# Disable auto-casting to fp16 for attention computation
with autocast(enabled=False):
if q.dtype == torch.float16:
# Convert to fp32 if the current dtype is fp16
attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
else:
# Do not cast for bfloat
attn = torch.einsum('bihk,bjhk->bijh', q, k)
# Scale attention
attn = attn * self.scale
# Get causal mask
mask = self._get_mask(attn)
# Apply mask
attn.masked_fill_(mask, self.mask_fill)
# Attention softmax
attn = self.softmax(attn)
# Get attention weighted values
output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
# Reshape from `[batch_size, seq_len, n_heads, d_k] to `[batch_size, seq_len, n_hidden]`
output = output.reshape(*x.shape)
# Final linear layer
return self.output(output)
class FFNLayer(nn.Module):
"""
## Feedforward Network
"""
def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
"""
:param n_hidden: is the embedding size
"""
super().__init__()
if not d_ff:
d_ff = n_hidden * 4
# Expansion linear layer
self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
# GELU activation
self.activation = nn.GELU()
# Contraction linear layer
self.dense_h4_h = nn.Linear(d_ff, n_hidden)
def forward(self, x: torch.Tensor):
"""
:param x: has shape `[batch_size, seq_len, n_hidden]`
"""
x = self.dense_h_h4(x)
x = self.activation(x)
x = self.dense_h4_h(x)
return x
class TransformerLayer(NeoXModule):
"""
## Transformer Layer
"""
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):
"""
:param n_hidden: is the embedding size
:param n_heads: is the number of heads
*Out implementation doesn't include dropout*.
"""
super().__init__()
# Layer normalization before attention
self.pre_ln_attn = nn.LayerNorm(n_hidden)
# Layer normalization before FFN
self.pre_ln_ffn = nn.LayerNorm(n_hidden)
# Attention layer
self.attention = AttentionLayer(n_hidden, n_heads)
# FFN layer
self.ffn = FFNLayer(n_hidden)
def forward(self, x: torch.Tensor):
"""
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
"""
# Residual connection
residual = x
# NeoX runs attention and feedforward network in parallel
attn = self.attention(self.pre_ln_attn(x))
ffn = self.ffn(self.pre_ln_ffn(x))
# Add them and the residual connection
return attn + ffn + residual
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
"""
Code to load the checkpoint
"""
with monit.section('Load transformer layer'):
# Attention output transform
checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
# Attention query, key and value transform
checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
# Layer norm before attention
checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
# FFN second transform
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
# FFN first transform
checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
# Layer norm before FFN
checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
class FinalNorm(NeoXModule):
"""
## Final normalization layer
"""
def __init__(self, n_hidden: int = 6_144):
"""
:param n_hidden: is the embedding size
"""
super().__init__()
self.ln = nn.LayerNorm(n_hidden)
def forward(self, x: torch.Tensor):
"""
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
"""
return self.ln(x)
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
"""
Code to load the checkpoint
"""
with monit.section('Load final normalization layer'):
checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
class ReadoutLayer(NeoXModule):
"""
Readout layer
"""
def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
"""
:param n_hidden: is the embedding size
:param n_vocab: is the size of the vocabulary
"""
super().__init__()
self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
def forward(self, x: torch.Tensor):
"""
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
"""
return self.linear(x)
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
"""
Code to load the checkpoint
"""
with monit.section('Load final linear layer'):
checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
class LayerGenerator:
pre_created_layers: Dict[Any, Optional[NeoXModule]]
def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
n_layers: int = 44, n_heads: int = 64,
filter_layers: Optional[Set] = None,
is_clone_layers: bool = True,
dtype: torch.dtype = torch.float,
device: torch.device = torch.device('cpu'),
is_llm_int8: bool = False,
llm_int8_threshold: float = 6.0,
):
"""
### Generator to create layers
The layers are generated in the same order as checkpoints.
It gives `None` when a layer is not available; we use the layer indices as NeoX and there are two
transformation layers we don't need in our implementation.
:param n_vocab: is the number of tokens in the vocabulary
:param n_hidden: is the number of features in the embeddings
:param n_layers: is the number of transformer layers
:param n_heads: is the number of attention heads
:param filter_layers: are the set of layers to be used. All layers will be used if None.
This is used to test smaller versions of the model with fewer layers
:param is_clone_layers: specifies whether to clone the transformer layers (a bit faster)
:param dtype: is the data type of the model
:param device: is the device of the model
:param is_llm_int8: specifies whether to use int8 quantization
:param llm_int8_threshold: is the threshold $\alpha$ used to separate outlier features
"""
if filter_layers is None:
filter_layers = set(range(n_layers + 3))
self.n_vocab = n_vocab
self.n_hidden = n_hidden
self.n_layers = n_layers
self.n_heads = n_heads
self.filter_layers = filter_layers
self.is_clone_layers = is_clone_layers
self.dtype = dtype
self.device = device
self.is_llm_int8 = is_llm_int8
self.llm_int8_threshold = llm_int8_threshold
self.pre_created_layers = dict(
transformer_layer=None,
)
def _prepare_layer(self, layer: NeoXModule):
"""
#### Prepares the layer for usage
We move the layer to the device and convert it to the correct data type
:param layer: is the layer to prepare
:return: the prepared layer
"""
return layer.to(self.device, self.dtype)
@torch.no_grad()
def post_load_prepare(self, layer: NeoXModule, *,
is_llm_int8: bool = None,
device: torch.device = None,
llm_int8_threshold: float = None,
):
"""
<a id="post_load_prepare"></a>
### Layer transformations after loading the checkpoint
This function implements layer transformations after loading the checkpoint.
Currently, it only applies the int8 quantization.
:param layer: is the layer to prepare
:param is_llm_int8: specifies whether to use int8 quantization
:param device: is the device of the model
:param llm_int8_threshold: is the threshold $\alpha$ used to separate outlier features
:return: the prepared layer
"""
# Get default values if not specified
if is_llm_int8 is None:
is_llm_int8 = self.is_llm_int8
if device is None:
device = self.device
if llm_int8_threshold is None:
llm_int8_threshold = self.llm_int8_threshold
# Skip if not using int8 quantization
if not is_llm_int8:
return layer
# Only convert the linear layers in the transformer layers
if not isinstance(layer, TransformerLayer):
return layer
# Use `make_llm_int8_linear` defined in [utilities](./utils/llm_int8.html).
from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
# Convert the linear layers
with monit.section('Convert to int8'):
layer.attention.output = make_llm_int8_linear(layer.attention.output,
device=device,
threshold=llm_int8_threshold)
layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
device=device,
threshold=llm_int8_threshold)
layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
device=device,
threshold=llm_int8_threshold)
layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
device=device,
threshold=llm_int8_threshold)
#
return layer
def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
"""
#### Creates and caches a layer
Copying cached layers is faster than initializing new layers because it takes time to
initialize parameters.
:param name: is the name of the layer
:param creator: is the function to create the layer
:return: the created layer or a copy of the cached layer
"""
if not self.is_clone_layers:
return self._prepare_layer(creator())
if self.pre_created_layers[name] is None:
self.pre_created_layers[name] = self._prepare_layer(creator())
layer = copy.deepcopy(self.pre_created_layers[name])
return layer
def _create_transformer_layer(self):
return self._create_and_cache_layer(
'transformer_layer',
lambda: TransformerLayer(self.n_hidden, self.n_heads)
)
def _create_embedding_layer(self):
return Embedding(self.n_vocab, self.n_hidden)
def _create_final_norm_layer(self):
return FinalNorm(self.n_hidden)
def _create_readout_layer(self):
return ReadoutLayer(self.n_hidden, self.n_vocab)
@torch.no_grad()
def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
"""
### Generator to get layers
"""
# Embedding layer
if 0 in self.filter_layers:
with monit.section('Embedding layer'):
layer = self._prepare_layer(self._create_embedding_layer())
yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
# Transformer layers
for i in range(self.n_layers):
# Transformer layer
if i + 1 in self.filter_layers:
with monit.section(f'Transformer Layer {i}'):
yield self._create_transformer_layer(), \
(f'layer_{i + 2 :02d}-model_00-model_states.pt',
f'layer_{i + 2 :02d}-model_01-model_states.pt')
# Final normalization layer
if self.n_layers + 1 in self.filter_layers:
with monit.section('Final norm layer'):
layer = self._prepare_layer(self._create_final_norm_layer())
yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
# Readout layer
if self.n_layers + 2 in self.filter_layers:
with monit.section('Readout layer'):
layer = self._prepare_layer(self._create_readout_layer())
yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
for k in self.pre_created_layers.keys():
self.pre_created_layers[k] = None
@property
def total_layers(self):
"""
### Returns the total number of layers
"""
return self.n_layers + 3
@torch.no_grad()
def load(self) -> Generator[NeoXModule, None, None]:
"""
### Generator to load layers
"""
with monit.section("Layers"):
for i, (layer, files) in enumerate(self.get_layers()):
if files is not None:
layer.load_state(*checkpoint.load_checkpoint_files(files))
layer = self.post_load_prepare(layer)
monit.progress(min(0.99, (i + 1) / self.total_layers))
yield layer