mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 18:27:20 +08:00
665 lines
24 KiB
Python
665 lines
24 KiB
Python
"""
|
|
---
|
|
title: GPT-NeoX Model Definition
|
|
summary: >
|
|
This is the model definition of GPT-NeoX.
|
|
---
|
|
|
|
# GPT-NeoX Model
|
|
|
|
Here is the code for layers of GPT-NeoX model and the code to load
|
|
20B checkpoint.
|
|
|
|
The method `load_state` in the layers load the checkpoints of that layer.
|
|
The checkpoint loading helpers are on [`checkpoint.py`](checkpoint.html)
|
|
"""
|
|
import copy
|
|
import math
|
|
from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.cuda.amp import autocast
|
|
|
|
from labml import monit
|
|
from labml_nn.neox import checkpoint
|
|
from labml_nn.neox.utils.cache import get_cache
|
|
|
|
|
|
class NeoXModule(nn.Module):
|
|
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
|
pass
|
|
|
|
|
|
class Embedding(NeoXModule):
|
|
"""
|
|
## Embedding layer
|
|
|
|
This is a standard embeddings layer with code to load the checkpoint.
|
|
"""
|
|
|
|
def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
|
|
"""
|
|
:param n_vocab: is the size of the vocabulary
|
|
:param n_hidden: is the size of the embeddings
|
|
"""
|
|
super().__init__()
|
|
|
|
self.emb = nn.Embedding(n_vocab, n_hidden)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""
|
|
:param x: are the token ids of shape `[batch_size, seq_len]`
|
|
"""
|
|
return self.emb(x)
|
|
|
|
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
|
"""
|
|
Code to load the checkpoint
|
|
"""
|
|
with monit.section('Load embedding layer'):
|
|
checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
|
|
|
|
|
|
class RoPE(nn.Module):
|
|
"""
|
|
## Rotary Positional Embeddings
|
|
|
|
GPT-NeoX uses [rotary positional embeddings (RoPE)](https://papers.labml.ai/paper/2104.09864).
|
|
|
|
WE have annotated implementation of RoPE [here](https://nn.labml.ai/transformers/rope/index.html)
|
|
with more notes the theory.
|
|
"""
|
|
|
|
def __init__(self, d_rope: int, base: float = 10_000.):
|
|
"""
|
|
:param d_rope: is the number of features for RoPE embeddings
|
|
:param base: is the base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$, which defaults to $10000$
|
|
"""
|
|
super().__init__()
|
|
|
|
# To store $\theta_i$ for the features
|
|
self.theta = None
|
|
# Cache $\cos m\theta_i$ and $\sin m\theta_i$
|
|
self.cos_cached = None
|
|
self.sin_cached = None
|
|
|
|
# Base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$
|
|
self.base = base
|
|
# Number of features for RoPE
|
|
self.d_rope = d_rope
|
|
|
|
@staticmethod
|
|
def rotate_half(x: torch.Tensor):
|
|
"""
|
|
### Rotate the features
|
|
|
|
$[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
|
|
"""
|
|
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
def forward(self, x: torch.Tensor, offset: int = 0):
|
|
"""
|
|
:param x: has shape `[..., seq, n_heads, d_k]`
|
|
:param offset: is the starting position of `x`. This is $\gt 0$ when we have
|
|
cached the keys and queries of previous positions
|
|
"""
|
|
|
|
# Get the actual sequence length
|
|
seq_len = x.shape[-3] + offset
|
|
|
|
# Initialize $\theta$
|
|
if self.theta is None:
|
|
# $\theta_i = 10000^{\frac{2(i-1)}{d}}$
|
|
theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
|
|
self.theta = theta.to(x.device).to(x.dtype)
|
|
|
|
# Initialize $\cos m\theta_i$ and $\sin m\theta_i$ cache
|
|
if (
|
|
self.cos_cached is None or
|
|
seq_len > self.cos_cached.shape[1] or
|
|
self.cos_cached.device != x.device or
|
|
self.cos_cached.dtype != x.dtype
|
|
):
|
|
# Get position indexes $m$
|
|
seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
|
|
# $m \theta_i$
|
|
idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
|
|
# Concatenate so that for row $m$ we have
|
|
#
|
|
# $$[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$$
|
|
idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
|
|
|
|
# Calculate $\cos m\theta_i$ and $\sin m\theta_i$ in fp32
|
|
with autocast(enabled=False):
|
|
idx_theta2 = idx_theta2.float()
|
|
# Add head dimension
|
|
self.cos_cached = idx_theta2.cos()[:, None, :]
|
|
self.sin_cached = idx_theta2.sin()[:, None, :]
|
|
|
|
# Cache them
|
|
self.cos_cached = self.cos_cached.to(x.dtype)
|
|
self.sin_cached = self.sin_cached.to(x.dtype)
|
|
|
|
# Split the features. We apply RoPE to only `d_rope` features
|
|
x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
|
|
|
|
# Get the sin and cos values from the cache
|
|
cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
|
|
|
|
# RoPE embeddings
|
|
#
|
|
# \begin{align}
|
|
# \begin{pmatrix}
|
|
# x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
|
|
# x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
|
|
# \end{pmatrix} \\
|
|
# \end{align}
|
|
#
|
|
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
|
x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
|
|
|
|
# Concatenate with features that didn't get RoPE embeddings
|
|
return torch.cat((x_rope, x_pass), dim=-1)
|
|
|
|
|
|
class AttentionLayer(nn.Module):
|
|
"""
|
|
## Attention layer
|
|
"""
|
|
|
|
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
|
|
mask_fill: float = -10_000.0):
|
|
"""
|
|
:param n_hidden: the number of features in embeddings
|
|
:param n_heads: the number of attention heads
|
|
:param rope_percentage: percentage of features to add RoPE embeddings
|
|
:param mask_fill: masking fill value for attention matrix
|
|
"""
|
|
super().__init__()
|
|
|
|
self.n_heads = n_heads
|
|
self.mask_fill = mask_fill
|
|
|
|
# Linear layer for query, key and value
|
|
self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
|
|
# Final linear layer
|
|
self.output = nn.Linear(n_hidden, n_hidden)
|
|
|
|
# Number of features per head
|
|
d_k = n_hidden // n_heads
|
|
# RoPE embedding module
|
|
self.rope = RoPE(int(d_k * rope_percentage))
|
|
|
|
# Attention scaling factor
|
|
self.scale = 1 / math.sqrt(d_k)
|
|
|
|
# To cache causal mask
|
|
self.causal_mask = None
|
|
|
|
# Attention softmax module
|
|
self.softmax = nn.Softmax(dim=-2)
|
|
|
|
def _get_mask(self, attn: torch.Tensor):
|
|
"""
|
|
#### Calculate the causal mask
|
|
|
|
* `attn` has shape [batch_size, query_seq_len, key_seq_len, n_heads]
|
|
"""
|
|
|
|
# Query and key lengths
|
|
nq, nk = attn.shape[1:3]
|
|
|
|
# Create mask
|
|
if (
|
|
self.causal_mask is None or
|
|
self.causal_mask.shape[0] != nq or
|
|
self.causal_mask.shape[1] != nk or
|
|
self.causal_mask.device != attn.device
|
|
):
|
|
self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
|
|
|
|
# Return from cache
|
|
return self.causal_mask[None, :, :, None]
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""
|
|
:param x: has shape `[batch_size, seq_len, n_hidden]`
|
|
"""
|
|
# Get query, key and value embeddings (all concatenated).
|
|
# The last dimension size will change from n_hidden -> `3 x n_hidden`
|
|
qkv = self.qkv_lin(x)
|
|
|
|
# Split into heads by changing the shape to `[batch_size, seq_len, n_heads, 3 * d_k]`
|
|
qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
|
|
# Split into query, key and value each of shape `[batch_size, seq_len, n_heads, 3 * d_k]`
|
|
q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
|
|
|
|
# If we are caching the states of previous tokens
|
|
if get_cache().get('use_cache', False):
|
|
# Get the state id's. We use to retrieve previous states and store the next states
|
|
prev_state_id, next_state_id = get_cache().get('state_ids')
|
|
# If there's cache
|
|
if prev_state_id is not None:
|
|
# Get the past keys and values. These will have shape `[batch_size, prev_seq_len, n_heads, d_k]`
|
|
k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
|
|
# Offset of the current embeddings
|
|
offset = k_past.shape[1]
|
|
|
|
# Add RoPE embeddings
|
|
q = self.rope(q, offset=offset)
|
|
k = self.rope(k, offset=offset)
|
|
|
|
# Concatenate the past
|
|
k = torch.cat([k_past, k], dim=1)
|
|
v = torch.cat([v_past, v], dim=1)
|
|
else:
|
|
# Add RoPE embeddings
|
|
q = self.rope(q)
|
|
k = self.rope(k)
|
|
|
|
# Save the current state
|
|
get_cache().push(f'attn_kv_{next_state_id}', (k, v))
|
|
else:
|
|
# No cache - simply add RoPE embeddings
|
|
q = self.rope(q)
|
|
k = self.rope(k)
|
|
|
|
# Disable auto-casting to fp16 for attention computation
|
|
with autocast(enabled=False):
|
|
if q.dtype == torch.float16:
|
|
# Convert to fp32 if the current dtype is fp16
|
|
attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
|
|
else:
|
|
# Do not cast for bfloat
|
|
attn = torch.einsum('bihk,bjhk->bijh', q, k)
|
|
|
|
# Scale attention
|
|
attn = attn * self.scale
|
|
|
|
# Get causal mask
|
|
mask = self._get_mask(attn)
|
|
# Apply mask
|
|
attn.masked_fill_(mask, self.mask_fill)
|
|
|
|
# Attention softmax
|
|
attn = self.softmax(attn)
|
|
|
|
# Get attention weighted values
|
|
output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
|
|
|
|
# Reshape from `[batch_size, seq_len, n_heads, d_k] to `[batch_size, seq_len, n_hidden]`
|
|
output = output.reshape(*x.shape)
|
|
|
|
# Final linear layer
|
|
return self.output(output)
|
|
|
|
|
|
class FFNLayer(nn.Module):
|
|
"""
|
|
## Feedforward Network
|
|
"""
|
|
|
|
def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
|
|
"""
|
|
:param n_hidden: is the embedding size
|
|
"""
|
|
super().__init__()
|
|
|
|
if not d_ff:
|
|
d_ff = n_hidden * 4
|
|
|
|
# Expansion linear layer
|
|
self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
|
|
# GELU activation
|
|
self.activation = nn.GELU()
|
|
# Contraction linear layer
|
|
self.dense_h4_h = nn.Linear(d_ff, n_hidden)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""
|
|
:param x: has shape `[batch_size, seq_len, n_hidden]`
|
|
"""
|
|
x = self.dense_h_h4(x)
|
|
x = self.activation(x)
|
|
x = self.dense_h4_h(x)
|
|
|
|
return x
|
|
|
|
|
|
class TransformerLayer(NeoXModule):
|
|
"""
|
|
## Transformer Layer
|
|
"""
|
|
|
|
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):
|
|
"""
|
|
:param n_hidden: is the embedding size
|
|
:param n_heads: is the number of heads
|
|
|
|
*Out implementation doesn't include dropout*.
|
|
"""
|
|
super().__init__()
|
|
|
|
# Layer normalization before attention
|
|
self.pre_ln_attn = nn.LayerNorm(n_hidden)
|
|
# Layer normalization before FFN
|
|
self.pre_ln_ffn = nn.LayerNorm(n_hidden)
|
|
|
|
# Attention layer
|
|
self.attention = AttentionLayer(n_hidden, n_heads)
|
|
# FFN layer
|
|
self.ffn = FFNLayer(n_hidden)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""
|
|
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
|
|
"""
|
|
|
|
# Residual connection
|
|
residual = x
|
|
# NeoX runs attention and feedforward network in parallel
|
|
attn = self.attention(self.pre_ln_attn(x))
|
|
ffn = self.ffn(self.pre_ln_ffn(x))
|
|
# Add them and the residual connection
|
|
return attn + ffn + residual
|
|
|
|
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
|
"""
|
|
Code to load the checkpoint
|
|
"""
|
|
with monit.section('Load transformer layer'):
|
|
# Attention output transform
|
|
checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
|
|
checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
|
|
|
|
# Attention query, key and value transform
|
|
checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
|
|
checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
|
|
|
|
# Layer norm before attention
|
|
checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
|
|
checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
|
|
|
|
# FFN second transform
|
|
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
|
|
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
|
|
|
|
# FFN first transform
|
|
checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
|
|
checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
|
|
|
|
# Layer norm before FFN
|
|
checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
|
|
checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
|
|
|
|
|
|
class FinalNorm(NeoXModule):
|
|
"""
|
|
## Final normalization layer
|
|
"""
|
|
|
|
def __init__(self, n_hidden: int = 6_144):
|
|
"""
|
|
:param n_hidden: is the embedding size
|
|
"""
|
|
super().__init__()
|
|
|
|
self.ln = nn.LayerNorm(n_hidden)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""
|
|
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
|
|
"""
|
|
return self.ln(x)
|
|
|
|
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
|
"""
|
|
Code to load the checkpoint
|
|
"""
|
|
with monit.section('Load final normalization layer'):
|
|
checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
|
|
checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
|
|
|
|
|
|
class ReadoutLayer(NeoXModule):
|
|
"""
|
|
Readout layer
|
|
"""
|
|
|
|
def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
|
|
"""
|
|
:param n_hidden: is the embedding size
|
|
:param n_vocab: is the size of the vocabulary
|
|
"""
|
|
super().__init__()
|
|
|
|
self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
"""
|
|
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
|
|
"""
|
|
return self.linear(x)
|
|
|
|
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
|
"""
|
|
Code to load the checkpoint
|
|
"""
|
|
with monit.section('Load final linear layer'):
|
|
checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
|
|
|
|
|
|
class LayerGenerator:
|
|
pre_created_layers: Dict[Any, Optional[NeoXModule]]
|
|
|
|
def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
|
|
n_layers: int = 44, n_heads: int = 64,
|
|
filter_layers: Optional[Set] = None,
|
|
is_clone_layers: bool = True,
|
|
dtype: torch.dtype = torch.float,
|
|
device: torch.device = torch.device('cpu'),
|
|
is_llm_int8: bool = False,
|
|
llm_int8_threshold: float = 6.0,
|
|
):
|
|
"""
|
|
### Generator to create layers
|
|
|
|
The layers are generated in the same order as checkpoints.
|
|
|
|
It gives `None` when a layer is not available; we use the layer indices as NeoX and there are two
|
|
transformation layers we don't need in our implementation.
|
|
|
|
:param n_vocab: is the number of tokens in the vocabulary
|
|
:param n_hidden: is the number of features in the embeddings
|
|
:param n_layers: is the number of transformer layers
|
|
:param n_heads: is the number of attention heads
|
|
:param filter_layers: are the set of layers to be used. All layers will be used if None.
|
|
This is used to test smaller versions of the model with fewer layers
|
|
:param is_clone_layers: specifies whether to clone the transformer layers (a bit faster)
|
|
:param dtype: is the data type of the model
|
|
:param device: is the device of the model
|
|
:param is_llm_int8: specifies whether to use int8 quantization
|
|
:param llm_int8_threshold: is the threshold $\alpha$ used to separate outlier features
|
|
"""
|
|
if filter_layers is None:
|
|
filter_layers = set(range(n_layers + 3))
|
|
|
|
self.n_vocab = n_vocab
|
|
self.n_hidden = n_hidden
|
|
self.n_layers = n_layers
|
|
self.n_heads = n_heads
|
|
self.filter_layers = filter_layers
|
|
self.is_clone_layers = is_clone_layers
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.is_llm_int8 = is_llm_int8
|
|
self.llm_int8_threshold = llm_int8_threshold
|
|
|
|
self.pre_created_layers = dict(
|
|
transformer_layer=None,
|
|
)
|
|
|
|
def _prepare_layer(self, layer: NeoXModule):
|
|
"""
|
|
#### Prepares the layer for usage
|
|
|
|
We move the layer to the device and convert it to the correct data type
|
|
|
|
:param layer: is the layer to prepare
|
|
:return: the prepared layer
|
|
"""
|
|
return layer.to(self.device, self.dtype)
|
|
|
|
@torch.no_grad()
|
|
def post_load_prepare(self, layer: NeoXModule, *,
|
|
is_llm_int8: bool = None,
|
|
device: torch.device = None,
|
|
llm_int8_threshold: float = None,
|
|
):
|
|
"""
|
|
<a id="post_load_prepare"></a>
|
|
|
|
### Layer transformations after loading the checkpoint
|
|
|
|
This function implements layer transformations after loading the checkpoint.
|
|
|
|
Currently, it only applies the int8 quantization.
|
|
|
|
:param layer: is the layer to prepare
|
|
:param is_llm_int8: specifies whether to use int8 quantization
|
|
:param device: is the device of the model
|
|
:param llm_int8_threshold: is the threshold $\alpha$ used to separate outlier features
|
|
:return: the prepared layer
|
|
"""
|
|
|
|
# Get default values if not specified
|
|
if is_llm_int8 is None:
|
|
is_llm_int8 = self.is_llm_int8
|
|
if device is None:
|
|
device = self.device
|
|
if llm_int8_threshold is None:
|
|
llm_int8_threshold = self.llm_int8_threshold
|
|
|
|
# Skip if not using int8 quantization
|
|
if not is_llm_int8:
|
|
return layer
|
|
|
|
# Only convert the linear layers in the transformer layers
|
|
if not isinstance(layer, TransformerLayer):
|
|
return layer
|
|
|
|
# Use `make_llm_int8_linear` defined in [utilities](./utils/llm_int8.html).
|
|
from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
|
|
|
|
# Convert the linear layers
|
|
with monit.section('Convert to int8'):
|
|
layer.attention.output = make_llm_int8_linear(layer.attention.output,
|
|
device=device,
|
|
threshold=llm_int8_threshold)
|
|
layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
|
|
device=device,
|
|
threshold=llm_int8_threshold)
|
|
layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
|
|
device=device,
|
|
threshold=llm_int8_threshold)
|
|
layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
|
|
device=device,
|
|
threshold=llm_int8_threshold)
|
|
#
|
|
return layer
|
|
|
|
def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
|
|
"""
|
|
#### Creates and caches a layer
|
|
|
|
Copying cached layers is faster than initializing new layers because it takes time to
|
|
initialize parameters.
|
|
|
|
:param name: is the name of the layer
|
|
:param creator: is the function to create the layer
|
|
:return: the created layer or a copy of the cached layer
|
|
"""
|
|
|
|
if not self.is_clone_layers:
|
|
return self._prepare_layer(creator())
|
|
|
|
if self.pre_created_layers[name] is None:
|
|
self.pre_created_layers[name] = self._prepare_layer(creator())
|
|
|
|
layer = copy.deepcopy(self.pre_created_layers[name])
|
|
return layer
|
|
|
|
def _create_transformer_layer(self):
|
|
return self._create_and_cache_layer(
|
|
'transformer_layer',
|
|
lambda: TransformerLayer(self.n_hidden, self.n_heads)
|
|
)
|
|
|
|
def _create_embedding_layer(self):
|
|
return Embedding(self.n_vocab, self.n_hidden)
|
|
|
|
def _create_final_norm_layer(self):
|
|
return FinalNorm(self.n_hidden)
|
|
|
|
def _create_readout_layer(self):
|
|
return ReadoutLayer(self.n_hidden, self.n_vocab)
|
|
|
|
@torch.no_grad()
|
|
def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
|
|
"""
|
|
### Generator to get layers
|
|
"""
|
|
# Embedding layer
|
|
if 0 in self.filter_layers:
|
|
with monit.section('Embedding layer'):
|
|
layer = self._prepare_layer(self._create_embedding_layer())
|
|
yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
|
|
|
|
# Transformer layers
|
|
for i in range(self.n_layers):
|
|
# Transformer layer
|
|
if i + 1 in self.filter_layers:
|
|
with monit.section(f'Transformer Layer {i}'):
|
|
yield self._create_transformer_layer(), \
|
|
(f'layer_{i + 2 :02d}-model_00-model_states.pt',
|
|
f'layer_{i + 2 :02d}-model_01-model_states.pt')
|
|
|
|
# Final normalization layer
|
|
if self.n_layers + 1 in self.filter_layers:
|
|
with monit.section('Final norm layer'):
|
|
layer = self._prepare_layer(self._create_final_norm_layer())
|
|
yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
|
|
|
|
# Readout layer
|
|
if self.n_layers + 2 in self.filter_layers:
|
|
with monit.section('Readout layer'):
|
|
layer = self._prepare_layer(self._create_readout_layer())
|
|
yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
|
|
|
|
for k in self.pre_created_layers.keys():
|
|
self.pre_created_layers[k] = None
|
|
|
|
@property
|
|
def total_layers(self):
|
|
"""
|
|
### Returns the total number of layers
|
|
"""
|
|
return self.n_layers + 3
|
|
|
|
@torch.no_grad()
|
|
def load(self) -> Generator[NeoXModule, None, None]:
|
|
"""
|
|
### Generator to load layers
|
|
"""
|
|
with monit.section("Layers"):
|
|
for i, (layer, files) in enumerate(self.get_layers()):
|
|
if files is not None:
|
|
layer.load_state(*checkpoint.load_checkpoint_files(files))
|
|
|
|
layer = self.post_load_prepare(layer)
|
|
|
|
monit.progress(min(0.99, (i + 1) / self.total_layers))
|
|
yield layer
|