mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 09:38:56 +08:00
Zero3 memory optimizations (#140)
This commit is contained in:
@ -125,6 +125,11 @@ Solving games with incomplete information such as poker with CFR.
|
||||
* [Top-k Sampling](sampling/top_k.html)
|
||||
* [Nucleus Sampling](sampling/nucleus.html)
|
||||
|
||||
#### ✨ [Eleuther GPT-NeoX](neox/index.html)
|
||||
|
||||
#### ✨ [Scalable Training/Inference](scaling/index.html)
|
||||
* [Zero3 memory optimizations](scaling/zero3/index.html)
|
||||
|
||||
## Highlighted Research Paper PDFs
|
||||
|
||||
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
|
||||
|
||||
29
labml_nn/neox/__init__.py
Normal file
29
labml_nn/neox/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""
|
||||
---
|
||||
title: GPT-NeoX
|
||||
summary: >
|
||||
Simple GPT-NeoX implementation
|
||||
---
|
||||
|
||||
# GPT-NeoX
|
||||
|
||||
This is a simple implementation of [Eleuther GPT-NeoX](https://papers.labml.ai/paper/2204.06745) for inference and fine-tuning.
|
||||
|
||||
|
||||
* [Model definition](model.html)
|
||||
* [Tokenizer](tokenizer.html)
|
||||
* [Checkpoint downloading and loading helpers](checkpoint.html)
|
||||
* [Utilities](utils/index.html)
|
||||
|
||||
### [Samples](samples/__init__.py)
|
||||
|
||||
* [Generating text](samples/generate.html)
|
||||
* [Fine tuning the biases with pipeline-parallel](samples/finetune.html)
|
||||
|
||||
### [Evaluation](evaluation/__init__.py)
|
||||
|
||||
* [Evaluating half precision model on a single GPU](evaluation/half_precision.html)
|
||||
|
||||
**Official [Eleuther](https://www.eleuther.ai)
|
||||
GPT-NoeX is source code is available at [eleutherai/gpt-neox](https://github.com/eleutherai/gpt-neox).**
|
||||
"""
|
||||
152
labml_nn/neox/checkpoint.py
Normal file
152
labml_nn/neox/checkpoint.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""
|
||||
---
|
||||
title: GPT-NeoX Checkpoints
|
||||
summary: >
|
||||
Code to download checkpoints and helpers to load them.
|
||||
---
|
||||
|
||||
# GPT-NeoX Checkpoints
|
||||
|
||||
"""
|
||||
from typing import Dict, Union, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml import monit, lab, logger
|
||||
from labml.logger import Text, inspect
|
||||
from labml.utils.download import download_file
|
||||
|
||||
# Parent url
|
||||
CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
|
||||
# Download path
|
||||
|
||||
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
|
||||
if not CHECKPOINTS_DOWNLOAD_PATH.exists():
|
||||
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
|
||||
inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH)
|
||||
|
||||
|
||||
def get_files_to_download(n_layers: int = 44):
|
||||
"""
|
||||
### Get files to download
|
||||
|
||||
:return: a list of files to be downloaded
|
||||
"""
|
||||
layers = (
|
||||
# Embedding layer
|
||||
[0] +
|
||||
# Transformer layers
|
||||
list(range(2, 2 + n_layers)) +
|
||||
# Final normalization layer and readout layer
|
||||
[47, 48]
|
||||
)
|
||||
|
||||
return (
|
||||
# Vocabulary and configs
|
||||
['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +
|
||||
# Layer checkpoints
|
||||
[f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +
|
||||
# Empty states (not used)
|
||||
[f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
|
||||
)
|
||||
|
||||
|
||||
def download(n_layers: int = 44):
|
||||
"""
|
||||
## Download all checkpoint files
|
||||
"""
|
||||
|
||||
# Get files to download
|
||||
files = get_files_to_download(n_layers)
|
||||
|
||||
# Iterate
|
||||
for i, f in monit.enum('Download All', files):
|
||||
# Log
|
||||
logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
|
||||
# Download
|
||||
download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f)
|
||||
|
||||
|
||||
def load_checkpoint_files(files: Tuple[str, str]):
|
||||
"""
|
||||
### Load a pair of checkpoint files
|
||||
|
||||
:param files: pair of files to load
|
||||
:return: the loaded parameter tensors
|
||||
"""
|
||||
checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
|
||||
with monit.section('Load checkpoint'):
|
||||
data = [torch.load(checkpoint_path / f) for f in files]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
|
||||
p2: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
### Load a parameter by merging the partitions along first dimension
|
||||
|
||||
:param param: is the parameter
|
||||
:param key: is the name of the parameter
|
||||
:param p1: first partition dictionary
|
||||
:param p2: second partition dictionary
|
||||
"""
|
||||
w1, w2 = p1[key], p2[key]
|
||||
param.data[:w1.shape[0]] = w1
|
||||
param.data[w1.shape[0]:] = w2
|
||||
|
||||
|
||||
def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
|
||||
p2: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
### Load a parameter by merging the partitions along second dimension
|
||||
|
||||
:param param: is the parameter
|
||||
:param key: is the name of the parameter
|
||||
:param p1: first partition dictionary
|
||||
:param p2: second partition dictionary
|
||||
"""
|
||||
w1, w2 = p1[key], p2[key]
|
||||
param.data[:, :w1.shape[1]] = w1
|
||||
param.data[:, w1.shape[1]:] = w2
|
||||
|
||||
|
||||
def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
|
||||
p2: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
### Load an un-partitioned parameter
|
||||
|
||||
This does a sanity check to make use both partitions are the same
|
||||
|
||||
:param param: is the parameter
|
||||
:param key: is the name of the parameter
|
||||
:param p1: first partition dictionary
|
||||
:param p2: second partition dictionary
|
||||
"""
|
||||
w1, w2 = p1[key], p2[key]
|
||||
|
||||
diff = sum((w1 - w2) ** 2).item()
|
||||
assert diff < 1e-4, f'The partitions do not match: {key}'
|
||||
|
||||
param.data[:] = (w1 + w2) / 2.
|
||||
|
||||
|
||||
def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
|
||||
p2: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
### Load biases that are partitioned which gets added on reduce
|
||||
|
||||
:param param: is the parameter
|
||||
:param key: is the name of the parameter
|
||||
:param p1: first partition dictionary
|
||||
:param p2: second partition dictionary
|
||||
"""
|
||||
w1, w2 = p1[key], p2[key]
|
||||
|
||||
param.data[:] = w1 + w2
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
download()
|
||||
262
labml_nn/neox/evaluation/__init__.py
Normal file
262
labml_nn/neox/evaluation/__init__.py
Normal file
@ -0,0 +1,262 @@
|
||||
"""
|
||||
---
|
||||
title: Evaluation
|
||||
summary: >
|
||||
Code to evaluate the model on NLP tasks through lm-evaluation-harness
|
||||
---
|
||||
|
||||
# Evaluation
|
||||
|
||||
This is the code to test the model on
|
||||
[EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness).
|
||||
|
||||
* [Evaluating half precision model on a single GPU](half_precision.html)
|
||||
"""
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from lm_eval import tasks, evaluator, utils
|
||||
from lm_eval.base import BaseLM
|
||||
from tokenizers import Tokenizer
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from labml import monit
|
||||
from labml_nn.neox.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
class EvalHarnessAdapter(BaseLM):
|
||||
"""
|
||||
## Evaluation Harness Adapter
|
||||
|
||||
This is based on the [adapter from EleutherAI/gpt-neox](https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py)
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: Tokenizer, vocab_size: int, batch_size: int):
|
||||
"""
|
||||
:param tokenizer: is the [Huggingface Tokenizer](huggingface/tokenizers)
|
||||
:param vocab_size: is the size of the vocabulary
|
||||
(this differs from the tokenizer vocab size since neox adds some extra to make the embedding layer
|
||||
model parallel.)
|
||||
:param batch_size: is the batch size
|
||||
"""
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self._eot_token_id = self.tokenizer.token_to_id("<|endoftext|>")
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
self._batch_size = batch_size
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
raise RuntimeError()
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Size of the vocabulary"""
|
||||
return self._vocab_size
|
||||
|
||||
@property
|
||||
def eot_token_id(self):
|
||||
"""End-of-text token"""
|
||||
return self._eot_token_id
|
||||
|
||||
@property
|
||||
def max_length(self):
|
||||
"""Maximum sequence length"""
|
||||
return 2048
|
||||
|
||||
@property
|
||||
def max_gen_toks(self):
|
||||
"""Maximum number of tokens to generate"""
|
||||
return 128
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
"""
|
||||
Batch size
|
||||
"""
|
||||
return self._batch_size
|
||||
|
||||
def tok_encode(self, string: str):
|
||||
"""
|
||||
Encode a given text
|
||||
"""
|
||||
return self.tokenizer.encode(string).ids
|
||||
|
||||
def tok_decode(self, tokens: List[int]):
|
||||
"""
|
||||
Decode text from token ids
|
||||
"""
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
||||
def _model_call(self, inps: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def _model_generate(self, context, max_length, eos_token_id):
|
||||
raise RuntimeError()
|
||||
|
||||
def greedy_until(self, requests):
|
||||
raise RuntimeError()
|
||||
|
||||
@torch.no_grad()
|
||||
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
|
||||
"""
|
||||
### Get log-likelihoods of the next tokens
|
||||
|
||||
:param requests: List of requests containing the context and the expected continuation.
|
||||
:param disable_tqdm: If True, disable tqdm progress bar.
|
||||
"""
|
||||
|
||||
# For results
|
||||
res = []
|
||||
|
||||
# Reorder the requests in the descending order of the lengths,
|
||||
# so that sequences with similar lengths are close
|
||||
def _collate(x):
|
||||
toks = x[1] + x[2]
|
||||
return -len(toks), tuple(toks)
|
||||
|
||||
reord = utils.Reorderer(requests, _collate)
|
||||
|
||||
# Loop through requests with `batch_size` number of requests at a time
|
||||
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
|
||||
# To store the inputs for the batch
|
||||
inps = []
|
||||
# The continuations for the batch
|
||||
continuations = []
|
||||
# Lengths of the input sequences
|
||||
inplens = []
|
||||
# Padded length for the batch
|
||||
padded_length = None
|
||||
# Loop through each request in the chunk and collect them into PyTorch tensors with paddings
|
||||
for _, context_enc, continuation_enc in chunk:
|
||||
# Concatenate the context and continuation
|
||||
inp = context_enc + continuation_enc
|
||||
# Truncate from left if the size exceeds the `max_length`
|
||||
inp = inp[-(self.max_length + 1):]
|
||||
# Remove final token
|
||||
inp = inp[:-1]
|
||||
# Create a tensor
|
||||
inp = torch.tensor(inp, dtype=torch.long)
|
||||
# Input length
|
||||
inplen = inp.shape[0]
|
||||
|
||||
# Determine the padded length.
|
||||
# Shorter sequences will get padded.
|
||||
if padded_length is None:
|
||||
padded_length = int(math.ceil(inplen / 32)) * 32
|
||||
# padded_length = padded_length if padded_length is not None else inplen
|
||||
|
||||
# Padding
|
||||
padding = torch.zeros(padded_length - inplen, dtype=torch.long)
|
||||
|
||||
# Add padding
|
||||
inp = torch.cat([inp, padding], dim=0)
|
||||
|
||||
inps.append(inp)
|
||||
continuations.append(continuation_enc)
|
||||
inplens.append(inplen)
|
||||
|
||||
# Get model logits
|
||||
logits = self._model_call(torch.stack(inps))
|
||||
|
||||
# Get log softmaxes
|
||||
multi_logits = F.log_softmax(logits, dim=-1)
|
||||
|
||||
# Loop through the input/output pairs of the batch
|
||||
for logits, inplen, cont_toks in zip(multi_logits, inplens, continuations):
|
||||
# Get number of predicted tokens
|
||||
contlen = len(cont_toks)
|
||||
# Get logits of those
|
||||
logits = logits[inplen - contlen: inplen]
|
||||
# Get the tokens with the highest probabilities
|
||||
greedy_tokens = logits.argmax(dim=-1)
|
||||
# Get the target tokens
|
||||
cont_toks = torch.tensor(cont_toks, dtype=torch.long).to(logits.device)
|
||||
# Whether there's an exact match
|
||||
max_equal = (greedy_tokens == cont_toks).all()
|
||||
# Log-likelihoods of the target tokens
|
||||
logits = torch.gather(logits, 1, cont_toks[:, None])
|
||||
# Add the total log-likelihoods and whether there was a match to the results
|
||||
res.append((float(logits.sum()), bool(max_equal)))
|
||||
|
||||
# Re-order and return results
|
||||
return reord.get_original(res)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_eval(self, name: str, eval_tasks: List[str]):
|
||||
"""
|
||||
### Run given evaluations
|
||||
"""
|
||||
|
||||
# Run [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) evaluator
|
||||
results = evaluator.evaluate(lm=self, task_dict=tasks.get_task_dict(eval_tasks))
|
||||
|
||||
# Add configs
|
||||
results["config"] = {
|
||||
"name": name,
|
||||
}
|
||||
|
||||
#
|
||||
return results
|
||||
|
||||
|
||||
class NoeXEvalHarnessAdapter(EvalHarnessAdapter):
|
||||
"""
|
||||
## Evaluation Harness Adapter
|
||||
|
||||
This is based on the [adapter from EleutherAI/gpt-neox](https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py)
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, tokenizer: Tokenizer, vocab_size: int, batch_size: int, device: torch.device):
|
||||
"""
|
||||
:param model: is model
|
||||
:param tokenizer: is the [Huggingface Tokenizer](huggingface/tokenizers)
|
||||
:param vocab_size: is the size of the vocabulary
|
||||
(this differs from the tokenizer vocab size since neox adds some extra to make the embedding layer
|
||||
model parallel.)
|
||||
:param batch_size: is the batch size
|
||||
:param device: is the device of the model
|
||||
"""
|
||||
super().__init__(tokenizer, vocab_size, batch_size)
|
||||
self.model = model
|
||||
self._device = device
|
||||
|
||||
def _model_call(self, inps: torch.Tensor):
|
||||
"""
|
||||
Call the model
|
||||
"""
|
||||
return self.model(inps.to(self._device))
|
||||
|
||||
|
||||
def run_eval_harness(model: nn.Module, name: str, eval_tasks: List[str], device: torch.device, batch_size: int = 8):
|
||||
"""
|
||||
## Run evaluation harness with a given model
|
||||
"""
|
||||
|
||||
# Load the tokenizer
|
||||
with monit.section('Load tokenizer'):
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# All tasks if nothing is specified
|
||||
if not eval_tasks:
|
||||
eval_tasks = [
|
||||
"anli_r1",
|
||||
"anli_r2",
|
||||
"anli_r3",
|
||||
"hellaswag",
|
||||
"lambada",
|
||||
"piqa",
|
||||
"winogrande",
|
||||
"wsc",
|
||||
"mathqa",
|
||||
]
|
||||
|
||||
# Create the adapter
|
||||
adapter = NoeXEvalHarnessAdapter(model, tokenizer, 50_432, batch_size, device)
|
||||
|
||||
# Run
|
||||
return adapter.run_eval(name, eval_tasks)
|
||||
19
labml_nn/neox/evaluation/half_precision.py
Normal file
19
labml_nn/neox/evaluation/half_precision.py
Normal file
@ -0,0 +1,19 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml import monit
|
||||
from labml_nn.neox.evaluation import run_eval_harness
|
||||
from labml_nn.neox.model import LayerGenerator
|
||||
|
||||
if __name__ == '__main__':
|
||||
device = torch.device('cuda:0')
|
||||
layers = list(LayerGenerator(is_clone_layers=True,
|
||||
filter_layers=None,
|
||||
dtype=torch.float16,
|
||||
device=device
|
||||
).load())
|
||||
|
||||
with monit.section('Sequential'):
|
||||
model = nn.Sequential(*layers)
|
||||
|
||||
print(run_eval_harness(model, 'half_precision', ['lambada'], device))
|
||||
572
labml_nn/neox/model.py
Normal file
572
labml_nn/neox/model.py
Normal file
@ -0,0 +1,572 @@
|
||||
"""
|
||||
---
|
||||
title: GPT-NeoX Model Definition
|
||||
summary: >
|
||||
This is the model definition of GPT-NeoX.
|
||||
---
|
||||
|
||||
# GPT-NeoX Model
|
||||
|
||||
Here is the code for layers of GPT-NeoX model and the code to load
|
||||
20B checkpoint.
|
||||
|
||||
The method `load_state` in the layers load the checkpoints of that layer.
|
||||
The checkpoint loading helpers are on [`checkpoint.py`](checkpoint.html)
|
||||
"""
|
||||
import copy
|
||||
import math
|
||||
from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from labml import monit
|
||||
from labml_nn.neox import checkpoint
|
||||
from labml_nn.neox.utils.cache import get_cache
|
||||
|
||||
|
||||
class NeoXModule(nn.Module):
|
||||
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
||||
pass
|
||||
|
||||
|
||||
class Embedding(NeoXModule):
|
||||
"""
|
||||
## Embedding layer
|
||||
|
||||
This is a standard embeddings layer with code to load the checkpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
|
||||
"""
|
||||
:param n_vocab: is the size of the vocabulary
|
||||
:param n_hidden: is the size of the embeddings
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.emb = nn.Embedding(n_vocab, n_hidden)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: are the token ids of shape `[batch_size, seq_len]`
|
||||
"""
|
||||
return self.emb(x)
|
||||
|
||||
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
Code to load the checkpoint
|
||||
"""
|
||||
with monit.section('Load embedding layer'):
|
||||
checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
"""
|
||||
## Rotary Positional Embeddings
|
||||
|
||||
GPT-NeoX uses [rotary positional embeddings (RoPE)](https://papers.labml.ai/paper/2104.09864).
|
||||
|
||||
WE have annotated implementation of RoPE [here](https://nn.labml.ai/transformers/rope/index.html)
|
||||
with more notes the theory.
|
||||
"""
|
||||
|
||||
def __init__(self, d_rope: int, base: float = 10_000.):
|
||||
"""
|
||||
:param d_rope: is the number of features for RoPE embeddings
|
||||
:param base: is the base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$, which defaults to $10000$
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# To store $\theta_i$ for the features
|
||||
self.theta = None
|
||||
# Cache $\cos m\theta_i$ and $\sin m\theta_i$
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
|
||||
# Base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$
|
||||
self.base = base
|
||||
# Number of features for RoPE
|
||||
self.d_rope = d_rope
|
||||
|
||||
@staticmethod
|
||||
def rotate_half(x: torch.Tensor):
|
||||
"""
|
||||
### Rotate the features
|
||||
|
||||
$[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
|
||||
"""
|
||||
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def forward(self, x: torch.Tensor, offset: int = 0):
|
||||
"""
|
||||
:param x: has shape `[..., seq, n_heads, d_k]`
|
||||
:param offset: is the starting position of `x`. This is $\gt 0$ when we have
|
||||
cached the keys and queries of previous positions
|
||||
"""
|
||||
|
||||
# Get the actual sequence length
|
||||
seq_len = x.shape[-3] + offset
|
||||
|
||||
# Initialize $\theta$
|
||||
if self.theta is None:
|
||||
# $\theta_i = 10000^{\frac{2(i-1)}{d}}$
|
||||
theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
|
||||
self.theta = theta.to(x.device).to(x.dtype)
|
||||
|
||||
# Initialize $\cos m\theta_i$ and $\sin m\theta_i$ cache
|
||||
if (
|
||||
self.cos_cached is None or
|
||||
seq_len > self.cos_cached.shape[1] or
|
||||
self.cos_cached.device != x.device or
|
||||
self.cos_cached.dtype != x.dtype
|
||||
):
|
||||
# Get position indexes $m$
|
||||
seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
|
||||
# $m \theta_i$
|
||||
idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
|
||||
# Concatenate so that for row $m$ we have
|
||||
#
|
||||
# $$[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$$
|
||||
idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
|
||||
|
||||
# Calculate $\cos m\theta_i$ and $\sin m\theta_i$ in fp32
|
||||
with autocast(enabled=False):
|
||||
idx_theta2 = idx_theta2.float()
|
||||
# Add head dimension
|
||||
self.cos_cached = idx_theta2.cos()[:, None, :]
|
||||
self.sin_cached = idx_theta2.sin()[:, None, :]
|
||||
|
||||
# Cache them
|
||||
self.cos_cached = self.cos_cached.to(x.dtype)
|
||||
self.sin_cached = self.sin_cached.to(x.dtype)
|
||||
|
||||
# Split the features. We apply RoPE to only `d_rope` features
|
||||
x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
|
||||
|
||||
# Get the sin and cos values from the cache
|
||||
cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
|
||||
|
||||
# RoPE embeddings
|
||||
#
|
||||
# \begin{align}
|
||||
# \begin{pmatrix}
|
||||
# x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
|
||||
# x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
|
||||
# \end{pmatrix} \\
|
||||
# \end{align}
|
||||
#
|
||||
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
||||
x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
|
||||
|
||||
# Concatenate with features that didn't get RoPE embeddings
|
||||
return torch.cat((x_rope, x_pass), dim=-1)
|
||||
|
||||
|
||||
class AttentionLayer(nn.Module):
|
||||
"""
|
||||
## Attention layer
|
||||
"""
|
||||
|
||||
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
|
||||
mask_fill: float = -10_000.0):
|
||||
"""
|
||||
:param n_hidden: the number of features in embeddings
|
||||
:param n_heads: the number of attention heads
|
||||
:param rope_percentage: percentage of features to add RoPE embeddings
|
||||
:param mask_fill: masking fill value for attention matrix
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.mask_fill = mask_fill
|
||||
|
||||
# Linear layer for query, key and value
|
||||
self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
|
||||
# Final linear layer
|
||||
self.output = nn.Linear(n_hidden, n_hidden)
|
||||
|
||||
# Number of features per head
|
||||
d_k = n_hidden // n_heads
|
||||
# RoPE embedding module
|
||||
self.rope = RoPE(int(d_k * rope_percentage))
|
||||
|
||||
# Attention scaling factor
|
||||
self.scale = 1 / math.sqrt(d_k)
|
||||
|
||||
# To cache causal mask
|
||||
self.causal_mask = None
|
||||
|
||||
# Attention softmax module
|
||||
self.softmax = nn.Softmax(dim=-2)
|
||||
|
||||
def _get_mask(self, attn: torch.Tensor):
|
||||
"""
|
||||
#### Calculate the causal mask
|
||||
|
||||
* `attn` has shape [batch_size, query_seq_len, key_seq_len, n_heads]
|
||||
"""
|
||||
|
||||
# Query and key lengths
|
||||
nq, nk = attn.shape[1:3]
|
||||
|
||||
# Create mask
|
||||
if (
|
||||
self.causal_mask is None or
|
||||
self.causal_mask.shape[0] != nq or
|
||||
self.causal_mask.shape[1] != nk or
|
||||
self.causal_mask.device != attn.device
|
||||
):
|
||||
self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
|
||||
|
||||
# Return from cache
|
||||
return self.causal_mask[None, :, :, None]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: has shape `[batch_size, seq_len, n_hidden]`
|
||||
"""
|
||||
# Get query, key and value embeddings (all concatenated).
|
||||
# The last dimension size will change from n_hidden -> `3 x n_hidden`
|
||||
qkv = self.qkv_lin(x)
|
||||
|
||||
# Split into heads by changing the shape to `[batch_size, seq_len, n_heads, 3 * d_k]`
|
||||
qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
|
||||
# Split into query, key and value each of shape `[batch_size, seq_len, n_heads, 3 * d_k]`
|
||||
q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
|
||||
|
||||
# If we are caching the states of previous tokens
|
||||
if get_cache().get('use_cache', False):
|
||||
# Get the state id's. We use to retrieve previous states and store the next states
|
||||
prev_state_id, next_state_id = get_cache().get('state_ids')
|
||||
# If there's cache
|
||||
if prev_state_id is not None:
|
||||
# Get the past keys and values. These will have shape `[batch_size, prev_seq_len, n_heads, d_k]`
|
||||
k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
|
||||
# Offset of the current embeddings
|
||||
offset = k_past.shape[1]
|
||||
|
||||
# Add RoPE embeddings
|
||||
q = self.rope(q, offset=offset)
|
||||
k = self.rope(k, offset=offset)
|
||||
|
||||
# Concatenate the past
|
||||
k = torch.cat([k_past, k], dim=1)
|
||||
v = torch.cat([v_past, v], dim=1)
|
||||
else:
|
||||
# Add RoPE embeddings
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
|
||||
# Save the current state
|
||||
get_cache().push(f'attn_kv_{next_state_id}', (k, v))
|
||||
else:
|
||||
# No cache - simply add RoPE embeddings
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
|
||||
# Disable auto-casting to fp16 for attention computation
|
||||
with autocast(enabled=False):
|
||||
if q.dtype == torch.float16:
|
||||
# Convert to fp32 if the current dtype is fp16
|
||||
attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
|
||||
else:
|
||||
# Do not cast for bfloat
|
||||
attn = torch.einsum('bihk,bjhk->bijh', q, k)
|
||||
|
||||
# Scale attention
|
||||
attn = attn * self.scale
|
||||
|
||||
# Get causal mask
|
||||
mask = self._get_mask(attn)
|
||||
# Apply mask
|
||||
attn.masked_fill_(mask, self.mask_fill)
|
||||
|
||||
# Attention softmax
|
||||
attn = self.softmax(attn)
|
||||
|
||||
# Get attention weighted values
|
||||
output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
|
||||
|
||||
# Reshape from `[batch_size, seq_len, n_heads, d_k] to `[batch_size, seq_len, n_hidden]`
|
||||
output = output.reshape(*x.shape)
|
||||
|
||||
# Final linear layer
|
||||
return self.output(output)
|
||||
|
||||
|
||||
class FFNLayer(nn.Module):
|
||||
"""
|
||||
## Feedforward Network
|
||||
"""
|
||||
|
||||
def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
|
||||
"""
|
||||
:param n_hidden: is the embedding size
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if not d_ff:
|
||||
d_ff = n_hidden * 4
|
||||
|
||||
# Expansion linear layer
|
||||
self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
|
||||
# GELU activation
|
||||
self.activation = nn.GELU()
|
||||
# Contraction linear layer
|
||||
self.dense_h4_h = nn.Linear(d_ff, n_hidden)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: has shape `[batch_size, seq_len, n_hidden]`
|
||||
"""
|
||||
x = self.dense_h_h4(x)
|
||||
x = self.activation(x)
|
||||
x = self.dense_h4_h(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerLayer(NeoXModule):
|
||||
"""
|
||||
## Transformer Layer
|
||||
"""
|
||||
|
||||
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64):
|
||||
"""
|
||||
:param n_hidden: is the embedding size
|
||||
:param n_heads: is the number of heads
|
||||
|
||||
*Out implementation doesn't include dropout*.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Layer normalization before attention
|
||||
self.pre_ln_attn = nn.LayerNorm(n_hidden)
|
||||
# Layer normalization before FFN
|
||||
self.pre_ln_ffn = nn.LayerNorm(n_hidden)
|
||||
|
||||
# Attention layer
|
||||
self.attention = AttentionLayer(n_hidden, n_heads)
|
||||
# FFN layer
|
||||
self.ffn = FFNLayer(n_hidden)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
|
||||
"""
|
||||
|
||||
# Residual connection
|
||||
residual = x
|
||||
# NeoX runs attention and feedforward network in parallel
|
||||
attn = self.attention(self.pre_ln_attn(x))
|
||||
ffn = self.ffn(self.pre_ln_ffn(x))
|
||||
# Add them and the residual connection
|
||||
return attn + ffn + residual
|
||||
|
||||
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
Code to load the checkpoint
|
||||
"""
|
||||
with monit.section('Load transformer layer'):
|
||||
# Attention output transform
|
||||
checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
|
||||
checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
|
||||
|
||||
# Attention query, key and value transform
|
||||
checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
|
||||
checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
|
||||
|
||||
# Layer norm before attention
|
||||
checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
|
||||
checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
|
||||
|
||||
# FFN second transform
|
||||
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
|
||||
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
|
||||
|
||||
# FFN first transform
|
||||
checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
|
||||
checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
|
||||
|
||||
# Layer norm before FFN
|
||||
checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
|
||||
checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
|
||||
|
||||
|
||||
class FinalNorm(NeoXModule):
|
||||
"""
|
||||
## Final normalization layer
|
||||
"""
|
||||
|
||||
def __init__(self, n_hidden: int = 6_144):
|
||||
"""
|
||||
:param n_hidden: is the embedding size
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.ln = nn.LayerNorm(n_hidden)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
|
||||
"""
|
||||
return self.ln(x)
|
||||
|
||||
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
Code to load the checkpoint
|
||||
"""
|
||||
with monit.section('Load final normalization layer'):
|
||||
checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
|
||||
checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
|
||||
|
||||
|
||||
class ReadoutLayer(NeoXModule):
|
||||
"""
|
||||
Readout layer
|
||||
"""
|
||||
|
||||
def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
|
||||
"""
|
||||
:param n_hidden: is the embedding size
|
||||
:param n_vocab: is the size of the vocabulary
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
|
||||
"""
|
||||
return self.linear(x)
|
||||
|
||||
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
Code to load the checkpoint
|
||||
"""
|
||||
with monit.section('Load final linear layer'):
|
||||
checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
|
||||
|
||||
|
||||
class LayerGenerator:
|
||||
pre_created_layers: Dict[Any, Optional[NeoXModule]]
|
||||
|
||||
def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
|
||||
n_layers: int = 44, n_heads: int = 64,
|
||||
filter_layers: Optional[Set] = None,
|
||||
is_clone_layers: bool = True,
|
||||
dtype: torch.dtype = torch.float,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
"""
|
||||
### Generator to create layers
|
||||
|
||||
The layers are generated in the same order as checkpoints.
|
||||
|
||||
It gives `None` when a layer is not available; we use the layer indices as NeoX and there are two
|
||||
transformation layers we don't need in our implementation.
|
||||
|
||||
:param n_vocab: is the number of tokens in the vocabulary
|
||||
:param n_hidden: is the number of features in the embeddings
|
||||
:param n_layers: is the number of transformer layers
|
||||
:param n_heads: is the number of attention heads
|
||||
:param filter_layers: are the set of layers to be used. All layers will be used if None.
|
||||
This is used to test smaller versions of the model with fewer layers
|
||||
:param is_clone_layers: specifies whether to clone the transformer layers (a bit faster)
|
||||
:param dtype: is the data type of the model
|
||||
:param device: is the device of the model
|
||||
:return: the layers as a generator
|
||||
"""
|
||||
if filter_layers is None:
|
||||
filter_layers = set(range(n_layers + 3))
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.n_hidden = n_hidden
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.filter_layers = filter_layers
|
||||
self.is_clone_layers = is_clone_layers
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.pre_created_layers = dict(
|
||||
transformer_layer=None,
|
||||
)
|
||||
|
||||
def _prepare_layer(self, layer: NeoXModule):
|
||||
layer = layer.to(self.device, self.dtype)
|
||||
return layer
|
||||
|
||||
def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
|
||||
if self.pre_created_layers[name] is None or not self.is_clone_layers:
|
||||
layer = creator()
|
||||
else:
|
||||
layer = copy.deepcopy(self.pre_created_layers[name])
|
||||
|
||||
layer: NeoXModule = self._prepare_layer(layer)
|
||||
|
||||
if self.pre_created_layers[name] is None:
|
||||
self.pre_created_layers[name] = layer
|
||||
|
||||
return layer
|
||||
|
||||
def _create_transformer_layer(self):
|
||||
return self._create_and_cache_layer(
|
||||
'transformer_layer',
|
||||
lambda: TransformerLayer(self.n_hidden, self.n_heads)
|
||||
)
|
||||
|
||||
def _create_embedding_layer(self):
|
||||
return Embedding(self.n_vocab, self.n_hidden)
|
||||
|
||||
def _create_final_norm_layer(self):
|
||||
return FinalNorm(self.n_hidden)
|
||||
|
||||
def _create_readout_layer(self):
|
||||
return ReadoutLayer(self.n_hidden, self.n_vocab)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
|
||||
# Embedding layer
|
||||
if 0 in self.filter_layers:
|
||||
with monit.section('Embedding layer'):
|
||||
layer = self._prepare_layer(self._create_embedding_layer())
|
||||
yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
|
||||
|
||||
# Transformer layers
|
||||
for i in range(self.n_layers):
|
||||
# Transformer layer
|
||||
if i + 1 in self.filter_layers:
|
||||
with monit.section(f'Transformer Layer {i}'):
|
||||
yield self._create_transformer_layer(), \
|
||||
(f'layer_{i + 2 :02d}-model_00-model_states.pt',
|
||||
f'layer_{i + 2 :02d}-model_01-model_states.pt')
|
||||
|
||||
# Final normalization layer
|
||||
if self.n_layers + 1 in self.filter_layers:
|
||||
with monit.section('Final norm layer'):
|
||||
layer = self._prepare_layer(self._create_final_norm_layer())
|
||||
yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
|
||||
|
||||
# Readout layer
|
||||
if self.n_layers + 2 in self.filter_layers:
|
||||
with monit.section('Readout layer'):
|
||||
layer = self._prepare_layer(self._create_readout_layer())
|
||||
yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
|
||||
|
||||
@property
|
||||
def total_layers(self):
|
||||
return self.n_layers + 3
|
||||
|
||||
@torch.no_grad()
|
||||
def load(self) -> Generator[NeoXModule, None, None]:
|
||||
with torch.no_grad():
|
||||
with monit.section("Layers"):
|
||||
for i, (layer, files) in enumerate(self.get_layers()):
|
||||
if files is not None:
|
||||
layer.load_state(*checkpoint.load_checkpoint_files(files))
|
||||
|
||||
monit.progress(min(0.99, (i + 1) / self.total_layers))
|
||||
yield layer
|
||||
12
labml_nn/neox/samples/__init__.py
Normal file
12
labml_nn/neox/samples/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""
|
||||
---
|
||||
title: Samples
|
||||
summary: >
|
||||
Samples for inference and fine-tuning
|
||||
---
|
||||
|
||||
# Samples
|
||||
|
||||
* [Generating text](generate.html)
|
||||
* [Fine tuning the biases with pipeline-parallel training](finetune.html)
|
||||
"""
|
||||
125
labml_nn/neox/samples/finetune.py
Normal file
125
labml_nn/neox/samples/finetune.py
Normal file
@ -0,0 +1,125 @@
|
||||
"""
|
||||
---
|
||||
title: Fine Tune GPT-NeoX
|
||||
summary: >
|
||||
Fine tune GPT-NeoX biases with Fairscale pipeline parallel module
|
||||
---
|
||||
|
||||
# Fine Tune GPT-NeoX
|
||||
|
||||
This shows how to fine tune GPT-NeoX with pipeline parallelism.
|
||||
"""
|
||||
|
||||
import fairscale
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
import torch.utils.data
|
||||
import typing
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from labml import experiment, monit, tracker, lab
|
||||
from labml.configs import option
|
||||
from labml.logger import inspect
|
||||
from labml_nn.neox.utils.text_dataset import get_training_data
|
||||
from labml_nn.neox.utils.finetune import FineTuneBiases
|
||||
from labml_nn.neox.model import LayerGenerator, NeoXModule
|
||||
from labml_nn.neox.utils import balance_layers_simple
|
||||
from labml_nn.neox.utils.trainer import PipelineParallelTrainerConf
|
||||
|
||||
|
||||
@option(PipelineParallelTrainerConf.layers, 'PipelineBiases')
|
||||
def neox_layers(c: PipelineParallelTrainerConf):
|
||||
"""
|
||||
### Load GPT-NeoX layers
|
||||
"""
|
||||
return list(LayerGenerator(is_clone_layers=c.is_clone_layers,
|
||||
filter_layers=c.filter_layers,
|
||||
dtype=c.dtype,
|
||||
).load())
|
||||
|
||||
|
||||
@option(PipelineParallelTrainerConf.fine_tuner, 'PipelineBiases')
|
||||
def fine_tune_biases(c: PipelineParallelTrainerConf):
|
||||
"""
|
||||
### Create fine tuner for biases
|
||||
"""
|
||||
|
||||
fine_tuner = FineTuneBiases(typing.cast(typing.List[NeoXModule], c.layers))
|
||||
# Mark biases as trainable
|
||||
fine_tuner.set_trainable_params()
|
||||
|
||||
#
|
||||
return fine_tuner
|
||||
|
||||
|
||||
@option(PipelineParallelTrainerConf.model, 'PipelineBiases')
|
||||
def pipe_model(c: PipelineParallelTrainerConf):
|
||||
"""
|
||||
### Create pipeline parallel model
|
||||
"""
|
||||
|
||||
if c.is_checkpointing:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
layers = c.layers
|
||||
|
||||
# Create the Pipe module
|
||||
with monit.section('Pipe'):
|
||||
# Get the layer distribution across GPUs
|
||||
balance = balance_layers_simple(len(layers), c.n_gpus)
|
||||
inspect(balance=balance)
|
||||
# Devices for each GPU
|
||||
devices = [torch.device(f'cuda:{i}') for i in range(c.n_gpus)]
|
||||
# Create Fairscale Pipe module
|
||||
pipe_model = fairscale.nn.Pipe(nn.Sequential(*layers),
|
||||
balance=balance,
|
||||
devices=devices,
|
||||
chunks=c.chunks)
|
||||
|
||||
#
|
||||
return pipe_model
|
||||
|
||||
|
||||
@option(PipelineParallelTrainerConf.train_loader)
|
||||
def tiny_shakespeare(c: PipelineParallelTrainerConf):
|
||||
"""
|
||||
#### Tiny Shakespeare dataset
|
||||
"""
|
||||
dataset = get_training_data(c.max_seq_len)
|
||||
|
||||
return DataLoader(dataset,
|
||||
batch_size=c.batch_size,
|
||||
sampler=RandomSampler(dataset, replacement=True))
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='pipe_neox_biases',
|
||||
writers={'screen', 'web_api'})
|
||||
|
||||
# Initialize configs
|
||||
conf = PipelineParallelTrainerConf()
|
||||
experiment.configs(conf, {
|
||||
'learning_rate': 3e-4,
|
||||
'is_checkpointing': False,
|
||||
'max_seq_len': 128,
|
||||
'batch_size': 64,
|
||||
'chunks': 8,
|
||||
})
|
||||
|
||||
# Start the experiment
|
||||
with experiment.start():
|
||||
# Initialize the model. Do this before the loop for cleaner logs.
|
||||
_ = conf.model
|
||||
|
||||
# Train
|
||||
for epoch in monit.loop(conf.epochs):
|
||||
conf.train_epoch()
|
||||
tracker.new_line()
|
||||
torch.save(conf.fine_tuner.state_dict(), str(lab.get_data_path() / 'fine_tune.pt'))
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
102
labml_nn/neox/samples/generate.py
Normal file
102
labml_nn/neox/samples/generate.py
Normal file
@ -0,0 +1,102 @@
|
||||
"""
|
||||
---
|
||||
title: Generate Text with GPT-NeoX
|
||||
summary: >
|
||||
Generate Text with GPT-NeoX
|
||||
---
|
||||
|
||||
# Generate Text with GPT-NeoX
|
||||
|
||||
This shows how to generate text from GPT-NeoX with a single GPU.
|
||||
|
||||
This needs a GPU with more than 45GB memory.
|
||||
"""
|
||||
|
||||
# Imports
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml import monit
|
||||
from labml_nn.neox.model import LayerGenerator
|
||||
from labml_nn.neox.utils import get_tokens, print_tokens
|
||||
from labml_nn.neox.utils.cache import get_cache
|
||||
|
||||
# List of layers to load. This is used for testing.
|
||||
# You can assign a subset of layers like `{0, 1}` so that it only loads
|
||||
# the first to transformer layers.
|
||||
LAYERS = None
|
||||
|
||||
# Prompt to complete
|
||||
PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'
|
||||
|
||||
|
||||
def infer(model: nn.Module, ids: List[int], device: torch.device):
|
||||
"""
|
||||
### Predict the next token
|
||||
|
||||
:param model: is the model
|
||||
:param ids: are the input token ids
|
||||
:param device: is the device of the model
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
# Get the tokens
|
||||
x = torch.tensor(ids)[None, :].to(device)
|
||||
# Eval model
|
||||
x = model(x)
|
||||
|
||||
# Return predicted token
|
||||
return x[0].max(dim=-1)[1].tolist()
|
||||
|
||||
|
||||
def generate():
|
||||
"""
|
||||
## Generate text
|
||||
"""
|
||||
|
||||
# Setup [cache](../utils/cache.html) to cache intermediate key/value pairs for faster generation
|
||||
cache = get_cache()
|
||||
cache.set('use_cache', True)
|
||||
|
||||
# Device
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
# Load layers
|
||||
layers = list(LayerGenerator(is_clone_layers=True,
|
||||
filter_layers=LAYERS,
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).load())
|
||||
|
||||
model = nn.Sequential(*layers)
|
||||
|
||||
# Get token ids
|
||||
ids = get_tokens(PROMPT)
|
||||
|
||||
# Run the model
|
||||
cache.set('state_ids', (None, 1))
|
||||
with monit.section('Infer'):
|
||||
next_token = infer(model, ids, device)[-1]
|
||||
|
||||
# Append the predicted token
|
||||
ids += [next_token]
|
||||
|
||||
# Predict 100 tokens
|
||||
for i in range(1, 100):
|
||||
# Set the state to use cached activations
|
||||
cache.set('state_ids', (i, i + 1))
|
||||
# Get next token. Note that we only feed the last token to the model because
|
||||
# we cache the key/value pairs of previous tokens.
|
||||
with monit.section('Infer'):
|
||||
next_token = infer(model, [next_token], device)[-1]
|
||||
# Append the predicted token
|
||||
ids += [next_token]
|
||||
# Print
|
||||
print_tokens(ids, [ids])
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
generate()
|
||||
28
labml_nn/neox/tokenizer.py
Normal file
28
labml_nn/neox/tokenizer.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""
|
||||
---
|
||||
title: GPT-NeoX Tokenizer
|
||||
summary: >
|
||||
Loads the GPT-NeoX tokenizer
|
||||
---
|
||||
|
||||
# GPT-NeoX Tokenizer
|
||||
|
||||
This initializes a Hugging Face tokenizer from the downloaded vocabulary.
|
||||
"""
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from labml import lab, monit
|
||||
|
||||
|
||||
@monit.func('Load NeoX Tokenizer')
|
||||
def get_tokenizer() -> Tokenizer:
|
||||
"""
|
||||
### Load NeoX Tokenizer
|
||||
|
||||
:return: the tokenizer
|
||||
"""
|
||||
vocab_file = lab.get_data_path() / 'neox' / 'slim_weights' / '20B_tokenizer.json'
|
||||
tokenizer = Tokenizer.from_file(str(vocab_file))
|
||||
|
||||
return tokenizer
|
||||
134
labml_nn/neox/utils/__init__.py
Normal file
134
labml_nn/neox/utils/__init__.py
Normal file
@ -0,0 +1,134 @@
|
||||
"""
|
||||
---
|
||||
title: Utilities and Helpers
|
||||
summary: >
|
||||
Utilities and helper functions
|
||||
---
|
||||
|
||||
# Utilities and Helpers
|
||||
|
||||
* [Cache for intermediate activations (for faster inference)](cache.html)
|
||||
* [Tools for finetuning](finetune.html)
|
||||
* [Trainer](trainer.html)
|
||||
* [Text dataset](text_dataset.html)
|
||||
"""
|
||||
import typing
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from labml import logger
|
||||
from labml.logger import Text
|
||||
from labml_nn.neox.tokenizer import get_tokenizer
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
# Tokenizer singleton
|
||||
_TOKENIZER: Optional['Tokenizer'] = None
|
||||
|
||||
|
||||
def get_tokens(text: str) -> List[int]:
|
||||
"""
|
||||
### Get token ids
|
||||
|
||||
:param text: is the text to tokenize
|
||||
:return: the token ids
|
||||
"""
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER is None:
|
||||
_TOKENIZER = get_tokenizer()
|
||||
return _TOKENIZER.encode_batch([text])[0].ids
|
||||
|
||||
|
||||
def print_token_outputs(ids: List[int], *xs: torch.Tensor):
|
||||
"""
|
||||
### Print tokens from model outputs
|
||||
|
||||
Pretty prints target tokens along side outputs from the model(s).
|
||||
|
||||
:param ids: are the target token ids
|
||||
:param xs: are the model(s) outputs
|
||||
"""
|
||||
ids = ids + [-1]
|
||||
xs = [[-1] + x[0].max(dim=-1)[1].tolist() for x in xs]
|
||||
|
||||
print_tokens(ids, xs)
|
||||
|
||||
|
||||
def print_tokens(target: List[int], others: List[List[int]]):
|
||||
"""
|
||||
### Print tokens
|
||||
|
||||
Pretty prints tokens for comparison
|
||||
|
||||
:param target: are the target token ids
|
||||
:param others: are the sampled outputs from the model(s)
|
||||
"""
|
||||
|
||||
# Load tokenizer
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER is None:
|
||||
_TOKENIZER = get_tokenizer()
|
||||
|
||||
# Convert the tokens to list of strings
|
||||
text = []
|
||||
for i in range(len(target)):
|
||||
tokens = [_TOKENIZER.decode([target[i]]) if target[i] != -1 else '---']
|
||||
for j in range(len(others)):
|
||||
tokens.append(_TOKENIZER.decode([others[j][i]]) if others[j][i] != -1 else '---')
|
||||
|
||||
text.append(tokens)
|
||||
|
||||
# Stats
|
||||
correct = [0 for _ in others]
|
||||
total = 0
|
||||
|
||||
# Iterate through tokens
|
||||
for i in range(len(target)):
|
||||
parts = [(f'{i}: ', Text.meta)]
|
||||
parts += [('"', Text.subtle), (text[i][0], Text.subtle), ('"', Text.subtle), '\t']
|
||||
|
||||
# Empty target
|
||||
if target[i] == -1:
|
||||
for j in range(len(others)):
|
||||
parts += [('"', Text.subtle), (text[i][j + 1], Text.subtle), ('"', Text.subtle), '\t']
|
||||
|
||||
logger.log(parts)
|
||||
continue
|
||||
|
||||
# Number of tokens
|
||||
total += 1
|
||||
|
||||
# Other outputs
|
||||
for j in range(len(others)):
|
||||
correct[j] += 1 if others[j][i] == target[i] else 0
|
||||
|
||||
parts += [('"', Text.subtle),
|
||||
(text[i][j + 1], Text.success if others[j][i] == target[i] else Text.danger),
|
||||
('"', Text.subtle), '\t']
|
||||
|
||||
logger.log(parts)
|
||||
|
||||
# Stats
|
||||
parts = [(f'{total}', Text.highlight), '\t']
|
||||
for j in range(len(others)):
|
||||
parts += [(f'{correct[j]}', Text.value), '\t']
|
||||
logger.log(parts)
|
||||
|
||||
|
||||
def balance_layers_simple(n_layers: int, n_chunks: int):
|
||||
"""
|
||||
### Balance layers
|
||||
|
||||
Split the `n_layers` into `n_chunks`. This is used for pipeline parallel training.
|
||||
|
||||
:param n_layers: is the number of layers
|
||||
:param n_chunks: is the number of chunks
|
||||
:return: returns a list with the number of layers for each chunk
|
||||
"""
|
||||
balance = []
|
||||
for i in range(n_chunks):
|
||||
balance.append((n_layers - sum(balance)) // (n_chunks - i))
|
||||
|
||||
return list(reversed(balance))
|
||||
118
labml_nn/neox/utils/cache.py
Normal file
118
labml_nn/neox/utils/cache.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""
|
||||
---
|
||||
title: Cache for Intermediate Activations
|
||||
summary: >
|
||||
Cache for intermediate activations for faster inference.
|
||||
---
|
||||
|
||||
# Cache for Intermediate Activations
|
||||
|
||||
During inference the model outputs token by token.
|
||||
We use this simple cache to store key's and value's attention layers,
|
||||
so that we don't have to recompute them for previous tokens.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Cache:
|
||||
"""
|
||||
## Cache
|
||||
|
||||
This maintains a key-value cache and queues push values and pop them in the same order.
|
||||
The queues are useful since we have multiple attention layers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache = {}
|
||||
|
||||
def clear_all(self):
|
||||
"""
|
||||
### Clear cache
|
||||
"""
|
||||
self._cache = {}
|
||||
|
||||
def push(self, name: str, value: Any):
|
||||
"""
|
||||
### Push a value to a queue
|
||||
|
||||
:param name: is the name of the queue
|
||||
:param value: is the value to be pushed
|
||||
"""
|
||||
|
||||
# Create an empty queue if it's not present
|
||||
if name not in self._cache:
|
||||
self._cache[name] = []
|
||||
|
||||
# Push to the queue
|
||||
self._cache[name].append(value)
|
||||
|
||||
def q_size(self, name):
|
||||
"""
|
||||
### Return the size of the queue
|
||||
|
||||
:param name: is the name of the queue
|
||||
:return: size of the queue if exists else None
|
||||
"""
|
||||
|
||||
if name not in self._cache:
|
||||
return None
|
||||
|
||||
if type(self._cache[name]) != list:
|
||||
return None
|
||||
|
||||
return len(self._cache[name])
|
||||
|
||||
def pop(self, name: str):
|
||||
"""
|
||||
### Pop from a queue
|
||||
|
||||
:param name: is the name of the queue
|
||||
:return: the value
|
||||
"""
|
||||
return self._cache[name].pop(0)
|
||||
|
||||
def set(self, key: str, value: Any):
|
||||
"""
|
||||
### Cache a value
|
||||
|
||||
:param key: is the name of the value to be cached
|
||||
:param value: is the value
|
||||
"""
|
||||
self._cache[key] = value
|
||||
|
||||
def get(self, key: str, default: Any = None):
|
||||
"""
|
||||
### Retrieve a value from cache
|
||||
|
||||
:param key: is the name used when caching
|
||||
:param default: is the default value if the cache is empty
|
||||
:return: the cached value
|
||||
"""
|
||||
return self._cache.get(key, default)
|
||||
|
||||
def clear(self, key: str):
|
||||
"""
|
||||
### Clear a cache value
|
||||
|
||||
:param key: is the name used when caching
|
||||
"""
|
||||
del self._cache[key]
|
||||
|
||||
|
||||
# Singleton for cache
|
||||
_INSTANCE = None
|
||||
|
||||
|
||||
def get_cache() -> Cache:
|
||||
"""
|
||||
### Get the cache instance
|
||||
|
||||
:return: the cache instance
|
||||
"""
|
||||
global _INSTANCE
|
||||
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = Cache()
|
||||
|
||||
return _INSTANCE
|
||||
55
labml_nn/neox/utils/finetune.py
Normal file
55
labml_nn/neox/utils/finetune.py
Normal file
@ -0,0 +1,55 @@
|
||||
from typing import List, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_nn.neox.model import TransformerLayer, NeoXModule
|
||||
|
||||
|
||||
class FineTuner:
|
||||
def __init__(self, layers: List[NeoXModule]):
|
||||
self.layers = layers
|
||||
|
||||
def get_trainable_params(self) -> Dict[str, nn.Parameter]:
|
||||
params = {}
|
||||
for i, layer in enumerate(self.layers):
|
||||
params.update(self.get_layer_trainable_params(layer, prefix=f'layer_{i :02d}'))
|
||||
|
||||
return params
|
||||
|
||||
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_trainable_params(self):
|
||||
for layer in self.layers:
|
||||
# Set `requires_grad` to `False` for the entire layer.
|
||||
layer.requires_grad_(False)
|
||||
#
|
||||
for p in self.get_trainable_params().values():
|
||||
p.requires_grad_(True)
|
||||
|
||||
def state_dict(self):
|
||||
return {n: p.data.cpu() for n, p in self.get_trainable_params().items()}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
|
||||
params = self.get_trainable_params()
|
||||
for n, p in params.items():
|
||||
p.data[:] = state_dict[n].to(p.data.device)
|
||||
|
||||
for n in state_dict.keys():
|
||||
assert n in params, n
|
||||
|
||||
|
||||
class FineTuneBiases(FineTuner):
|
||||
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
|
||||
params = {}
|
||||
|
||||
if isinstance(layer, TransformerLayer):
|
||||
# No need to train the mlp bias because we are adding it with attention output
|
||||
params[f'{prefix}.attention.output.bias'] = layer.attention.output.bias
|
||||
params[f'{prefix}.attention.qkv_lin.bias'] = layer.attention.qkv_lin.bias
|
||||
params[f'{prefix}.ffn.dense_h_h4.bias'] = layer.ffn.dense_h_h4.bias
|
||||
else:
|
||||
pass
|
||||
|
||||
return params
|
||||
132
labml_nn/neox/utils/text_dataset.py
Normal file
132
labml_nn/neox/utils/text_dataset.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""
|
||||
---
|
||||
title: Text Dataset for GPT-NeoX
|
||||
summary: >
|
||||
Loads text datasets to fine-tune GPT-NeoX
|
||||
---
|
||||
|
||||
# Text Dataset for GPT-NeoX
|
||||
"""
|
||||
from pathlib import PurePath, Path
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from labml import lab
|
||||
from labml import monit
|
||||
from labml.logger import inspect
|
||||
from labml.utils.download import download_file
|
||||
|
||||
from labml_nn.neox.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def load_text(path: PurePath, url: Optional[str] = None, *, filter_subset: Optional[int] = None):
|
||||
"""
|
||||
### Load text file
|
||||
|
||||
:param path: is the location of the text file
|
||||
:param url: is the URL to download the file from
|
||||
:param filter_subset: is the number of characters to filter.
|
||||
Use this during testing when trying large datasets
|
||||
:return: the text content
|
||||
"""
|
||||
|
||||
path = Path(path)
|
||||
|
||||
# Download if it doesn't exist
|
||||
if not path.exists():
|
||||
if not url:
|
||||
raise FileNotFoundError(str(path))
|
||||
else:
|
||||
download_file(url, path)
|
||||
|
||||
with monit.section("Load data"):
|
||||
# Load data
|
||||
with open(str(path), 'r') as f:
|
||||
text = f.read()
|
||||
# Filter
|
||||
if filter_subset:
|
||||
text = text[:filter_subset]
|
||||
|
||||
#
|
||||
return text
|
||||
|
||||
|
||||
class NeoXDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
## Dataset for fine-tuning GPT-NeoX
|
||||
|
||||
This is not optimized to very large datasets.
|
||||
"""
|
||||
|
||||
def __init__(self, tokens: List[int], seq_len: int):
|
||||
"""
|
||||
:param tokens: is the list of token ids
|
||||
:param seq_len: is the sequence length of a single training sample
|
||||
"""
|
||||
|
||||
self.seq_len = seq_len
|
||||
# Number of samples
|
||||
n_samples = len(tokens) // seq_len
|
||||
self.n_samples = n_samples
|
||||
# Truncate
|
||||
tokens = tokens[:n_samples * seq_len + 1]
|
||||
# Create a PyTorch tensor
|
||||
self.tokens = torch.tensor(tokens)
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
"""
|
||||
### Get a sample
|
||||
|
||||
:param idx: is the index of the sample
|
||||
:return: the input and the target
|
||||
"""
|
||||
offset = idx * self.seq_len
|
||||
return self.tokens[offset:offset + self.seq_len], self.tokens[offset + 1:offset + 1 + self.seq_len]
|
||||
|
||||
|
||||
DATASETS = {
|
||||
'tiny_shakespeare': {
|
||||
'file': 'tiny_shakespeare.txt',
|
||||
'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_training_data(seq_len: int = 32, dataset_name: str = 'tiny_shakespeare', truncate: int = -1):
|
||||
"""
|
||||
### Load Dataset
|
||||
|
||||
:param seq_len: is the sequence length of a single training sample
|
||||
:param dataset_name: is the name of the dataset
|
||||
:return: the dataset
|
||||
"""
|
||||
|
||||
ds = DATASETS[dataset_name]
|
||||
# Load the content
|
||||
text = load_text(lab.get_data_path() / ds['file'], ds['url'])
|
||||
# Tokenize
|
||||
tokenizer = get_tokenizer()
|
||||
tokens = tokenizer.encode_batch([text])[0]
|
||||
|
||||
if truncate > 0:
|
||||
token_ids = tokens.ids[:truncate * seq_len]
|
||||
else:
|
||||
token_ids = tokens.ids
|
||||
|
||||
#
|
||||
return NeoXDataset(token_ids, seq_len)
|
||||
|
||||
|
||||
def _test():
|
||||
dataset = get_training_data()
|
||||
|
||||
inspect(tokens=len(dataset.tokens))
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
_test()
|
||||
182
labml_nn/neox/utils/trainer.py
Normal file
182
labml_nn/neox/utils/trainer.py
Normal file
@ -0,0 +1,182 @@
|
||||
from typing import Optional, Set, List
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
from torch.cuda import amp
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from labml import monit, tracker
|
||||
from labml.configs import BaseConfigs, option
|
||||
from labml_nn.neox.utils.finetune import FineTuner
|
||||
|
||||
|
||||
def get_trainable_params(model: nn.Module):
|
||||
"""
|
||||
### Get trainable parameters
|
||||
|
||||
:param model: is the model to train
|
||||
:return: a list of parameters for training
|
||||
"""
|
||||
|
||||
# Get all parameters
|
||||
params = list(model.parameters())
|
||||
# Filter parameters that require gradients
|
||||
trainable_params = [p for p in params if p.requires_grad]
|
||||
|
||||
#
|
||||
return trainable_params
|
||||
|
||||
|
||||
class TrainerConf(BaseConfigs):
|
||||
model: nn.Module
|
||||
layers: List[nn.Module]
|
||||
optimizer: torch.optim.Optimizer = 'Adam'
|
||||
train_loader: torch.utils.data.DataLoader
|
||||
valid_loader: Optional[torch.utils.data.DataLoader] = None,
|
||||
device: torch.device = torch.device('cuda:0')
|
||||
scaler: Optional[GradScaler] = 'Default'
|
||||
is_amp: bool = True
|
||||
dtype: torch.dtype = torch.float16
|
||||
|
||||
is_clone_layers: bool = True
|
||||
|
||||
loss_func: nn.Module = nn.CrossEntropyLoss()
|
||||
checkpoints_per_epoch: int = 0
|
||||
samples_per_epoch: int = 0
|
||||
|
||||
grad_norm: Optional[float] = 1.0
|
||||
learning_rate: float = 3e-4
|
||||
max_seq_len: int = 1024
|
||||
batch_size: int = 64
|
||||
epochs: int = 16
|
||||
|
||||
n_gpus: int = torch.cuda.device_count()
|
||||
|
||||
filter_layers: Optional[Set] = None
|
||||
|
||||
def get_loss(self, sample, dataset_split: str):
|
||||
"""
|
||||
:param dataset_split: train/valid
|
||||
:param sample: is the sample
|
||||
:return: the loss, output and the target
|
||||
"""
|
||||
data, target = sample
|
||||
|
||||
# Forward pass
|
||||
with monit.section('Forward pass'):
|
||||
output = self.model(data.to(self.device))
|
||||
# Move targets to the same device as output
|
||||
target = target.to(output.device)
|
||||
# Calculate loss
|
||||
loss = self.loss_func(output.view(target.numel(), -1), target.view(-1))
|
||||
|
||||
return loss, output, target
|
||||
|
||||
def train(self):
|
||||
for epoch in monit.loop(self.epochs):
|
||||
self.train_epoch()
|
||||
tracker.new_line()
|
||||
|
||||
def sample(self, idx):
|
||||
pass
|
||||
|
||||
def save_checkpoint(self, idx):
|
||||
pass
|
||||
|
||||
def get_iterators(self):
|
||||
# Iterate through the batches
|
||||
iterators = [('train', self.train_loader)]
|
||||
if self.valid_loader is not None:
|
||||
iterators.append(('valid', self.valid_loader))
|
||||
|
||||
if self.samples_per_epoch > 0:
|
||||
iterators.append((self.sample, [i for i in range(self.samples_per_epoch)]))
|
||||
|
||||
if self.checkpoints_per_epoch > 0:
|
||||
iterators.append((self.save_checkpoint, [i for i in range(self.checkpoints_per_epoch)]))
|
||||
|
||||
return iterators
|
||||
|
||||
def train_epoch(self):
|
||||
# Set model for train
|
||||
self.model.train()
|
||||
|
||||
iterators = self.get_iterators()
|
||||
for split_name, sample in monit.mix(1024, *iterators):
|
||||
if split_name == 'train':
|
||||
# Set gradients to zero
|
||||
self.optimizer.zero_grad()
|
||||
tracker.add_global_step()
|
||||
|
||||
with torch.set_grad_enabled(split_name == 'train'):
|
||||
if self.is_amp:
|
||||
# Forward pass
|
||||
with amp.autocast():
|
||||
loss, output, target = self.get_loss(sample, split_name)
|
||||
else:
|
||||
loss, output, target = self.get_loss(sample, split_name)
|
||||
|
||||
# Get predictions
|
||||
pred = output.argmax(dim=-1)
|
||||
# Calculate accuracy
|
||||
accuracy = pred.eq(target).sum().item() / (target != -100).sum()
|
||||
|
||||
tracker.add({f'loss.{split_name}': loss, f'acc.{split_name}': accuracy * 100})
|
||||
|
||||
if split_name == 'train':
|
||||
if self.scaler is not None:
|
||||
# Backward pass
|
||||
loss = self.scaler.scale(loss)
|
||||
# tracker.add({'loss.scaled': loss})
|
||||
|
||||
with monit.section('Backward pass'):
|
||||
loss.backward()
|
||||
|
||||
# Optimize
|
||||
with monit.section('Optimize'):
|
||||
if self.scaler is None:
|
||||
self.optimizer.step()
|
||||
else:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
if self.grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(get_trainable_params(self.model), self.grad_norm)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
tracker.save()
|
||||
|
||||
|
||||
@option(TrainerConf.optimizer, 'Adam')
|
||||
def adam_optimizer(c: TrainerConf):
|
||||
if c.dtype == torch.float32:
|
||||
return torch.optim.Adam(get_trainable_params(c.model), lr=c.learning_rate)
|
||||
elif c.dtype == torch.float16:
|
||||
from labml_nn.optimizers.adam_fp16 import AdamFP16
|
||||
return AdamFP16(get_trainable_params(c.model), lr=c.learning_rate)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@option(TrainerConf.optimizer, 'SGD')
|
||||
def sgd_optimizer(c: TrainerConf):
|
||||
return torch.optim.SGD(get_trainable_params(c.model), lr=c.learning_rate)
|
||||
|
||||
|
||||
@option(TrainerConf.scaler, 'Default')
|
||||
def grad_scaler(c: TrainerConf):
|
||||
if not c.is_amp:
|
||||
return None
|
||||
|
||||
if c.dtype == torch.float16:
|
||||
from labml_nn.optimizers.adam_fp16 import GradScalerFP16
|
||||
return GradScalerFP16()
|
||||
else:
|
||||
return GradScaler()
|
||||
|
||||
|
||||
class PipelineParallelTrainerConf(TrainerConf):
|
||||
is_checkpointing: bool = False
|
||||
chunks: int
|
||||
|
||||
fine_tuner: FineTuner
|
||||
136
labml_nn/optimizers/adam_fp16.py
Normal file
136
labml_nn/optimizers/adam_fp16.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""
|
||||
---
|
||||
title: Adam Optimizer for Half Precision Training
|
||||
summary: A simple PyTorch implementation/tutorial of Adam optimizer
|
||||
---
|
||||
|
||||
# Adam Optimizer for Half Precision Training
|
||||
"""
|
||||
|
||||
from typing import Dict, Tuple, Optional, Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.cuda.amp import grad_scaler
|
||||
from collections import defaultdict, abc
|
||||
|
||||
from labml_nn.optimizers import WeightDecay
|
||||
from labml_nn.optimizers.adam import Adam
|
||||
|
||||
|
||||
class AdamFP16(Adam):
|
||||
"""
|
||||
## Adam Optimizer for Half Precision Training
|
||||
|
||||
We extend [Adam Optimizer](adam.html) but use FP32 to store gradients and moments.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
|
||||
weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
|
||||
defaults: Optional[Dict[str, Any]] = None):
|
||||
# Parameter to store 32 bit gradients. This get populated by the `GradScaler` defined below.
|
||||
self.grad_fp32 = {}
|
||||
# Call the [Adam Optimizer](adam.html) initializer
|
||||
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
|
||||
|
||||
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
|
||||
"""
|
||||
### Initialize a parameter state
|
||||
|
||||
* `state` is the optimizer state of the parameter (tensor)
|
||||
* `group` stores optimizer attributes of the parameter group
|
||||
* `param` is the parameter tensor $\theta_{t-1}$
|
||||
|
||||
All the state tensors use FP32.
|
||||
"""
|
||||
|
||||
# This is the number of optimizer steps taken on the parameter, $t$
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradients, $m_t$
|
||||
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
|
||||
# Exponential moving average of squared gradient values, $v_t$
|
||||
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
|
||||
# Maintain a FP32 copy of the parameters
|
||||
state['fp32_copy'] = param.to(torch.float)
|
||||
|
||||
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
|
||||
"""
|
||||
### Take an update step for a given parameter tensor
|
||||
|
||||
* `state` is the optimizer state of the parameter (tensor)
|
||||
* `group` stores optimizer attributes of the parameter group
|
||||
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
|
||||
* `param` is the parameter tensor $\theta_{t-1}$
|
||||
"""
|
||||
|
||||
# Get the FP32 parameters
|
||||
param_fp32 = state['fp32_copy']
|
||||
# Get the FP32 gradients if available
|
||||
grad_fp32 = self.grad_fp32.get(param, None)
|
||||
if grad_fp32 is not None:
|
||||
del self.grad_fp32[param]
|
||||
grad = grad_fp32
|
||||
else:
|
||||
# Otherwise, convert the gradients to FP32
|
||||
grad = grad.to(torch.float)
|
||||
|
||||
# Calculate weight decay
|
||||
grad = self.weight_decay(param_fp32, grad, group)
|
||||
|
||||
# Get $m_t$ and $v_t$
|
||||
m, v = self.get_mv(state, group, grad)
|
||||
|
||||
# Increment $t$ the number of optimizer steps
|
||||
state['step'] += 1
|
||||
|
||||
# Perform *Adam* update
|
||||
self.adam_update(state, group, param_fp32, m, v)
|
||||
|
||||
# Set the parameters
|
||||
param.data = param_fp32.to(param.dtype)
|
||||
|
||||
|
||||
class GradScalerFP16(grad_scaler.GradScaler):
|
||||
"""
|
||||
## Gradient Scaler with half precision gradients
|
||||
|
||||
We extend PyTorch gradient scaler to use FP32 gradients.
|
||||
"""
|
||||
|
||||
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
|
||||
allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
|
||||
per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
|
||||
per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
|
||||
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
|
||||
with torch.no_grad():
|
||||
# Loop through parameters
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
# Skip non-trainable parameters
|
||||
if param.grad is None:
|
||||
continue
|
||||
# Not implemented for sparse tensors
|
||||
if param.grad.is_sparse:
|
||||
raise NotImplementedError
|
||||
|
||||
# If we are using the `AdamFP16` optimizer set `optimizer.grad_fp32[param]` to the FP32 gradients
|
||||
if isinstance(optimizer, AdamFP16):
|
||||
grad = param.grad.to(torch.float)
|
||||
optimizer.grad_fp32[param] = grad
|
||||
# Otherwise, do not convert the gradients to FP32
|
||||
else:
|
||||
grad = param.grad
|
||||
|
||||
per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)
|
||||
|
||||
# Unscale all the gradients
|
||||
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||
for grads in per_dtype_grads.values():
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(grads,
|
||||
per_device_found_inf.get(device),
|
||||
per_device_inv_scale.get(device))
|
||||
#
|
||||
return per_device_found_inf._per_device_tensors
|
||||
11
labml_nn/scaling/__init__.py
Normal file
11
labml_nn/scaling/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""
|
||||
---
|
||||
title: Large scale model training
|
||||
summary: >
|
||||
Large scale model training/inference implementations.
|
||||
---
|
||||
|
||||
# Large scale model training
|
||||
|
||||
* [Zero-DP optimizer](zero3/index.html)
|
||||
"""
|
||||
495
labml_nn/scaling/zero3/__init__.py
Normal file
495
labml_nn/scaling/zero3/__init__.py
Normal file
@ -0,0 +1,495 @@
|
||||
"""
|
||||
---
|
||||
title: Zero-DP Memory Optimization
|
||||
summary: >
|
||||
This is an implementation of Zero-DP Memory Optimization written in PyTorch.
|
||||
---
|
||||
|
||||
# Zero-DP Memory Optimization
|
||||
|
||||
This is an implementation of Zero-DP introduced in the paper
|
||||
[ZeRO: Memory Optimization Towards Training A Trillion Parameter Models](https://papers.labml.ai/paper/1910.02054),
|
||||
|
||||
It keeps shards of the optimizer state, gradients and parameters into multiple devices/nodes.
|
||||
It reduces the memory consumption to $\frac{(2 + 2 + K)\Psi}{N_d}$ of the original model,
|
||||
where $\Psi$ is the number of parameters, $N_d$ is the number of shards,
|
||||
and $K$ is number of optimizer bytes per parameter.
|
||||
$2 + 2$ are the parameter and gradient memory assuming 16-bit precision; i.e. 2 bytes per parameter and gradient.
|
||||
$K = 12$ for Adam optimizer because it maintains a copy of parameters, and two moments per parameter in fp32.
|
||||
|
||||
The communication volume of Zero-DP is $\mathcal{O}(3\Psi)$. For comparison data-parallel training
|
||||
has a communication volume of $\mathcal{O}(2\Psi)$.
|
||||
|
||||
Although this is named `Zero3`, we have only implemented the Zero-DP part of it and not the
|
||||
Zero-R memory optimizations which target residual memory consumption.
|
||||
Out implementation supports training only a subset of parameters.
|
||||
|
||||
This implementation is inspired by [Fairscale FSDP](https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html).
|
||||
|
||||
[Here's a script to fine-tune](finetune_neox.html) GPT NeoX using Zero-DP memory optimization.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Zero3Layer(nn.Module):
|
||||
"""
|
||||
## Zero3 Layer
|
||||
|
||||
Each layer of the model (or a combination of a few consecutive layers) should be wrapped in
|
||||
this module.
|
||||
"""
|
||||
# Each shard keeps parameters in `chunk` list.
|
||||
# The `chunk[0]` is for trainable parameters and `chunk[1]` is for fixed parameters.
|
||||
chunk: List[nn.Parameter]
|
||||
# This is the sizes of the chunks in `chunk` list.
|
||||
chunk_size: List[int]
|
||||
# The first chunk is for trainable parameters.
|
||||
TRAINING_PARAMS_IDX = 0
|
||||
|
||||
# This is the list of parameters split into lists as trainable and fixed parameters.
|
||||
param_refs: List[List[nn.Parameter]]
|
||||
|
||||
# CUDA stream to featch parameters
|
||||
fetch_stream: Optional[torch.cuda.Stream]
|
||||
# CUDA stream to backup/accumulate gradients
|
||||
backup_stream: Optional[torch.cuda.Stream]
|
||||
# List of layers right before this layer
|
||||
prev_layer: List['Zero3Layer']
|
||||
# List of layers right after this layer
|
||||
next_layer: List['Zero3Layer']
|
||||
# The position of the current layer; used this for debugging logs
|
||||
layer_idx: int
|
||||
|
||||
# Whether parameters have been fetched
|
||||
is_fetched: bool
|
||||
|
||||
# Device of the layer
|
||||
device: torch.device
|
||||
# Data type of the layer
|
||||
dtype: torch.dtype
|
||||
# The module to be wrapped
|
||||
module: nn.Module
|
||||
# Number of nodes/devices the data is sharded across
|
||||
world_size: int
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int, world_size: int, device: torch.device, dtype: torch.dtype):
|
||||
"""
|
||||
:param module: The module to be wrapped.
|
||||
:param rank: The rank of the current node.
|
||||
:param world_size: The number of nodes/devices the data is sharded across.
|
||||
:param device: The device of the layer.
|
||||
:param dtype: The data type of the layer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Initialize the properties
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.module = module
|
||||
self.prev_layer = []
|
||||
self.next_layer = []
|
||||
self.is_fetched = False
|
||||
self.world_size = world_size
|
||||
self.layer_idx = -1
|
||||
self.fetch_stream = None
|
||||
self.backup_stream = None
|
||||
|
||||
with torch.no_grad():
|
||||
# Collect all the parameters of the layer
|
||||
all_param_refs = [p for p in self.parameters()]
|
||||
|
||||
# Store the shape of the parameters because we need it later to reconstruct them
|
||||
for p in all_param_refs:
|
||||
p._orig_shape = p.shape
|
||||
|
||||
# All parameters should have the same type
|
||||
for p in all_param_refs:
|
||||
assert p.dtype == dtype, "All parameters should have same dtype"
|
||||
|
||||
# Separate parameters as trainable and fixed
|
||||
self.param_refs = [[p for p in all_param_refs if p.requires_grad],
|
||||
[p for p in all_param_refs if not p.requires_grad]]
|
||||
del all_param_refs
|
||||
|
||||
# The `rank = 0` node will calculate the size each device/node should store, and
|
||||
# distribute the parameters accordingly.
|
||||
if rank == 0:
|
||||
# Merge and pad trainable (`merged_params[0]`) and fixed (`merged_params[1]`) parameters
|
||||
merged_params = [self._merge_and_pad_params(ps) for ps in self.param_refs]
|
||||
# Calculate the chunk sizes of trainable and fixed params
|
||||
self.chunk_size = [(len(p) // world_size if p is not None else 0) for p in merged_params]
|
||||
# Broadcast the sizes
|
||||
dist.broadcast(torch.tensor(self.chunk_size, device=device), src=0)
|
||||
else:
|
||||
# Create an empty tensor to receive the sizes
|
||||
chunk_size = torch.tensor([0, 0], device=device)
|
||||
# Receive the sizes
|
||||
dist.broadcast(chunk_size, src=0)
|
||||
self.chunk_size = chunk_size.tolist()
|
||||
|
||||
# Create parameters for trainable (`self.chunk[0]`) and fixed (`self.chunk[1]`)
|
||||
# parameters to be stored in current device/node
|
||||
self.chunk = [nn.Parameter(self._empty((s,)), requires_grad=i == self.TRAINING_PARAMS_IDX)
|
||||
for i, s in enumerate(self.chunk_size)]
|
||||
|
||||
# An empty tensor to receive the trainable and fixed parameters combined
|
||||
chunk = self._empty((sum(self.chunk_size),))
|
||||
|
||||
if rank == 0:
|
||||
# Concatenate both trainable and fixed params
|
||||
all_params = torch.cat([p.view(world_size, -1) for p in merged_params], dim=-1).view(-1)
|
||||
del merged_params
|
||||
|
||||
# Scatter them to all the nodes/devices
|
||||
dist.scatter(chunk, list(all_params.split(sum(self.chunk_size))))
|
||||
del all_params
|
||||
else:
|
||||
# Receive the parameters
|
||||
dist.scatter(chunk)
|
||||
|
||||
# Collect the chunk data
|
||||
chunk = chunk.split(self.chunk_size)
|
||||
for i, c in enumerate(chunk):
|
||||
self.chunk[i].data[:] = c
|
||||
del chunk
|
||||
|
||||
# Cleanup the normal parameters
|
||||
self._cleanup_params()
|
||||
|
||||
# Add a backward hook. This gets called when the gradients relative to the module are computed.
|
||||
self._backward_hook_ref = self.register_full_backward_hook(self._backward_hook) # type: ignore
|
||||
|
||||
def _merge_and_pad_params(self, params: List[nn.Parameter]) -> torch.Tensor:
|
||||
"""
|
||||
#### Merge all the parameters and pad it so that it's divisible by `world_size`.
|
||||
"""
|
||||
# Total number of parameters
|
||||
size = sum(p.shape.numel() for p in params)
|
||||
|
||||
# If it is not divisible by `world_size`, pad it
|
||||
if size % self.world_size != 0:
|
||||
padding_fixed = self.world_size - (size % self.world_size)
|
||||
# Otherwise, no need to pad
|
||||
else:
|
||||
padding_fixed = 0
|
||||
# Create an empty padding tensor
|
||||
padding = self._empty((padding_fixed,))
|
||||
# Concatenate all the parameters and pad it
|
||||
return torch.cat([p.view(-1) for p in params] + [padding], dim=0)
|
||||
|
||||
def get_trainable_chunk(self) -> List[nn.Parameter]:
|
||||
"""
|
||||
### Get trainable chunk/shard of the parameters.
|
||||
|
||||
This is what we pass on to the optimizer on the current node.
|
||||
"""
|
||||
# Return and empty list if there are no trainable parameters
|
||||
if len(self.chunk[self.TRAINING_PARAMS_IDX]) == 0:
|
||||
return []
|
||||
|
||||
# Return the trainable chunk as a list
|
||||
return [self.chunk[self.TRAINING_PARAMS_IDX]]
|
||||
|
||||
def _empty(self, shape: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
#### Create an empty tensor of the given shape.
|
||||
"""
|
||||
return torch.empty(shape, device=self.device, dtype=self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def _cleanup_params(self):
|
||||
"""
|
||||
#### Cleanup the parameter data
|
||||
|
||||
This will release all the memory used by the layer parameters.
|
||||
"""
|
||||
|
||||
# Set the flag to indicate that the parameters are not fetched
|
||||
self.is_fetched = False
|
||||
|
||||
# Iterate through all parameters
|
||||
for ps in self.param_refs:
|
||||
for p in ps:
|
||||
# Wait for operations on the parameters to complete before any new operations
|
||||
p.data.record_stream(torch.cuda.current_stream())
|
||||
# Check to make sure the parameter is not sharing storage with anything else
|
||||
assert p.data.storage_offset() == 0, "The tensor is not the sole occupant of the storage."
|
||||
# Resize the storage to $0$. This will release the memory used by the parameter.
|
||||
#
|
||||
# **Setting `p.data` will not release the memory, since the autograd graph keeps a reference to it.**
|
||||
p.data.storage().resize_(0) # This is what actually clears the memory
|
||||
# Make sure the parameter has no gradient data
|
||||
assert p.grad is None, 'Gradients should be None'
|
||||
|
||||
@torch.no_grad()
|
||||
def fetch_params(self):
|
||||
"""
|
||||
### Fetch the parameters from all shards
|
||||
|
||||
This will fetch all the parameter data from all the nodes and rebuild the parameters on each node.
|
||||
"""
|
||||
|
||||
# Skip is already fetched
|
||||
if self.is_fetched:
|
||||
return
|
||||
|
||||
# Set the flag
|
||||
self.is_fetched = True
|
||||
|
||||
# Skip if there's nothing to fetch or share.
|
||||
if sum(self.chunk_size) == 0:
|
||||
return
|
||||
|
||||
# Use `fetch_stream` to fetch the parameters from all the shards
|
||||
with torch.cuda.stream(self.fetch_stream):
|
||||
# Create an empty tensor to receive the parameters
|
||||
buffer = self._empty((self.world_size * sum(self.chunk_size),))
|
||||
# Split the continuous buffer into the number of nodes. These splits are views of `buffer'.
|
||||
buffers = list(buffer.split(sum(self.chunk_size)))
|
||||
|
||||
# Concatenate both trainable and fixed chunks
|
||||
chunk = torch.cat(self.chunk, dim=0)
|
||||
|
||||
# Gather the parameters from all the nodes/devices
|
||||
dist.all_gather(buffers, chunk)
|
||||
|
||||
# Split the gathered parameters into the trainable and fixed chunks
|
||||
params = buffer.view(-1, sum(self.chunk_size)).split(self.chunk_size, dim=1)
|
||||
# Wait for the gather operation to complete and then clear the references to the buffers
|
||||
buffer.record_stream(self.fetch_stream)
|
||||
for b in buffers:
|
||||
b.record_stream(self.fetch_stream)
|
||||
buffer.record_stream(self.fetch_stream)
|
||||
del buffer
|
||||
del buffers
|
||||
|
||||
# Reshape the trainable and fixed parameters to continuous tensors
|
||||
params = [p.reshape(-1) for p in params]
|
||||
|
||||
# Collect the individual parameter tensors
|
||||
for cont, ps in zip(params, self.param_refs):
|
||||
# If there are no parameters, skip
|
||||
if not ps:
|
||||
continue
|
||||
|
||||
# Offset of the continuous tensor
|
||||
offset = 0
|
||||
# Iterate through model parameters and assign the values from the continuous tensor
|
||||
for p in ps:
|
||||
# Original parameter shape
|
||||
shape = p._orig_shape # type: ignore[attr-defined]
|
||||
# Change the storage size of the parameter. This was set to $0$ when we cleaned up the parameters.
|
||||
p.data.storage().resize_(shape.numel())
|
||||
# Assign the values from the continuous tensor
|
||||
p.data[:] = cont[offset: offset + shape.numel()].reshape(shape)
|
||||
# Wait for the operations to complete before other operations can be performed
|
||||
p.data.record_stream(self.fetch_stream)
|
||||
# Update the offset
|
||||
offset += shape.numel()
|
||||
|
||||
# Wait for the operation to complete before other operations can be performed
|
||||
cont.record_stream(self.fetch_stream)
|
||||
|
||||
#
|
||||
del params
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
### Forward pass
|
||||
"""
|
||||
|
||||
# Fetch all the parameters of the current node.
|
||||
# This gets called by the previous layer so this call is just to make sure parameters are fetched.
|
||||
self.fetch_params()
|
||||
|
||||
# Wait for parameter fetching to complete.
|
||||
torch.cuda.current_stream().wait_stream(self.fetch_stream)
|
||||
|
||||
# Start fetching parameters of the proceeding layers, so that they will fetch them which the current layer
|
||||
# does its computations.
|
||||
for layer in self.next_layer:
|
||||
layer.fetch_params()
|
||||
|
||||
# Add backward hooks to the parameters of the current layer if autograd is enabled.
|
||||
if torch.is_grad_enabled():
|
||||
self._add_backward_hooks()
|
||||
|
||||
# Compute the outputs of the current layer
|
||||
res = self.module(*args, **kwargs)
|
||||
|
||||
# Cleanup the parameters of the layer.
|
||||
#
|
||||
# *Skip cleaning up if autograd is enabled and this is the last layer in the network,
|
||||
# because we will need to fetch the parameters again for the backward pass.*
|
||||
if not torch.is_grad_enabled() or self.next_layer:
|
||||
self._cleanup_params()
|
||||
|
||||
return res
|
||||
|
||||
def _add_backward_hooks(self):
|
||||
"""
|
||||
#### Add backward hooks to the parameters of the current layer.
|
||||
"""
|
||||
|
||||
# Number of backward hooks added
|
||||
self._backward_hook_handles = 0
|
||||
|
||||
# Loop through trainable parameters of the current layer
|
||||
for p in self.param_refs[self.TRAINING_PARAMS_IDX]:
|
||||
# Make sure a hook hasn't already been added
|
||||
assert not hasattr(p, "_hook_handle"), 'Parameter has already been hooked'
|
||||
# Use `expand_as` to create an autograd step which we can intercept
|
||||
p_tmp = p.expand_as(p)
|
||||
# Get a handle to add the backward hook.
|
||||
# [This blog discusses about `grad_acc`](https://amsword.medium.com/understanding-pytorchs-autograd-with-grad-fn-and-next-functions-b2c4836daa00).
|
||||
grad_acc = p_tmp.grad_fn.next_functions[0][0]
|
||||
# Add the backward hook
|
||||
handle = grad_acc.register_hook(
|
||||
functools.partial(self._post_backward_hook, p))
|
||||
# Keep a reference to the handle
|
||||
p._hook_handle = handle
|
||||
# Increment the number of hooks added
|
||||
self._backward_hook_handles += 1
|
||||
|
||||
def _backward_event(self):
|
||||
"""
|
||||
#### Handle a backward event
|
||||
|
||||
This gets called by parameter backward hooks and the module backward hook.
|
||||
"""
|
||||
|
||||
# Decrement the hooks counter
|
||||
self._backward_hook_handles -= 1
|
||||
|
||||
# If all the hooks (including the module hook) have been called,
|
||||
# then we can back up gradients and clean up the parameters.
|
||||
if self._backward_hook_handles == -1:
|
||||
self._backup_grads()
|
||||
self._cleanup_params()
|
||||
|
||||
# Start fetch parameters of the previous layer, because autograd will next process the gradients of it.
|
||||
for layer in self.prev_layer:
|
||||
layer.fetch_params()
|
||||
|
||||
def _post_backward_hook(self, p: nn.Parameter, *args):
|
||||
"""
|
||||
#### Parameter backward hook
|
||||
"""
|
||||
# Remove the handle from the parameter
|
||||
p._hook_handle.remove() # type: ignore[attr-defined]
|
||||
delattr(p, "_hook_handle")
|
||||
|
||||
# Handle a backward event
|
||||
self._backward_event()
|
||||
|
||||
def _backward_hook(self, *args, **kwargs):
|
||||
"""
|
||||
#### Module backward hook
|
||||
"""
|
||||
# Handle a backward event
|
||||
self._backward_event()
|
||||
|
||||
# The previous layer will start computing gradients. We need to make sure it has finished fetching params.
|
||||
torch.cuda.current_stream().wait_stream(self.fetch_stream)
|
||||
|
||||
#
|
||||
return None
|
||||
|
||||
@torch.no_grad()
|
||||
def _backup_grads(self):
|
||||
"""
|
||||
### Backup the gradients of the current layer
|
||||
"""
|
||||
# Skip if there are no trainable parameters
|
||||
if self.chunk_size[self.TRAINING_PARAMS_IDX] == 0:
|
||||
return
|
||||
|
||||
# Use the backup stream to backup the gradients
|
||||
with torch.cuda.stream(self.backup_stream):
|
||||
# Buffer to store the gradients
|
||||
buffer = self._empty((self.world_size * self.chunk_size[self.TRAINING_PARAMS_IDX],))
|
||||
# Split the continuous buffer into number of nodes. These splits are views of `buffer'.
|
||||
buffers = list(buffer.split(self.chunk_size[self.TRAINING_PARAMS_IDX]))
|
||||
|
||||
# Offset of the continuous buffer
|
||||
offset = 0
|
||||
# Iterate through trainable parameters
|
||||
for p in self.param_refs[self.TRAINING_PARAMS_IDX]:
|
||||
# Collect gradients
|
||||
shape = p._orig_shape # type: ignore[attr-defined]
|
||||
buffer[offset: offset + shape.numel()] = p.grad.view(-1)
|
||||
# Update the offset
|
||||
offset += shape.numel()
|
||||
# Clean the gradients
|
||||
p.grad = None
|
||||
|
||||
# Empty tensor to accumulate the gradients of the current shard
|
||||
grad = self._empty((self.chunk_size[self.TRAINING_PARAMS_IDX],))
|
||||
# Accumulate the gradients of each shard. It scatters the buffers across the nodes,
|
||||
# and each node accumulates (reduces) the tensors it receives.
|
||||
dist.reduce_scatter(grad, buffers)
|
||||
|
||||
# Wait for the operation to complete and then clear the references to the buffers
|
||||
for b in buffers:
|
||||
b.record_stream(self.fetch_stream)
|
||||
buffer.record_stream(self.fetch_stream)
|
||||
del buffer
|
||||
del buffers
|
||||
|
||||
# Set the chunk gradients. This is what the optimizer sees.
|
||||
self.chunk[self.TRAINING_PARAMS_IDX].grad = grad
|
||||
del grad
|
||||
|
||||
|
||||
class Zero3Sequential(nn.Module):
|
||||
"""
|
||||
## Sequential module for `Zero3Layer` layers
|
||||
"""
|
||||
def __init__(self, modules: List[Zero3Layer]):
|
||||
"""
|
||||
:param modules: List of `Zero3Layer` layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# CUDA stream to fetch parameters
|
||||
self.fetch_stream = torch.cuda.Stream()
|
||||
# CUDA stream to back up (accumulate) gradients
|
||||
self.backup_stream = torch.cuda.Stream()
|
||||
|
||||
# Set the streams and preceding and proceeding layers for each `Zero3Layer` layer
|
||||
for i in range(len(modules)):
|
||||
# Set layer index
|
||||
modules[i].layer_idx = i
|
||||
# Set streams
|
||||
modules[i].fetch_stream = self.fetch_stream
|
||||
modules[i].backup_stream = self.backup_stream
|
||||
# Set proceeding layers
|
||||
if i + 1 < len(modules):
|
||||
modules[i].next_layer.append(modules[i + 1])
|
||||
# Set preceding layers
|
||||
if i - 1 >= 0:
|
||||
modules[i].prev_layer.append(modules[i - 1])
|
||||
|
||||
# Store list of modules
|
||||
self.module_list = nn.ModuleList(modules)
|
||||
|
||||
def get_trainable_chunk(self):
|
||||
# Return the list of trainable chunks from each layer
|
||||
return sum([m.get_trainable_chunk() for m in self.module_list], [])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Make sure gradient back up is complete
|
||||
torch.cuda.current_stream().wait_stream(self.backup_stream)
|
||||
|
||||
# Forward pass
|
||||
for m in self.module_list:
|
||||
x = m(x)
|
||||
|
||||
#
|
||||
return x
|
||||
127
labml_nn/scaling/zero3/finetune_neox.py
Normal file
127
labml_nn/scaling/zero3/finetune_neox.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
---
|
||||
title: Finetune GPT-NeoX with Zero3 memory optimizer
|
||||
summary: >
|
||||
This script trains the bias parameters of the GPT-NeoX on multiple devices with Zero-DP Memory Optimization.
|
||||
---
|
||||
|
||||
# Finetune [GPT-NeoX](../../neox/index.html) with [Zero3 memory optimizer](index.html)
|
||||
|
||||
This script trains the bias parameters of the [GPT-NeoX model](../../neox/model.html)
|
||||
on multiple devices with Zero-DP Memory Optimization.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from labml import experiment, monit, tracker
|
||||
from labml.configs import option
|
||||
from labml.logger import inspect
|
||||
from labml_nn.neox.samples.finetune import PipelineParallelTrainerConf
|
||||
|
||||
|
||||
# Use the [Pipeline Parallel Trainer configurations](../../neox/samples/finetune.html) and adapt it for
|
||||
# Zero3 memory optimizer.
|
||||
class Configs(PipelineParallelTrainerConf):
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
|
||||
@option(Configs.optimizer, 'Zero3Adam')
|
||||
def _optimizer(c: Configs):
|
||||
"""
|
||||
#### Set the optimizers for the model
|
||||
|
||||
Note that we pass the sharded parameters from `get_trainable_chunk`.
|
||||
"""
|
||||
from labml_nn.optimizers.adam_fp16 import AdamFP16
|
||||
return AdamFP16(c.model.get_trainable_chunk(), lr=c.learning_rate)
|
||||
|
||||
|
||||
@option(Configs.model, 'Zero3')
|
||||
def _model(c: Configs):
|
||||
"""
|
||||
#### Create the model with Zero3 memory optimizer
|
||||
"""
|
||||
from labml_nn.scaling.zero3 import Zero3Layer, Zero3Sequential
|
||||
|
||||
# To make sure the fine tuner sets the trainable parameters
|
||||
_ = c.fine_tuner
|
||||
|
||||
# Wrap the layers with `Zero3Layer`
|
||||
modules = []
|
||||
for m in monit.iterate('Zero3', c.layers):
|
||||
modules.append(Zero3Layer(m.to(c.device),
|
||||
c.rank, c.world_size, c.device, c.dtype))
|
||||
|
||||
# Create a sequential model
|
||||
model = Zero3Sequential(modules)
|
||||
|
||||
#
|
||||
return model
|
||||
|
||||
|
||||
def main(rank: int, world_size: int, init_method: str = 'tcp://localhost:23456'):
|
||||
"""
|
||||
#### Run the training on the node with rank `rank`.
|
||||
"""
|
||||
# Initialize PyTorch distributed process group
|
||||
with monit.section('Distributed'):
|
||||
torch.distributed.init_process_group('nccl',
|
||||
timeout=datetime.timedelta(seconds=30),
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
|
||||
# Set current device
|
||||
device = torch.device(f'cuda:{rank}')
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Create the experiment
|
||||
experiment.create(name='zero3_neox', writers={'screen', 'labml'})
|
||||
experiment.distributed(rank, world_size)
|
||||
|
||||
# Create configurations
|
||||
conf = Configs()
|
||||
|
||||
# Load configurations
|
||||
experiment.configs(conf, {
|
||||
'model': 'Zero3',
|
||||
'optimizer': 'Zero3Adam',
|
||||
|
||||
'device': device,
|
||||
'rank': rank,
|
||||
'world_size': world_size,
|
||||
|
||||
'learning_rate': 3e-4,
|
||||
'max_seq_len': 128,
|
||||
'batch_size': 16,
|
||||
})
|
||||
|
||||
# Start the experiment
|
||||
with experiment.start():
|
||||
# Initialize the model. Do this before the loop for cleaner logs.
|
||||
_ = conf.model
|
||||
|
||||
# Train the model
|
||||
for epoch in monit.loop(conf.epochs):
|
||||
conf.train_epoch()
|
||||
tracker.new_line()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
# Log the machine configurations
|
||||
inspect([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
|
||||
inspect(
|
||||
n_gpus=torch.cuda.device_count(),
|
||||
mpi=torch.distributed.is_mpi_available(),
|
||||
nccl=torch.distributed.is_nccl_available(),
|
||||
)
|
||||
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
||||
# Start a process for each GPU. You will need a separate launcher if you are using multiple computers.
|
||||
torch.multiprocessing.spawn(main, args=(n_gpu,), nprocs=n_gpu, join=True)
|
||||
Reference in New Issue
Block a user