Zero3 memory optimizations (#140)

This commit is contained in:
Varuna Jayasiri
2022-08-11 15:44:13 +05:30
committed by GitHub
parent 0bfb210671
commit 980a84ed4f
43 changed files with 12573 additions and 9 deletions

View File

@ -125,6 +125,11 @@ Solving games with incomplete information such as poker with CFR.
* [Top-k Sampling](sampling/top_k.html)
* [Nucleus Sampling](sampling/nucleus.html)
#### ✨ [Eleuther GPT-NeoX](neox/index.html)
#### ✨ [Scalable Training/Inference](scaling/index.html)
* [Zero3 memory optimizations](scaling/zero3/index.html)
## Highlighted Research Paper PDFs
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)

29
labml_nn/neox/__init__.py Normal file
View File

@ -0,0 +1,29 @@
"""
---
title: GPT-NeoX
summary: >
Simple GPT-NeoX implementation
---
# GPT-NeoX
This is a simple implementation of [Eleuther GPT-NeoX](https://papers.labml.ai/paper/2204.06745) for inference and fine-tuning.
* [Model definition](model.html)
* [Tokenizer](tokenizer.html)
* [Checkpoint downloading and loading helpers](checkpoint.html)
* [Utilities](utils/index.html)
### [Samples](samples/__init__.py)
* [Generating text](samples/generate.html)
* [Fine tuning the biases with pipeline-parallel](samples/finetune.html)
### [Evaluation](evaluation/__init__.py)
* [Evaluating half precision model on a single GPU](evaluation/half_precision.html)
**Official [Eleuther](https://www.eleuther.ai)
GPT-NoeX is source code is available at [eleutherai/gpt-neox](https://github.com/eleutherai/gpt-neox).**
"""

152
labml_nn/neox/checkpoint.py Normal file
View File

@ -0,0 +1,152 @@
"""
---
title: GPT-NeoX Checkpoints
summary: >
Code to download checkpoints and helpers to load them.
---
# GPT-NeoX Checkpoints
"""
from typing import Dict, Union, Tuple
import torch
from torch import nn
from labml import monit, lab, logger
from labml.logger import Text, inspect
from labml.utils.download import download_file
# Parent url
CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
# Download path
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
if not CHECKPOINTS_DOWNLOAD_PATH.exists():
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH)
def get_files_to_download(n_layers: int = 44):
"""
### Get files to download
:return: a list of files to be downloaded
"""
layers = (
# Embedding layer
[0] +
# Transformer layers
list(range(2, 2 + n_layers)) +
# Final normalization layer and readout layer
[47, 48]
)
return (
# Vocabulary and configs
['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +
# Layer checkpoints
[f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +
# Empty states (not used)
[f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
)
def download(n_layers: int = 44):
"""
## Download all checkpoint files
"""
# Get files to download
files = get_files_to_download(n_layers)
# Iterate
for i, f in monit.enum('Download All', files):
# Log
logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
# Download
download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f)
def load_checkpoint_files(files: Tuple[str, str]):
"""
### Load a pair of checkpoint files
:param files: pair of files to load
:return: the loaded parameter tensors
"""
checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
with monit.section('Load checkpoint'):
data = [torch.load(checkpoint_path / f) for f in files]
return data
def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
p2: Dict[str, torch.Tensor]):
"""
### Load a parameter by merging the partitions along first dimension
:param param: is the parameter
:param key: is the name of the parameter
:param p1: first partition dictionary
:param p2: second partition dictionary
"""
w1, w2 = p1[key], p2[key]
param.data[:w1.shape[0]] = w1
param.data[w1.shape[0]:] = w2
def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
p2: Dict[str, torch.Tensor]):
"""
### Load a parameter by merging the partitions along second dimension
:param param: is the parameter
:param key: is the name of the parameter
:param p1: first partition dictionary
:param p2: second partition dictionary
"""
w1, w2 = p1[key], p2[key]
param.data[:, :w1.shape[1]] = w1
param.data[:, w1.shape[1]:] = w2
def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
p2: Dict[str, torch.Tensor]):
"""
### Load an un-partitioned parameter
This does a sanity check to make use both partitions are the same
:param param: is the parameter
:param key: is the name of the parameter
:param p1: first partition dictionary
:param p2: second partition dictionary
"""
w1, w2 = p1[key], p2[key]
diff = sum((w1 - w2) ** 2).item()
assert diff < 1e-4, f'The partitions do not match: {key}'
param.data[:] = (w1 + w2) / 2.
def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
p2: Dict[str, torch.Tensor]):
"""
### Load biases that are partitioned which gets added on reduce
:param param: is the parameter
:param key: is the name of the parameter
:param p1: first partition dictionary
:param p2: second partition dictionary
"""
w1, w2 = p1[key], p2[key]
param.data[:] = w1 + w2
#
if __name__ == '__main__':
download()

View File

@ -0,0 +1,262 @@
"""
---
title: Evaluation
summary: >
Code to evaluate the model on NLP tasks through lm-evaluation-harness
---
# Evaluation
This is the code to test the model on
[EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness).
* [Evaluating half precision model on a single GPU](half_precision.html)
"""
import math
from typing import List
import torch
import torch.nn.functional as F
from lm_eval import tasks, evaluator, utils
from lm_eval.base import BaseLM
from tokenizers import Tokenizer
from torch import nn
from tqdm import tqdm
from labml import monit
from labml_nn.neox.tokenizer import get_tokenizer
class EvalHarnessAdapter(BaseLM):
"""
## Evaluation Harness Adapter
This is based on the [adapter from EleutherAI/gpt-neox](https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py)
"""
def __init__(self, tokenizer: Tokenizer, vocab_size: int, batch_size: int):
"""
:param tokenizer: is the [Huggingface Tokenizer](huggingface/tokenizers)
:param vocab_size: is the size of the vocabulary
(this differs from the tokenizer vocab size since neox adds some extra to make the embedding layer
model parallel.)
:param batch_size: is the batch size
"""
super().__init__()
self.tokenizer = tokenizer
self._eot_token_id = self.tokenizer.token_to_id("<|endoftext|>")
self._vocab_size = vocab_size
self._batch_size = batch_size
@property
def device(self):
raise RuntimeError()
@property
def vocab_size(self):
"""Size of the vocabulary"""
return self._vocab_size
@property
def eot_token_id(self):
"""End-of-text token"""
return self._eot_token_id
@property
def max_length(self):
"""Maximum sequence length"""
return 2048
@property
def max_gen_toks(self):
"""Maximum number of tokens to generate"""
return 128
@property
def batch_size(self):
"""
Batch size
"""
return self._batch_size
def tok_encode(self, string: str):
"""
Encode a given text
"""
return self.tokenizer.encode(string).ids
def tok_decode(self, tokens: List[int]):
"""
Decode text from token ids
"""
return self.tokenizer.decode(tokens)
def _model_call(self, inps: torch.Tensor):
raise NotImplementedError
def _model_generate(self, context, max_length, eos_token_id):
raise RuntimeError()
def greedy_until(self, requests):
raise RuntimeError()
@torch.no_grad()
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
"""
### Get log-likelihoods of the next tokens
:param requests: List of requests containing the context and the expected continuation.
:param disable_tqdm: If True, disable tqdm progress bar.
"""
# For results
res = []
# Reorder the requests in the descending order of the lengths,
# so that sequences with similar lengths are close
def _collate(x):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
reord = utils.Reorderer(requests, _collate)
# Loop through requests with `batch_size` number of requests at a time
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
# To store the inputs for the batch
inps = []
# The continuations for the batch
continuations = []
# Lengths of the input sequences
inplens = []
# Padded length for the batch
padded_length = None
# Loop through each request in the chunk and collect them into PyTorch tensors with paddings
for _, context_enc, continuation_enc in chunk:
# Concatenate the context and continuation
inp = context_enc + continuation_enc
# Truncate from left if the size exceeds the `max_length`
inp = inp[-(self.max_length + 1):]
# Remove final token
inp = inp[:-1]
# Create a tensor
inp = torch.tensor(inp, dtype=torch.long)
# Input length
inplen = inp.shape[0]
# Determine the padded length.
# Shorter sequences will get padded.
if padded_length is None:
padded_length = int(math.ceil(inplen / 32)) * 32
# padded_length = padded_length if padded_length is not None else inplen
# Padding
padding = torch.zeros(padded_length - inplen, dtype=torch.long)
# Add padding
inp = torch.cat([inp, padding], dim=0)
inps.append(inp)
continuations.append(continuation_enc)
inplens.append(inplen)
# Get model logits
logits = self._model_call(torch.stack(inps))
# Get log softmaxes
multi_logits = F.log_softmax(logits, dim=-1)
# Loop through the input/output pairs of the batch
for logits, inplen, cont_toks in zip(multi_logits, inplens, continuations):
# Get number of predicted tokens
contlen = len(cont_toks)
# Get logits of those
logits = logits[inplen - contlen: inplen]
# Get the tokens with the highest probabilities
greedy_tokens = logits.argmax(dim=-1)
# Get the target tokens
cont_toks = torch.tensor(cont_toks, dtype=torch.long).to(logits.device)
# Whether there's an exact match
max_equal = (greedy_tokens == cont_toks).all()
# Log-likelihoods of the target tokens
logits = torch.gather(logits, 1, cont_toks[:, None])
# Add the total log-likelihoods and whether there was a match to the results
res.append((float(logits.sum()), bool(max_equal)))
# Re-order and return results
return reord.get_original(res)
@torch.no_grad()
def run_eval(self, name: str, eval_tasks: List[str]):
"""
### Run given evaluations
"""
# Run [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) evaluator
results = evaluator.evaluate(lm=self, task_dict=tasks.get_task_dict(eval_tasks))
# Add configs
results["config"] = {
"name": name,
}
#
return results
class NoeXEvalHarnessAdapter(EvalHarnessAdapter):
"""
## Evaluation Harness Adapter
This is based on the [adapter from EleutherAI/gpt-neox](https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py)
"""
def __init__(self, model: nn.Module, tokenizer: Tokenizer, vocab_size: int, batch_size: int, device: torch.device):
"""
:param model: is model
:param tokenizer: is the [Huggingface Tokenizer](huggingface/tokenizers)
:param vocab_size: is the size of the vocabulary
(this differs from the tokenizer vocab size since neox adds some extra to make the embedding layer
model parallel.)
:param batch_size: is the batch size
:param device: is the device of the model
"""
super().__init__(tokenizer, vocab_size, batch_size)
self.model = model
self._device = device
def _model_call(self, inps: torch.Tensor):
"""
Call the model
"""
return self.model(inps.to(self._device))
def run_eval_harness(model: nn.Module, name: str, eval_tasks: List[str], device: torch.device, batch_size: int = 8):
"""
## Run evaluation harness with a given model
"""
# Load the tokenizer
with monit.section('Load tokenizer'):
tokenizer = get_tokenizer()
# All tasks if nothing is specified
if not eval_tasks:
eval_tasks = [
"anli_r1",
"anli_r2",
"anli_r3",
"hellaswag",
"lambada",
"piqa",
"winogrande",
"wsc",
"mathqa",
]
# Create the adapter
adapter = NoeXEvalHarnessAdapter(model, tokenizer, 50_432, batch_size, device)
# Run
return adapter.run_eval(name, eval_tasks)

View File

@ -0,0 +1,19 @@
import torch
from torch import nn
from labml import monit
from labml_nn.neox.evaluation import run_eval_harness
from labml_nn.neox.model import LayerGenerator
if __name__ == '__main__':
device = torch.device('cuda:0')
layers = list(LayerGenerator(is_clone_layers=True,
filter_layers=None,
dtype=torch.float16,
device=device
).load())
with monit.section('Sequential'):
model = nn.Sequential(*layers)
print(run_eval_harness(model, 'half_precision', ['lambada'], device))

572
labml_nn/neox/model.py Normal file
View File

@ -0,0 +1,572 @@
"""
---
title: GPT-NeoX Model Definition
summary: >
This is the model definition of GPT-NeoX.
---
# GPT-NeoX Model
Here is the code for layers of GPT-NeoX model and the code to load
20B checkpoint.
The method `load_state` in the layers load the checkpoints of that layer.
The checkpoint loading helpers are on [`checkpoint.py`](checkpoint.html)
"""
import copy
import math
from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
import torch
from torch import nn
from torch.cuda.amp import autocast
from labml import monit
from labml_nn.neox import checkpoint
from labml_nn.neox.utils.cache import get_cache
class NeoXModule(nn.Module):
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
pass
class Embedding(NeoXModule):
"""
## Embedding layer
This is a standard embeddings layer with code to load the checkpoint.
"""
def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
"""
:param n_vocab: is the size of the vocabulary
:param n_hidden: is the size of the embeddings
"""
super().__init__()
self.emb = nn.Embedding(n_vocab, n_hidden)
def forward(self, x: torch.Tensor):
"""
:param x: are the token ids of shape `[batch_size, seq_len]`
"""
return self.emb(x)
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
"""
Code to load the checkpoint
"""
with monit.section('Load embedding layer'):
checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
class RoPE(nn.Module):
"""
## Rotary Positional Embeddings
GPT-NeoX uses [rotary positional embeddings (RoPE)](https://papers.labml.ai/paper/2104.09864).
WE have annotated implementation of RoPE [here](https://nn.labml.ai/transformers/rope/index.html)
with more notes the theory.
"""
def __init__(self, d_rope: int, base: float = 10_000.):
"""
:param d_rope: is the number of features for RoPE embeddings
:param base: is the base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$, which defaults to $10000$
"""
super().__init__()
# To store $\theta_i$ for the features
self.theta = None
# Cache $\cos m\theta_i$ and $\sin m\theta_i$
self.cos_cached = None
self.sin_cached = None
# Base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$
self.base = base
# Number of features for RoPE
self.d_rope = d_rope
@staticmethod
def rotate_half(x: torch.Tensor):
"""
### Rotate the features
$[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
"""
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def forward(self, x: torch.Tensor, offset: int = 0):
"""
:param x: has shape `[..., seq, n_heads, d_k]`
:param offset: is the starting position of `x`. This is $\gt 0$ when we have
cached the keys and queries of previous positions
"""
# Get the actual sequence length
seq_len = x.shape[-3] + offset
# Initialize $\theta$
if self.theta is None:
# $\theta_i = 10000^{\frac{2(i-1)}{d}}$
theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
self.theta = theta.to(x.device).to(x.dtype)
# Initialize $\cos m\theta_i$ and $\sin m\theta_i$ cache
if (
self.cos_cached is None or
seq_len > self.cos_cached.shape[1] or
self.cos_cached.device != x.device or
self.cos_cached.dtype != x.dtype
):
# Get position indexes $m$
seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
# $m \theta_i$
idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
# Concatenate so that for row $m$ we have
#
# $$[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$$
idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
# Calculate $\cos m\theta_i$ and $\sin m\theta_i$ in fp32
with autocast(enabled=False):
idx_theta2 = idx_theta2.float()
# Add head dimension
self.cos_cached = idx_theta2.cos()[:, None, :]
self.sin_cached = idx_theta2.sin()[:, None, :]
# Cache them
self.cos_cached = self.cos_cached.to(x.dtype)
self.sin_cached = self.sin_cached.to(x.dtype)
# Split the features. We apply RoPE to only `d_rope` features
x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
# Get the sin and cos values from the cache
cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
# RoPE embeddings
#
# \begin{align}
# \begin{pmatrix}
# x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
# x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
# \end{pmatrix} \\
# \end{align}
#
# for $i \in {1, 2, ..., \frac{d}{2}}$
x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
# Concatenate with features that didn't get RoPE embeddings
return torch.cat((x_rope, x_pass), dim=-1)
class AttentionLayer(nn.Module):
"""
## Attention layer
"""
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
mask_fill: float = -10_000.0):
"""
:param n_hidden: the number of features in embeddings
:param n_heads: the number of attention heads
:param rope_percentage: percentage of features to add RoPE embeddings
:param mask_fill: masking fill value for attention matrix
"""
super().__init__()
self.n_heads = n_heads
self.mask_fill = mask_fill
# Linear layer for query, key and value
self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
# Final linear layer
self.output = nn.Linear(n_hidden, n_hidden)
# Number of features per head
d_k = n_hidden // n_heads
# RoPE embedding module
self.rope = RoPE(int(d_k * rope_percentage))
# Attention scaling factor
self.scale = 1 / math.sqrt(d_k)
# To cache causal mask
self.causal_mask = None
# Attention softmax module
self.softmax = nn.Softmax(dim=-2)
def _get_mask(self, attn: torch.Tensor):
"""
#### Calculate the causal mask
* `attn` has shape [batch_size, query_seq_len, key_seq_len, n_heads]
"""
# Query and key lengths
nq, nk = attn.shape[1:3]
# Create mask
if (
self.causal_mask is None or
self.causal_mask.shape[0] != nq or
self.causal_mask.shape[1] != nk or
self.causal_mask.device != attn.device
):
self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
# Return from cache
return self.causal_mask[None, :, :, None]
def forward(self, x: torch.Tensor):
"""
:param x: has shape `[batch_size, seq_len, n_hidden]`
"""
# Get query, key and value embeddings (all concatenated).
# The last dimension size will change from n_hidden -> `3 x n_hidden`
qkv = self.qkv_lin(x)
# Split into heads by changing the shape to `[batch_size, seq_len, n_heads, 3 * d_k]`
qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
# Split into query, key and value each of shape `[batch_size, seq_len, n_heads, 3 * d_k]`
q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
# If we are caching the states of previous tokens
if get_cache().get('use_cache', False):
# Get the state id's. We use to retrieve previous states and store the next states
prev_state_id, next_state_id = get_cache().get('state_ids')
# If there's cache
if prev_state_id is not None:
# Get the past keys and values. These will have shape `[batch_size, prev_seq_len, n_heads, d_k]`
k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
# Offset of the current embeddings
offset = k_past.shape[1]
# Add RoPE embeddings
q = self.rope(q, offset=offset)
k = self.rope(k, offset=offset)
# Concatenate the past
k = torch.cat([k_past, k], dim=1)
v = torch.cat([v_past, v], dim=1)
else:
# Add RoPE embeddings
q = self.rope(q)
k = self.rope(k)
# Save the current state
get_cache().push(f'attn_kv_{next_state_id}', (k, v))
else:
# No cache - simply add RoPE embeddings
q = self.rope(q)
k = self.rope(k)
# Disable auto-casting to fp16 for attention computation
with autocast(enabled=False):
if q.dtype == torch.float16:
# Convert to fp32 if the current dtype is fp16
attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
else:
# Do not cast for bfloat
attn = torch.einsum('bihk,bjhk->bijh', q, k)
# Scale attention
attn = attn * self.scale
# Get causal mask
mask = self._get_mask(attn)
# Apply mask
attn.masked_fill_(mask, self.mask_fill)
# Attention softmax
attn = self.softmax(attn)
# Get attention weighted values
output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
# Reshape from `[batch_size, seq_len, n_heads, d_k] to `[batch_size, seq_len, n_hidden]`
output = output.reshape(*x.shape)
# Final linear layer
return self.output(output)
class FFNLayer(nn.Module):
"""
## Feedforward Network
"""
def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
"""
:param n_hidden: is the embedding size
"""
super().__init__()
if not d_ff:
d_ff = n_hidden * 4
# Expansion linear layer
self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
# GELU activation
self.activation = nn.GELU()
# Contraction linear layer
self.dense_h4_h = nn.Linear(d_ff, n_hidden)
def forward(self, x: torch.Tensor):
"""
:param x: has shape `[batch_size, seq_len, n_hidden]`
"""
x = self.dense_h_h4(x)
x = self.activation(x)
x = self.dense_h4_h(x)
return x
class TransformerLayer(NeoXModule):
"""
## Transformer Layer
"""
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):
"""
:param n_hidden: is the embedding size
:param n_heads: is the number of heads
*Out implementation doesn't include dropout*.
"""
super().__init__()
# Layer normalization before attention
self.pre_ln_attn = nn.LayerNorm(n_hidden)
# Layer normalization before FFN
self.pre_ln_ffn = nn.LayerNorm(n_hidden)
# Attention layer
self.attention = AttentionLayer(n_hidden, n_heads)
# FFN layer
self.ffn = FFNLayer(n_hidden)
def forward(self, x: torch.Tensor):
"""
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
"""
# Residual connection
residual = x
# NeoX runs attention and feedforward network in parallel
attn = self.attention(self.pre_ln_attn(x))
ffn = self.ffn(self.pre_ln_ffn(x))
# Add them and the residual connection
return attn + ffn + residual
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
"""
Code to load the checkpoint
"""
with monit.section('Load transformer layer'):
# Attention output transform
checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
# Attention query, key and value transform
checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
# Layer norm before attention
checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
# FFN second transform
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
# FFN first transform
checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
# Layer norm before FFN
checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
class FinalNorm(NeoXModule):
"""
## Final normalization layer
"""
def __init__(self, n_hidden: int = 6_144):
"""
:param n_hidden: is the embedding size
"""
super().__init__()
self.ln = nn.LayerNorm(n_hidden)
def forward(self, x: torch.Tensor):
"""
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
"""
return self.ln(x)
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
"""
Code to load the checkpoint
"""
with monit.section('Load final normalization layer'):
checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
class ReadoutLayer(NeoXModule):
"""
Readout layer
"""
def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
"""
:param n_hidden: is the embedding size
:param n_vocab: is the size of the vocabulary
"""
super().__init__()
self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
def forward(self, x: torch.Tensor):
"""
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
"""
return self.linear(x)
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
"""
Code to load the checkpoint
"""
with monit.section('Load final linear layer'):
checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
class LayerGenerator:
pre_created_layers: Dict[Any, Optional[NeoXModule]]
def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
n_layers: int = 44, n_heads: int = 64,
filter_layers: Optional[Set] = None,
is_clone_layers: bool = True,
dtype: torch.dtype = torch.float,
device: torch.device = torch.device('cpu')):
"""
### Generator to create layers
The layers are generated in the same order as checkpoints.
It gives `None` when a layer is not available; we use the layer indices as NeoX and there are two
transformation layers we don't need in our implementation.
:param n_vocab: is the number of tokens in the vocabulary
:param n_hidden: is the number of features in the embeddings
:param n_layers: is the number of transformer layers
:param n_heads: is the number of attention heads
:param filter_layers: are the set of layers to be used. All layers will be used if None.
This is used to test smaller versions of the model with fewer layers
:param is_clone_layers: specifies whether to clone the transformer layers (a bit faster)
:param dtype: is the data type of the model
:param device: is the device of the model
:return: the layers as a generator
"""
if filter_layers is None:
filter_layers = set(range(n_layers + 3))
self.n_vocab = n_vocab
self.n_hidden = n_hidden
self.n_layers = n_layers
self.n_heads = n_heads
self.filter_layers = filter_layers
self.is_clone_layers = is_clone_layers
self.dtype = dtype
self.device = device
self.pre_created_layers = dict(
transformer_layer=None,
)
def _prepare_layer(self, layer: NeoXModule):
layer = layer.to(self.device, self.dtype)
return layer
def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
if self.pre_created_layers[name] is None or not self.is_clone_layers:
layer = creator()
else:
layer = copy.deepcopy(self.pre_created_layers[name])
layer: NeoXModule = self._prepare_layer(layer)
if self.pre_created_layers[name] is None:
self.pre_created_layers[name] = layer
return layer
def _create_transformer_layer(self):
return self._create_and_cache_layer(
'transformer_layer',
lambda: TransformerLayer(self.n_hidden, self.n_heads)
)
def _create_embedding_layer(self):
return Embedding(self.n_vocab, self.n_hidden)
def _create_final_norm_layer(self):
return FinalNorm(self.n_hidden)
def _create_readout_layer(self):
return ReadoutLayer(self.n_hidden, self.n_vocab)
@torch.no_grad()
def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
# Embedding layer
if 0 in self.filter_layers:
with monit.section('Embedding layer'):
layer = self._prepare_layer(self._create_embedding_layer())
yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
# Transformer layers
for i in range(self.n_layers):
# Transformer layer
if i + 1 in self.filter_layers:
with monit.section(f'Transformer Layer {i}'):
yield self._create_transformer_layer(), \
(f'layer_{i + 2 :02d}-model_00-model_states.pt',
f'layer_{i + 2 :02d}-model_01-model_states.pt')
# Final normalization layer
if self.n_layers + 1 in self.filter_layers:
with monit.section('Final norm layer'):
layer = self._prepare_layer(self._create_final_norm_layer())
yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
# Readout layer
if self.n_layers + 2 in self.filter_layers:
with monit.section('Readout layer'):
layer = self._prepare_layer(self._create_readout_layer())
yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
@property
def total_layers(self):
return self.n_layers + 3
@torch.no_grad()
def load(self) -> Generator[NeoXModule, None, None]:
with torch.no_grad():
with monit.section("Layers"):
for i, (layer, files) in enumerate(self.get_layers()):
if files is not None:
layer.load_state(*checkpoint.load_checkpoint_files(files))
monit.progress(min(0.99, (i + 1) / self.total_layers))
yield layer

View File

@ -0,0 +1,12 @@
"""
---
title: Samples
summary: >
Samples for inference and fine-tuning
---
# Samples
* [Generating text](generate.html)
* [Fine tuning the biases with pipeline-parallel training](finetune.html)
"""

View File

@ -0,0 +1,125 @@
"""
---
title: Fine Tune GPT-NeoX
summary: >
Fine tune GPT-NeoX biases with Fairscale pipeline parallel module
---
# Fine Tune GPT-NeoX
This shows how to fine tune GPT-NeoX with pipeline parallelism.
"""
import fairscale
import torch
import torch.nn as nn
import torch.utils.data
import torch.utils.data
import typing
from torch.utils.data import DataLoader, RandomSampler
from labml import experiment, monit, tracker, lab
from labml.configs import option
from labml.logger import inspect
from labml_nn.neox.utils.text_dataset import get_training_data
from labml_nn.neox.utils.finetune import FineTuneBiases
from labml_nn.neox.model import LayerGenerator, NeoXModule
from labml_nn.neox.utils import balance_layers_simple
from labml_nn.neox.utils.trainer import PipelineParallelTrainerConf
@option(PipelineParallelTrainerConf.layers, 'PipelineBiases')
def neox_layers(c: PipelineParallelTrainerConf):
"""
### Load GPT-NeoX layers
"""
return list(LayerGenerator(is_clone_layers=c.is_clone_layers,
filter_layers=c.filter_layers,
dtype=c.dtype,
).load())
@option(PipelineParallelTrainerConf.fine_tuner, 'PipelineBiases')
def fine_tune_biases(c: PipelineParallelTrainerConf):
"""
### Create fine tuner for biases
"""
fine_tuner = FineTuneBiases(typing.cast(typing.List[NeoXModule], c.layers))
# Mark biases as trainable
fine_tuner.set_trainable_params()
#
return fine_tuner
@option(PipelineParallelTrainerConf.model, 'PipelineBiases')
def pipe_model(c: PipelineParallelTrainerConf):
"""
### Create pipeline parallel model
"""
if c.is_checkpointing:
raise NotImplementedError()
else:
layers = c.layers
# Create the Pipe module
with monit.section('Pipe'):
# Get the layer distribution across GPUs
balance = balance_layers_simple(len(layers), c.n_gpus)
inspect(balance=balance)
# Devices for each GPU
devices = [torch.device(f'cuda:{i}') for i in range(c.n_gpus)]
# Create Fairscale Pipe module
pipe_model = fairscale.nn.Pipe(nn.Sequential(*layers),
balance=balance,
devices=devices,
chunks=c.chunks)
#
return pipe_model
@option(PipelineParallelTrainerConf.train_loader)
def tiny_shakespeare(c: PipelineParallelTrainerConf):
"""
#### Tiny Shakespeare dataset
"""
dataset = get_training_data(c.max_seq_len)
return DataLoader(dataset,
batch_size=c.batch_size,
sampler=RandomSampler(dataset, replacement=True))
def main():
# Create experiment
experiment.create(name='pipe_neox_biases',
writers={'screen', 'web_api'})
# Initialize configs
conf = PipelineParallelTrainerConf()
experiment.configs(conf, {
'learning_rate': 3e-4,
'is_checkpointing': False,
'max_seq_len': 128,
'batch_size': 64,
'chunks': 8,
})
# Start the experiment
with experiment.start():
# Initialize the model. Do this before the loop for cleaner logs.
_ = conf.model
# Train
for epoch in monit.loop(conf.epochs):
conf.train_epoch()
tracker.new_line()
torch.save(conf.fine_tuner.state_dict(), str(lab.get_data_path() / 'fine_tune.pt'))
#
if __name__ == '__main__':
main()

View File

@ -0,0 +1,102 @@
"""
---
title: Generate Text with GPT-NeoX
summary: >
Generate Text with GPT-NeoX
---
# Generate Text with GPT-NeoX
This shows how to generate text from GPT-NeoX with a single GPU.
This needs a GPU with more than 45GB memory.
"""
# Imports
from typing import List
import torch
from torch import nn
from labml import monit
from labml_nn.neox.model import LayerGenerator
from labml_nn.neox.utils import get_tokens, print_tokens
from labml_nn.neox.utils.cache import get_cache
# List of layers to load. This is used for testing.
# You can assign a subset of layers like `{0, 1}` so that it only loads
# the first to transformer layers.
LAYERS = None
# Prompt to complete
PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'
def infer(model: nn.Module, ids: List[int], device: torch.device):
"""
### Predict the next token
:param model: is the model
:param ids: are the input token ids
:param device: is the device of the model
"""
with torch.no_grad():
# Get the tokens
x = torch.tensor(ids)[None, :].to(device)
# Eval model
x = model(x)
# Return predicted token
return x[0].max(dim=-1)[1].tolist()
def generate():
"""
## Generate text
"""
# Setup [cache](../utils/cache.html) to cache intermediate key/value pairs for faster generation
cache = get_cache()
cache.set('use_cache', True)
# Device
device = torch.device('cuda:0')
# Load layers
layers = list(LayerGenerator(is_clone_layers=True,
filter_layers=LAYERS,
dtype=torch.float16,
device=device,
).load())
model = nn.Sequential(*layers)
# Get token ids
ids = get_tokens(PROMPT)
# Run the model
cache.set('state_ids', (None, 1))
with monit.section('Infer'):
next_token = infer(model, ids, device)[-1]
# Append the predicted token
ids += [next_token]
# Predict 100 tokens
for i in range(1, 100):
# Set the state to use cached activations
cache.set('state_ids', (i, i + 1))
# Get next token. Note that we only feed the last token to the model because
# we cache the key/value pairs of previous tokens.
with monit.section('Infer'):
next_token = infer(model, [next_token], device)[-1]
# Append the predicted token
ids += [next_token]
# Print
print_tokens(ids, [ids])
#
if __name__ == '__main__':
generate()

View File

@ -0,0 +1,28 @@
"""
---
title: GPT-NeoX Tokenizer
summary: >
Loads the GPT-NeoX tokenizer
---
# GPT-NeoX Tokenizer
This initializes a Hugging Face tokenizer from the downloaded vocabulary.
"""
from tokenizers import Tokenizer
from labml import lab, monit
@monit.func('Load NeoX Tokenizer')
def get_tokenizer() -> Tokenizer:
"""
### Load NeoX Tokenizer
:return: the tokenizer
"""
vocab_file = lab.get_data_path() / 'neox' / 'slim_weights' / '20B_tokenizer.json'
tokenizer = Tokenizer.from_file(str(vocab_file))
return tokenizer

View File

@ -0,0 +1,134 @@
"""
---
title: Utilities and Helpers
summary: >
Utilities and helper functions
---
# Utilities and Helpers
* [Cache for intermediate activations (for faster inference)](cache.html)
* [Tools for finetuning](finetune.html)
* [Trainer](trainer.html)
* [Text dataset](text_dataset.html)
"""
import typing
from typing import List, Optional
import torch
from labml import logger
from labml.logger import Text
from labml_nn.neox.tokenizer import get_tokenizer
if typing.TYPE_CHECKING:
from tokenizers import Tokenizer
# Tokenizer singleton
_TOKENIZER: Optional['Tokenizer'] = None
def get_tokens(text: str) -> List[int]:
"""
### Get token ids
:param text: is the text to tokenize
:return: the token ids
"""
global _TOKENIZER
if _TOKENIZER is None:
_TOKENIZER = get_tokenizer()
return _TOKENIZER.encode_batch([text])[0].ids
def print_token_outputs(ids: List[int], *xs: torch.Tensor):
"""
### Print tokens from model outputs
Pretty prints target tokens along side outputs from the model(s).
:param ids: are the target token ids
:param xs: are the model(s) outputs
"""
ids = ids + [-1]
xs = [[-1] + x[0].max(dim=-1)[1].tolist() for x in xs]
print_tokens(ids, xs)
def print_tokens(target: List[int], others: List[List[int]]):
"""
### Print tokens
Pretty prints tokens for comparison
:param target: are the target token ids
:param others: are the sampled outputs from the model(s)
"""
# Load tokenizer
global _TOKENIZER
if _TOKENIZER is None:
_TOKENIZER = get_tokenizer()
# Convert the tokens to list of strings
text = []
for i in range(len(target)):
tokens = [_TOKENIZER.decode([target[i]]) if target[i] != -1 else '---']
for j in range(len(others)):
tokens.append(_TOKENIZER.decode([others[j][i]]) if others[j][i] != -1 else '---')
text.append(tokens)
# Stats
correct = [0 for _ in others]
total = 0
# Iterate through tokens
for i in range(len(target)):
parts = [(f'{i}: ', Text.meta)]
parts += [('"', Text.subtle), (text[i][0], Text.subtle), ('"', Text.subtle), '\t']
# Empty target
if target[i] == -1:
for j in range(len(others)):
parts += [('"', Text.subtle), (text[i][j + 1], Text.subtle), ('"', Text.subtle), '\t']
logger.log(parts)
continue
# Number of tokens
total += 1
# Other outputs
for j in range(len(others)):
correct[j] += 1 if others[j][i] == target[i] else 0
parts += [('"', Text.subtle),
(text[i][j + 1], Text.success if others[j][i] == target[i] else Text.danger),
('"', Text.subtle), '\t']
logger.log(parts)
# Stats
parts = [(f'{total}', Text.highlight), '\t']
for j in range(len(others)):
parts += [(f'{correct[j]}', Text.value), '\t']
logger.log(parts)
def balance_layers_simple(n_layers: int, n_chunks: int):
"""
### Balance layers
Split the `n_layers` into `n_chunks`. This is used for pipeline parallel training.
:param n_layers: is the number of layers
:param n_chunks: is the number of chunks
:return: returns a list with the number of layers for each chunk
"""
balance = []
for i in range(n_chunks):
balance.append((n_layers - sum(balance)) // (n_chunks - i))
return list(reversed(balance))

View File

@ -0,0 +1,118 @@
"""
---
title: Cache for Intermediate Activations
summary: >
Cache for intermediate activations for faster inference.
---
# Cache for Intermediate Activations
During inference the model outputs token by token.
We use this simple cache to store key's and value's attention layers,
so that we don't have to recompute them for previous tokens.
"""
from typing import Any
class Cache:
"""
## Cache
This maintains a key-value cache and queues push values and pop them in the same order.
The queues are useful since we have multiple attention layers.
"""
def __init__(self):
self._cache = {}
def clear_all(self):
"""
### Clear cache
"""
self._cache = {}
def push(self, name: str, value: Any):
"""
### Push a value to a queue
:param name: is the name of the queue
:param value: is the value to be pushed
"""
# Create an empty queue if it's not present
if name not in self._cache:
self._cache[name] = []
# Push to the queue
self._cache[name].append(value)
def q_size(self, name):
"""
### Return the size of the queue
:param name: is the name of the queue
:return: size of the queue if exists else None
"""
if name not in self._cache:
return None
if type(self._cache[name]) != list:
return None
return len(self._cache[name])
def pop(self, name: str):
"""
### Pop from a queue
:param name: is the name of the queue
:return: the value
"""
return self._cache[name].pop(0)
def set(self, key: str, value: Any):
"""
### Cache a value
:param key: is the name of the value to be cached
:param value: is the value
"""
self._cache[key] = value
def get(self, key: str, default: Any = None):
"""
### Retrieve a value from cache
:param key: is the name used when caching
:param default: is the default value if the cache is empty
:return: the cached value
"""
return self._cache.get(key, default)
def clear(self, key: str):
"""
### Clear a cache value
:param key: is the name used when caching
"""
del self._cache[key]
# Singleton for cache
_INSTANCE = None
def get_cache() -> Cache:
"""
### Get the cache instance
:return: the cache instance
"""
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Cache()
return _INSTANCE

View File

@ -0,0 +1,55 @@
from typing import List, Dict
import torch
from torch import nn
from labml_nn.neox.model import TransformerLayer, NeoXModule
class FineTuner:
def __init__(self, layers: List[NeoXModule]):
self.layers = layers
def get_trainable_params(self) -> Dict[str, nn.Parameter]:
params = {}
for i, layer in enumerate(self.layers):
params.update(self.get_layer_trainable_params(layer, prefix=f'layer_{i :02d}'))
return params
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
raise NotImplementedError
def set_trainable_params(self):
for layer in self.layers:
# Set `requires_grad` to `False` for the entire layer.
layer.requires_grad_(False)
#
for p in self.get_trainable_params().values():
p.requires_grad_(True)
def state_dict(self):
return {n: p.data.cpu() for n, p in self.get_trainable_params().items()}
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
params = self.get_trainable_params()
for n, p in params.items():
p.data[:] = state_dict[n].to(p.data.device)
for n in state_dict.keys():
assert n in params, n
class FineTuneBiases(FineTuner):
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
params = {}
if isinstance(layer, TransformerLayer):
# No need to train the mlp bias because we are adding it with attention output
params[f'{prefix}.attention.output.bias'] = layer.attention.output.bias
params[f'{prefix}.attention.qkv_lin.bias'] = layer.attention.qkv_lin.bias
params[f'{prefix}.ffn.dense_h_h4.bias'] = layer.ffn.dense_h_h4.bias
else:
pass
return params

View File

@ -0,0 +1,132 @@
"""
---
title: Text Dataset for GPT-NeoX
summary: >
Loads text datasets to fine-tune GPT-NeoX
---
# Text Dataset for GPT-NeoX
"""
from pathlib import PurePath, Path
from typing import Optional, List
import torch
import torch.utils.data
from labml import lab
from labml import monit
from labml.logger import inspect
from labml.utils.download import download_file
from labml_nn.neox.tokenizer import get_tokenizer
def load_text(path: PurePath, url: Optional[str] = None, *, filter_subset: Optional[int] = None):
"""
### Load text file
:param path: is the location of the text file
:param url: is the URL to download the file from
:param filter_subset: is the number of characters to filter.
Use this during testing when trying large datasets
:return: the text content
"""
path = Path(path)
# Download if it doesn't exist
if not path.exists():
if not url:
raise FileNotFoundError(str(path))
else:
download_file(url, path)
with monit.section("Load data"):
# Load data
with open(str(path), 'r') as f:
text = f.read()
# Filter
if filter_subset:
text = text[:filter_subset]
#
return text
class NeoXDataset(torch.utils.data.Dataset):
"""
## Dataset for fine-tuning GPT-NeoX
This is not optimized to very large datasets.
"""
def __init__(self, tokens: List[int], seq_len: int):
"""
:param tokens: is the list of token ids
:param seq_len: is the sequence length of a single training sample
"""
self.seq_len = seq_len
# Number of samples
n_samples = len(tokens) // seq_len
self.n_samples = n_samples
# Truncate
tokens = tokens[:n_samples * seq_len + 1]
# Create a PyTorch tensor
self.tokens = torch.tensor(tokens)
def __len__(self):
return self.n_samples
def __getitem__(self, idx: int):
"""
### Get a sample
:param idx: is the index of the sample
:return: the input and the target
"""
offset = idx * self.seq_len
return self.tokens[offset:offset + self.seq_len], self.tokens[offset + 1:offset + 1 + self.seq_len]
DATASETS = {
'tiny_shakespeare': {
'file': 'tiny_shakespeare.txt',
'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
}
}
def get_training_data(seq_len: int = 32, dataset_name: str = 'tiny_shakespeare', truncate: int = -1):
"""
### Load Dataset
:param seq_len: is the sequence length of a single training sample
:param dataset_name: is the name of the dataset
:return: the dataset
"""
ds = DATASETS[dataset_name]
# Load the content
text = load_text(lab.get_data_path() / ds['file'], ds['url'])
# Tokenize
tokenizer = get_tokenizer()
tokens = tokenizer.encode_batch([text])[0]
if truncate > 0:
token_ids = tokens.ids[:truncate * seq_len]
else:
token_ids = tokens.ids
#
return NeoXDataset(token_ids, seq_len)
def _test():
dataset = get_training_data()
inspect(tokens=len(dataset.tokens))
#
if __name__ == '__main__':
_test()

View File

@ -0,0 +1,182 @@
from typing import Optional, Set, List
import torch.nn as nn
import torch.optim
import torch.utils.data
from torch.cuda import amp
from torch.cuda.amp import GradScaler
from labml import monit, tracker
from labml.configs import BaseConfigs, option
from labml_nn.neox.utils.finetune import FineTuner
def get_trainable_params(model: nn.Module):
"""
### Get trainable parameters
:param model: is the model to train
:return: a list of parameters for training
"""
# Get all parameters
params = list(model.parameters())
# Filter parameters that require gradients
trainable_params = [p for p in params if p.requires_grad]
#
return trainable_params
class TrainerConf(BaseConfigs):
model: nn.Module
layers: List[nn.Module]
optimizer: torch.optim.Optimizer = 'Adam'
train_loader: torch.utils.data.DataLoader
valid_loader: Optional[torch.utils.data.DataLoader] = None,
device: torch.device = torch.device('cuda:0')
scaler: Optional[GradScaler] = 'Default'
is_amp: bool = True
dtype: torch.dtype = torch.float16
is_clone_layers: bool = True
loss_func: nn.Module = nn.CrossEntropyLoss()
checkpoints_per_epoch: int = 0
samples_per_epoch: int = 0
grad_norm: Optional[float] = 1.0
learning_rate: float = 3e-4
max_seq_len: int = 1024
batch_size: int = 64
epochs: int = 16
n_gpus: int = torch.cuda.device_count()
filter_layers: Optional[Set] = None
def get_loss(self, sample, dataset_split: str):
"""
:param dataset_split: train/valid
:param sample: is the sample
:return: the loss, output and the target
"""
data, target = sample
# Forward pass
with monit.section('Forward pass'):
output = self.model(data.to(self.device))
# Move targets to the same device as output
target = target.to(output.device)
# Calculate loss
loss = self.loss_func(output.view(target.numel(), -1), target.view(-1))
return loss, output, target
def train(self):
for epoch in monit.loop(self.epochs):
self.train_epoch()
tracker.new_line()
def sample(self, idx):
pass
def save_checkpoint(self, idx):
pass
def get_iterators(self):
# Iterate through the batches
iterators = [('train', self.train_loader)]
if self.valid_loader is not None:
iterators.append(('valid', self.valid_loader))
if self.samples_per_epoch > 0:
iterators.append((self.sample, [i for i in range(self.samples_per_epoch)]))
if self.checkpoints_per_epoch > 0:
iterators.append((self.save_checkpoint, [i for i in range(self.checkpoints_per_epoch)]))
return iterators
def train_epoch(self):
# Set model for train
self.model.train()
iterators = self.get_iterators()
for split_name, sample in monit.mix(1024, *iterators):
if split_name == 'train':
# Set gradients to zero
self.optimizer.zero_grad()
tracker.add_global_step()
with torch.set_grad_enabled(split_name == 'train'):
if self.is_amp:
# Forward pass
with amp.autocast():
loss, output, target = self.get_loss(sample, split_name)
else:
loss, output, target = self.get_loss(sample, split_name)
# Get predictions
pred = output.argmax(dim=-1)
# Calculate accuracy
accuracy = pred.eq(target).sum().item() / (target != -100).sum()
tracker.add({f'loss.{split_name}': loss, f'acc.{split_name}': accuracy * 100})
if split_name == 'train':
if self.scaler is not None:
# Backward pass
loss = self.scaler.scale(loss)
# tracker.add({'loss.scaled': loss})
with monit.section('Backward pass'):
loss.backward()
# Optimize
with monit.section('Optimize'):
if self.scaler is None:
self.optimizer.step()
else:
self.scaler.unscale_(self.optimizer)
if self.grad_norm is not None:
torch.nn.utils.clip_grad_norm_(get_trainable_params(self.model), self.grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
tracker.save()
@option(TrainerConf.optimizer, 'Adam')
def adam_optimizer(c: TrainerConf):
if c.dtype == torch.float32:
return torch.optim.Adam(get_trainable_params(c.model), lr=c.learning_rate)
elif c.dtype == torch.float16:
from labml_nn.optimizers.adam_fp16 import AdamFP16
return AdamFP16(get_trainable_params(c.model), lr=c.learning_rate)
else:
raise NotImplementedError()
@option(TrainerConf.optimizer, 'SGD')
def sgd_optimizer(c: TrainerConf):
return torch.optim.SGD(get_trainable_params(c.model), lr=c.learning_rate)
@option(TrainerConf.scaler, 'Default')
def grad_scaler(c: TrainerConf):
if not c.is_amp:
return None
if c.dtype == torch.float16:
from labml_nn.optimizers.adam_fp16 import GradScalerFP16
return GradScalerFP16()
else:
return GradScaler()
class PipelineParallelTrainerConf(TrainerConf):
is_checkpointing: bool = False
chunks: int
fine_tuner: FineTuner

View File

@ -0,0 +1,136 @@
"""
---
title: Adam Optimizer for Half Precision Training
summary: A simple PyTorch implementation/tutorial of Adam optimizer
---
# Adam Optimizer for Half Precision Training
"""
from typing import Dict, Tuple, Optional, Any
import torch
from torch import nn
from torch.optim import Optimizer
from torch.cuda.amp import grad_scaler
from collections import defaultdict, abc
from labml_nn.optimizers import WeightDecay
from labml_nn.optimizers.adam import Adam
class AdamFP16(Adam):
"""
## Adam Optimizer for Half Precision Training
We extend [Adam Optimizer](adam.html) but use FP32 to store gradients and moments.
"""
def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
defaults: Optional[Dict[str, Any]] = None):
# Parameter to store 32 bit gradients. This get populated by the `GradScaler` defined below.
self.grad_fp32 = {}
# Call the [Adam Optimizer](adam.html) initializer
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
"""
### Initialize a parameter state
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `param` is the parameter tensor $\theta_{t-1}$
All the state tensors use FP32.
"""
# This is the number of optimizer steps taken on the parameter, $t$
state['step'] = 0
# Exponential moving average of gradients, $m_t$
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
# Exponential moving average of squared gradient values, $v_t$
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
# Maintain a FP32 copy of the parameters
state['fp32_copy'] = param.to(torch.float)
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
"""
### Take an update step for a given parameter tensor
* `state` is the optimizer state of the parameter (tensor)
* `group` stores optimizer attributes of the parameter group
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
* `param` is the parameter tensor $\theta_{t-1}$
"""
# Get the FP32 parameters
param_fp32 = state['fp32_copy']
# Get the FP32 gradients if available
grad_fp32 = self.grad_fp32.get(param, None)
if grad_fp32 is not None:
del self.grad_fp32[param]
grad = grad_fp32
else:
# Otherwise, convert the gradients to FP32
grad = grad.to(torch.float)
# Calculate weight decay
grad = self.weight_decay(param_fp32, grad, group)
# Get $m_t$ and $v_t$
m, v = self.get_mv(state, group, grad)
# Increment $t$ the number of optimizer steps
state['step'] += 1
# Perform *Adam* update
self.adam_update(state, group, param_fp32, m, v)
# Set the parameters
param.data = param_fp32.to(param.dtype)
class GradScalerFP16(grad_scaler.GradScaler):
"""
## Gradient Scaler with half precision gradients
We extend PyTorch gradient scaler to use FP32 gradients.
"""
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
# Loop through parameters
for group in optimizer.param_groups:
for param in group["params"]:
# Skip non-trainable parameters
if param.grad is None:
continue
# Not implemented for sparse tensors
if param.grad.is_sparse:
raise NotImplementedError
# If we are using the `AdamFP16` optimizer set `optimizer.grad_fp32[param]` to the FP32 gradients
if isinstance(optimizer, AdamFP16):
grad = param.grad.to(torch.float)
optimizer.grad_fp32[param] = grad
# Otherwise, do not convert the gradients to FP32
else:
grad = param.grad
per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)
# Unscale all the gradients
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device))
#
return per_device_found_inf._per_device_tensors

View File

@ -0,0 +1,11 @@
"""
---
title: Large scale model training
summary: >
Large scale model training/inference implementations.
---
# Large scale model training
* [Zero-DP optimizer](zero3/index.html)
"""

View File

@ -0,0 +1,495 @@
"""
---
title: Zero-DP Memory Optimization
summary: >
This is an implementation of Zero-DP Memory Optimization written in PyTorch.
---
# Zero-DP Memory Optimization
This is an implementation of Zero-DP introduced in the paper
[ZeRO: Memory Optimization Towards Training A Trillion Parameter Models](https://papers.labml.ai/paper/1910.02054),
It keeps shards of the optimizer state, gradients and parameters into multiple devices/nodes.
It reduces the memory consumption to $\frac{(2 + 2 + K)\Psi}{N_d}$ of the original model,
where $\Psi$ is the number of parameters, $N_d$ is the number of shards,
and $K$ is number of optimizer bytes per parameter.
$2 + 2$ are the parameter and gradient memory assuming 16-bit precision; i.e. 2 bytes per parameter and gradient.
$K = 12$ for Adam optimizer because it maintains a copy of parameters, and two moments per parameter in fp32.
The communication volume of Zero-DP is $\mathcal{O}(3\Psi)$. For comparison data-parallel training
has a communication volume of $\mathcal{O}(2\Psi)$.
Although this is named `Zero3`, we have only implemented the Zero-DP part of it and not the
Zero-R memory optimizations which target residual memory consumption.
Out implementation supports training only a subset of parameters.
This implementation is inspired by [Fairscale FSDP](https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html).
[Here's a script to fine-tune](finetune_neox.html) GPT NeoX using Zero-DP memory optimization.
"""
import functools
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
class Zero3Layer(nn.Module):
"""
## Zero3 Layer
Each layer of the model (or a combination of a few consecutive layers) should be wrapped in
this module.
"""
# Each shard keeps parameters in `chunk` list.
# The `chunk[0]` is for trainable parameters and `chunk[1]` is for fixed parameters.
chunk: List[nn.Parameter]
# This is the sizes of the chunks in `chunk` list.
chunk_size: List[int]
# The first chunk is for trainable parameters.
TRAINING_PARAMS_IDX = 0
# This is the list of parameters split into lists as trainable and fixed parameters.
param_refs: List[List[nn.Parameter]]
# CUDA stream to featch parameters
fetch_stream: Optional[torch.cuda.Stream]
# CUDA stream to backup/accumulate gradients
backup_stream: Optional[torch.cuda.Stream]
# List of layers right before this layer
prev_layer: List['Zero3Layer']
# List of layers right after this layer
next_layer: List['Zero3Layer']
# The position of the current layer; used this for debugging logs
layer_idx: int
# Whether parameters have been fetched
is_fetched: bool
# Device of the layer
device: torch.device
# Data type of the layer
dtype: torch.dtype
# The module to be wrapped
module: nn.Module
# Number of nodes/devices the data is sharded across
world_size: int
def __init__(self, module: nn.Module, rank: int, world_size: int, device: torch.device, dtype: torch.dtype):
"""
:param module: The module to be wrapped.
:param rank: The rank of the current node.
:param world_size: The number of nodes/devices the data is sharded across.
:param device: The device of the layer.
:param dtype: The data type of the layer.
"""
super().__init__()
# Initialize the properties
self.device = device
self.dtype = dtype
self.module = module
self.prev_layer = []
self.next_layer = []
self.is_fetched = False
self.world_size = world_size
self.layer_idx = -1
self.fetch_stream = None
self.backup_stream = None
with torch.no_grad():
# Collect all the parameters of the layer
all_param_refs = [p for p in self.parameters()]
# Store the shape of the parameters because we need it later to reconstruct them
for p in all_param_refs:
p._orig_shape = p.shape
# All parameters should have the same type
for p in all_param_refs:
assert p.dtype == dtype, "All parameters should have same dtype"
# Separate parameters as trainable and fixed
self.param_refs = [[p for p in all_param_refs if p.requires_grad],
[p for p in all_param_refs if not p.requires_grad]]
del all_param_refs
# The `rank = 0` node will calculate the size each device/node should store, and
# distribute the parameters accordingly.
if rank == 0:
# Merge and pad trainable (`merged_params[0]`) and fixed (`merged_params[1]`) parameters
merged_params = [self._merge_and_pad_params(ps) for ps in self.param_refs]
# Calculate the chunk sizes of trainable and fixed params
self.chunk_size = [(len(p) // world_size if p is not None else 0) for p in merged_params]
# Broadcast the sizes
dist.broadcast(torch.tensor(self.chunk_size, device=device), src=0)
else:
# Create an empty tensor to receive the sizes
chunk_size = torch.tensor([0, 0], device=device)
# Receive the sizes
dist.broadcast(chunk_size, src=0)
self.chunk_size = chunk_size.tolist()
# Create parameters for trainable (`self.chunk[0]`) and fixed (`self.chunk[1]`)
# parameters to be stored in current device/node
self.chunk = [nn.Parameter(self._empty((s,)), requires_grad=i == self.TRAINING_PARAMS_IDX)
for i, s in enumerate(self.chunk_size)]
# An empty tensor to receive the trainable and fixed parameters combined
chunk = self._empty((sum(self.chunk_size),))
if rank == 0:
# Concatenate both trainable and fixed params
all_params = torch.cat([p.view(world_size, -1) for p in merged_params], dim=-1).view(-1)
del merged_params
# Scatter them to all the nodes/devices
dist.scatter(chunk, list(all_params.split(sum(self.chunk_size))))
del all_params
else:
# Receive the parameters
dist.scatter(chunk)
# Collect the chunk data
chunk = chunk.split(self.chunk_size)
for i, c in enumerate(chunk):
self.chunk[i].data[:] = c
del chunk
# Cleanup the normal parameters
self._cleanup_params()
# Add a backward hook. This gets called when the gradients relative to the module are computed.
self._backward_hook_ref = self.register_full_backward_hook(self._backward_hook) # type: ignore
def _merge_and_pad_params(self, params: List[nn.Parameter]) -> torch.Tensor:
"""
#### Merge all the parameters and pad it so that it's divisible by `world_size`.
"""
# Total number of parameters
size = sum(p.shape.numel() for p in params)
# If it is not divisible by `world_size`, pad it
if size % self.world_size != 0:
padding_fixed = self.world_size - (size % self.world_size)
# Otherwise, no need to pad
else:
padding_fixed = 0
# Create an empty padding tensor
padding = self._empty((padding_fixed,))
# Concatenate all the parameters and pad it
return torch.cat([p.view(-1) for p in params] + [padding], dim=0)
def get_trainable_chunk(self) -> List[nn.Parameter]:
"""
### Get trainable chunk/shard of the parameters.
This is what we pass on to the optimizer on the current node.
"""
# Return and empty list if there are no trainable parameters
if len(self.chunk[self.TRAINING_PARAMS_IDX]) == 0:
return []
# Return the trainable chunk as a list
return [self.chunk[self.TRAINING_PARAMS_IDX]]
def _empty(self, shape: Tuple[int, ...]) -> torch.Tensor:
"""
#### Create an empty tensor of the given shape.
"""
return torch.empty(shape, device=self.device, dtype=self.dtype)
@torch.no_grad()
def _cleanup_params(self):
"""
#### Cleanup the parameter data
This will release all the memory used by the layer parameters.
"""
# Set the flag to indicate that the parameters are not fetched
self.is_fetched = False
# Iterate through all parameters
for ps in self.param_refs:
for p in ps:
# Wait for operations on the parameters to complete before any new operations
p.data.record_stream(torch.cuda.current_stream())
# Check to make sure the parameter is not sharing storage with anything else
assert p.data.storage_offset() == 0, "The tensor is not the sole occupant of the storage."
# Resize the storage to $0$. This will release the memory used by the parameter.
#
# **Setting `p.data` will not release the memory, since the autograd graph keeps a reference to it.**
p.data.storage().resize_(0) # This is what actually clears the memory
# Make sure the parameter has no gradient data
assert p.grad is None, 'Gradients should be None'
@torch.no_grad()
def fetch_params(self):
"""
### Fetch the parameters from all shards
This will fetch all the parameter data from all the nodes and rebuild the parameters on each node.
"""
# Skip is already fetched
if self.is_fetched:
return
# Set the flag
self.is_fetched = True
# Skip if there's nothing to fetch or share.
if sum(self.chunk_size) == 0:
return
# Use `fetch_stream` to fetch the parameters from all the shards
with torch.cuda.stream(self.fetch_stream):
# Create an empty tensor to receive the parameters
buffer = self._empty((self.world_size * sum(self.chunk_size),))
# Split the continuous buffer into the number of nodes. These splits are views of `buffer'.
buffers = list(buffer.split(sum(self.chunk_size)))
# Concatenate both trainable and fixed chunks
chunk = torch.cat(self.chunk, dim=0)
# Gather the parameters from all the nodes/devices
dist.all_gather(buffers, chunk)
# Split the gathered parameters into the trainable and fixed chunks
params = buffer.view(-1, sum(self.chunk_size)).split(self.chunk_size, dim=1)
# Wait for the gather operation to complete and then clear the references to the buffers
buffer.record_stream(self.fetch_stream)
for b in buffers:
b.record_stream(self.fetch_stream)
buffer.record_stream(self.fetch_stream)
del buffer
del buffers
# Reshape the trainable and fixed parameters to continuous tensors
params = [p.reshape(-1) for p in params]
# Collect the individual parameter tensors
for cont, ps in zip(params, self.param_refs):
# If there are no parameters, skip
if not ps:
continue
# Offset of the continuous tensor
offset = 0
# Iterate through model parameters and assign the values from the continuous tensor
for p in ps:
# Original parameter shape
shape = p._orig_shape # type: ignore[attr-defined]
# Change the storage size of the parameter. This was set to $0$ when we cleaned up the parameters.
p.data.storage().resize_(shape.numel())
# Assign the values from the continuous tensor
p.data[:] = cont[offset: offset + shape.numel()].reshape(shape)
# Wait for the operations to complete before other operations can be performed
p.data.record_stream(self.fetch_stream)
# Update the offset
offset += shape.numel()
# Wait for the operation to complete before other operations can be performed
cont.record_stream(self.fetch_stream)
#
del params
def forward(self, *args, **kwargs):
"""
### Forward pass
"""
# Fetch all the parameters of the current node.
# This gets called by the previous layer so this call is just to make sure parameters are fetched.
self.fetch_params()
# Wait for parameter fetching to complete.
torch.cuda.current_stream().wait_stream(self.fetch_stream)
# Start fetching parameters of the proceeding layers, so that they will fetch them which the current layer
# does its computations.
for layer in self.next_layer:
layer.fetch_params()
# Add backward hooks to the parameters of the current layer if autograd is enabled.
if torch.is_grad_enabled():
self._add_backward_hooks()
# Compute the outputs of the current layer
res = self.module(*args, **kwargs)
# Cleanup the parameters of the layer.
#
# *Skip cleaning up if autograd is enabled and this is the last layer in the network,
# because we will need to fetch the parameters again for the backward pass.*
if not torch.is_grad_enabled() or self.next_layer:
self._cleanup_params()
return res
def _add_backward_hooks(self):
"""
#### Add backward hooks to the parameters of the current layer.
"""
# Number of backward hooks added
self._backward_hook_handles = 0
# Loop through trainable parameters of the current layer
for p in self.param_refs[self.TRAINING_PARAMS_IDX]:
# Make sure a hook hasn't already been added
assert not hasattr(p, "_hook_handle"), 'Parameter has already been hooked'
# Use `expand_as` to create an autograd step which we can intercept
p_tmp = p.expand_as(p)
# Get a handle to add the backward hook.
# [This blog discusses about `grad_acc`](https://amsword.medium.com/understanding-pytorchs-autograd-with-grad-fn-and-next-functions-b2c4836daa00).
grad_acc = p_tmp.grad_fn.next_functions[0][0]
# Add the backward hook
handle = grad_acc.register_hook(
functools.partial(self._post_backward_hook, p))
# Keep a reference to the handle
p._hook_handle = handle
# Increment the number of hooks added
self._backward_hook_handles += 1
def _backward_event(self):
"""
#### Handle a backward event
This gets called by parameter backward hooks and the module backward hook.
"""
# Decrement the hooks counter
self._backward_hook_handles -= 1
# If all the hooks (including the module hook) have been called,
# then we can back up gradients and clean up the parameters.
if self._backward_hook_handles == -1:
self._backup_grads()
self._cleanup_params()
# Start fetch parameters of the previous layer, because autograd will next process the gradients of it.
for layer in self.prev_layer:
layer.fetch_params()
def _post_backward_hook(self, p: nn.Parameter, *args):
"""
#### Parameter backward hook
"""
# Remove the handle from the parameter
p._hook_handle.remove() # type: ignore[attr-defined]
delattr(p, "_hook_handle")
# Handle a backward event
self._backward_event()
def _backward_hook(self, *args, **kwargs):
"""
#### Module backward hook
"""
# Handle a backward event
self._backward_event()
# The previous layer will start computing gradients. We need to make sure it has finished fetching params.
torch.cuda.current_stream().wait_stream(self.fetch_stream)
#
return None
@torch.no_grad()
def _backup_grads(self):
"""
### Backup the gradients of the current layer
"""
# Skip if there are no trainable parameters
if self.chunk_size[self.TRAINING_PARAMS_IDX] == 0:
return
# Use the backup stream to backup the gradients
with torch.cuda.stream(self.backup_stream):
# Buffer to store the gradients
buffer = self._empty((self.world_size * self.chunk_size[self.TRAINING_PARAMS_IDX],))
# Split the continuous buffer into number of nodes. These splits are views of `buffer'.
buffers = list(buffer.split(self.chunk_size[self.TRAINING_PARAMS_IDX]))
# Offset of the continuous buffer
offset = 0
# Iterate through trainable parameters
for p in self.param_refs[self.TRAINING_PARAMS_IDX]:
# Collect gradients
shape = p._orig_shape # type: ignore[attr-defined]
buffer[offset: offset + shape.numel()] = p.grad.view(-1)
# Update the offset
offset += shape.numel()
# Clean the gradients
p.grad = None
# Empty tensor to accumulate the gradients of the current shard
grad = self._empty((self.chunk_size[self.TRAINING_PARAMS_IDX],))
# Accumulate the gradients of each shard. It scatters the buffers across the nodes,
# and each node accumulates (reduces) the tensors it receives.
dist.reduce_scatter(grad, buffers)
# Wait for the operation to complete and then clear the references to the buffers
for b in buffers:
b.record_stream(self.fetch_stream)
buffer.record_stream(self.fetch_stream)
del buffer
del buffers
# Set the chunk gradients. This is what the optimizer sees.
self.chunk[self.TRAINING_PARAMS_IDX].grad = grad
del grad
class Zero3Sequential(nn.Module):
"""
## Sequential module for `Zero3Layer` layers
"""
def __init__(self, modules: List[Zero3Layer]):
"""
:param modules: List of `Zero3Layer` layers
"""
super().__init__()
# CUDA stream to fetch parameters
self.fetch_stream = torch.cuda.Stream()
# CUDA stream to back up (accumulate) gradients
self.backup_stream = torch.cuda.Stream()
# Set the streams and preceding and proceeding layers for each `Zero3Layer` layer
for i in range(len(modules)):
# Set layer index
modules[i].layer_idx = i
# Set streams
modules[i].fetch_stream = self.fetch_stream
modules[i].backup_stream = self.backup_stream
# Set proceeding layers
if i + 1 < len(modules):
modules[i].next_layer.append(modules[i + 1])
# Set preceding layers
if i - 1 >= 0:
modules[i].prev_layer.append(modules[i - 1])
# Store list of modules
self.module_list = nn.ModuleList(modules)
def get_trainable_chunk(self):
# Return the list of trainable chunks from each layer
return sum([m.get_trainable_chunk() for m in self.module_list], [])
def forward(self, x: torch.Tensor):
# Make sure gradient back up is complete
torch.cuda.current_stream().wait_stream(self.backup_stream)
# Forward pass
for m in self.module_list:
x = m(x)
#
return x

View File

@ -0,0 +1,127 @@
"""
---
title: Finetune GPT-NeoX with Zero3 memory optimizer
summary: >
This script trains the bias parameters of the GPT-NeoX on multiple devices with Zero-DP Memory Optimization.
---
# Finetune [GPT-NeoX](../../neox/index.html) with [Zero3 memory optimizer](index.html)
This script trains the bias parameters of the [GPT-NeoX model](../../neox/model.html)
on multiple devices with Zero-DP Memory Optimization.
"""
import datetime
import torch
import torch.distributed
from labml import experiment, monit, tracker
from labml.configs import option
from labml.logger import inspect
from labml_nn.neox.samples.finetune import PipelineParallelTrainerConf
# Use the [Pipeline Parallel Trainer configurations](../../neox/samples/finetune.html) and adapt it for
# Zero3 memory optimizer.
class Configs(PipelineParallelTrainerConf):
rank: int
world_size: int
@option(Configs.optimizer, 'Zero3Adam')
def _optimizer(c: Configs):
"""
#### Set the optimizers for the model
Note that we pass the sharded parameters from `get_trainable_chunk`.
"""
from labml_nn.optimizers.adam_fp16 import AdamFP16
return AdamFP16(c.model.get_trainable_chunk(), lr=c.learning_rate)
@option(Configs.model, 'Zero3')
def _model(c: Configs):
"""
#### Create the model with Zero3 memory optimizer
"""
from labml_nn.scaling.zero3 import Zero3Layer, Zero3Sequential
# To make sure the fine tuner sets the trainable parameters
_ = c.fine_tuner
# Wrap the layers with `Zero3Layer`
modules = []
for m in monit.iterate('Zero3', c.layers):
modules.append(Zero3Layer(m.to(c.device),
c.rank, c.world_size, c.device, c.dtype))
# Create a sequential model
model = Zero3Sequential(modules)
#
return model
def main(rank: int, world_size: int, init_method: str = 'tcp://localhost:23456'):
"""
#### Run the training on the node with rank `rank`.
"""
# Initialize PyTorch distributed process group
with monit.section('Distributed'):
torch.distributed.init_process_group('nccl',
timeout=datetime.timedelta(seconds=30),
init_method=init_method,
rank=rank,
world_size=world_size)
# Set current device
device = torch.device(f'cuda:{rank}')
torch.cuda.set_device(device)
# Create the experiment
experiment.create(name='zero3_neox', writers={'screen', 'labml'})
experiment.distributed(rank, world_size)
# Create configurations
conf = Configs()
# Load configurations
experiment.configs(conf, {
'model': 'Zero3',
'optimizer': 'Zero3Adam',
'device': device,
'rank': rank,
'world_size': world_size,
'learning_rate': 3e-4,
'max_seq_len': 128,
'batch_size': 16,
})
# Start the experiment
with experiment.start():
# Initialize the model. Do this before the loop for cleaner logs.
_ = conf.model
# Train the model
for epoch in monit.loop(conf.epochs):
conf.train_epoch()
tracker.new_line()
#
if __name__ == '__main__':
# Log the machine configurations
inspect([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
inspect(
n_gpus=torch.cuda.device_count(),
mpi=torch.distributed.is_mpi_available(),
nccl=torch.distributed.is_nccl_available(),
)
n_gpu = torch.cuda.device_count()
# Start a process for each GPU. You will need a separate launcher if you are using multiple computers.
torch.multiprocessing.spawn(main, args=(n_gpu,), nprocs=n_gpu, join=True)