mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 12:01:45 +08:00
short long
This commit is contained in:
71
labml_nn/alibi/__init__.py
Normal file
71
labml_nn/alibi/__init__.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
---
|
||||||
|
title: Attention with Linear Biases (ALiBi)
|
||||||
|
summary: >
|
||||||
|
Documented implementation with explanations of Attention with Linear Biases (ALiBi)
|
||||||
|
---
|
||||||
|
|
||||||
|
# Attention with Linear Biases (ALiBi)
|
||||||
|
|
||||||
|
This is an implementation of Attention with Linear Biases (ALiBi).
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from labml.logger import inspect
|
||||||
|
from labml_nn.transformers.mha import MultiHeadAttention
|
||||||
|
|
||||||
|
|
||||||
|
def get_slopes(n_heads: int):
|
||||||
|
"""
|
||||||
|
## Get head-specific slope $m$ for each head
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert math.log2(n_heads).is_integer()
|
||||||
|
|
||||||
|
s = (2 ** (-2 ** -(math.log2(n_heads) - 3)))
|
||||||
|
r = s
|
||||||
|
return [s * (r ** i) for i in range(n_heads)]
|
||||||
|
|
||||||
|
|
||||||
|
class AlibiMultiHeadAttention(MultiHeadAttention):
|
||||||
|
"""
|
||||||
|
## Attention with Linear Biases (ALiBi)
|
||||||
|
|
||||||
|
We override [Multi-Head Attention](mha.html) module so we only need to
|
||||||
|
write the `get_scores` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||||
|
# The linear transformations do not need a bias since we
|
||||||
|
# explicitly include it when calculating scores.
|
||||||
|
# However having a bias for `value` might make sense.
|
||||||
|
super().__init__(heads, d_model, dropout_prob)
|
||||||
|
|
||||||
|
self.slopes = nn.Parameter(torch.tensor(get_slopes(heads)), requires_grad=False)
|
||||||
|
|
||||||
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||||
|
r"""
|
||||||
|
### Calculate attention scores and add attention biases
|
||||||
|
"""
|
||||||
|
|
||||||
|
# scores has shape `[query_seq_len, key_seq_len, batch_size, head]`
|
||||||
|
scores = super().get_scores(query, key)
|
||||||
|
|
||||||
|
distance = torch.arange(scores.shape[1]).to(scores.device, scores.dtype)
|
||||||
|
bias = distance[None, :, None, None] * self.slopes[None, None, None, :]
|
||||||
|
# add to scores
|
||||||
|
scores = scores + bias
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def _test_slopes():
|
||||||
|
inspect(get_slopes(8))
|
||||||
|
inspect(get_slopes(16))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_test_slopes()
|
||||||
158
labml_nn/alibi/experiment.py
Normal file
158
labml_nn/alibi/experiment.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
from pathlib import PurePath, Path
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from labml import experiment, monit, lab
|
||||||
|
from labml.configs import option, calculate
|
||||||
|
from labml.utils.download import download_file
|
||||||
|
from labml_helpers.datasets.text import SequentialDataLoader, SequentialUnBatchedDataset, TextDataset
|
||||||
|
from labml_nn.alibi import AlibiMultiHeadAttention
|
||||||
|
from labml_nn.experiments.nlp_autoregression import transpose_batch
|
||||||
|
from labml_nn.transformers import TransformerConfigs
|
||||||
|
from labml_nn.transformers.gpt import Configs as GPTConfigs
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(GPTConfigs):
|
||||||
|
transformer: TransformerConfigs = 'GPT_ALiBi'
|
||||||
|
valid_seq_len: int = 128
|
||||||
|
valid_loader = 'shuffled_longer_valid_loader'
|
||||||
|
text: TextDataset = 'tiny_shakespeare_no_split'
|
||||||
|
|
||||||
|
|
||||||
|
# ### Multi-head Attention
|
||||||
|
def _alibi_mha(c: TransformerConfigs):
|
||||||
|
return AlibiMultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
|
||||||
|
|
||||||
|
|
||||||
|
calculate(TransformerConfigs.encoder_attn, 'alibi_mha', _alibi_mha)
|
||||||
|
calculate(TransformerConfigs.decoder_attn, 'alibi_mha', _alibi_mha)
|
||||||
|
calculate(TransformerConfigs.decoder_mem_attn, 'alibi_mha', _alibi_mha)
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.valid_loader)
|
||||||
|
def shuffled_longer_valid_loader(c: Configs):
|
||||||
|
"""
|
||||||
|
### Shuffled validation data loader
|
||||||
|
"""
|
||||||
|
return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
|
||||||
|
dataset=c.text,
|
||||||
|
seq_len=c.valid_seq_len),
|
||||||
|
batch_size=c.batch_size,
|
||||||
|
collate_fn=transpose_batch,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.transformer, 'GPT_ALiBi')
|
||||||
|
def _transformer_configs(c: Configs):
|
||||||
|
"""
|
||||||
|
### Transformer configurations
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We use our
|
||||||
|
# [configurable transformer implementation](../configs.html#TransformerConfigs)
|
||||||
|
conf = TransformerConfigs()
|
||||||
|
# Set the vocabulary sizes for embeddings and generating logits
|
||||||
|
conf.n_src_vocab = c.n_tokens
|
||||||
|
conf.n_tgt_vocab = c.n_tokens
|
||||||
|
# GPT uses GELU activation for position wise feedforward
|
||||||
|
conf.ffn.activation = 'GELU'
|
||||||
|
|
||||||
|
conf.src_embed = 'no_pos'
|
||||||
|
conf.tgt_embed = 'no_pos'
|
||||||
|
|
||||||
|
conf.encoder_attn = 'alibi_mha'
|
||||||
|
conf.decoder_attn = 'alibi_mha'
|
||||||
|
conf.decoder_mem_attn = 'alibi_mha'
|
||||||
|
|
||||||
|
#
|
||||||
|
return conf
|
||||||
|
|
||||||
|
|
||||||
|
class TextFileDataset(TextDataset):
|
||||||
|
standard_tokens = []
|
||||||
|
|
||||||
|
def __init__(self, path: PurePath, tokenizer: Callable, *,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
filter_subset: Optional[int] = None):
|
||||||
|
path = Path(path)
|
||||||
|
if not path.exists():
|
||||||
|
if not url:
|
||||||
|
raise FileNotFoundError(str(path))
|
||||||
|
else:
|
||||||
|
download_file(url, path)
|
||||||
|
|
||||||
|
with monit.section("Load data"):
|
||||||
|
text = self.load(path)
|
||||||
|
if filter_subset:
|
||||||
|
text = text[:filter_subset]
|
||||||
|
|
||||||
|
super().__init__(path, tokenizer, text, text, '')
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.text)
|
||||||
|
def tiny_shakespeare_no_split(c: Configs):
|
||||||
|
"""
|
||||||
|
### Tiny Shakespeare dataset
|
||||||
|
|
||||||
|
It will download from the url if not present
|
||||||
|
"""
|
||||||
|
return TextFileDataset(
|
||||||
|
lab.get_data_path() / 'tiny_shakespeare.txt',
|
||||||
|
c.tokenizer,
|
||||||
|
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create experiment
|
||||||
|
experiment.create(name="gpt_alibi")
|
||||||
|
# Create configs
|
||||||
|
conf = Configs()
|
||||||
|
# Override configurations
|
||||||
|
experiment.configs(conf, {
|
||||||
|
# Use character level tokenizer
|
||||||
|
'tokenizer': 'character',
|
||||||
|
# Prompt separator is blank
|
||||||
|
'prompt_separator': '',
|
||||||
|
# Starting prompt for sampling
|
||||||
|
'prompt': 'It is ',
|
||||||
|
# Use Tiny Shakespeare dataset
|
||||||
|
'text': 'tiny_shakespeare_no_split',
|
||||||
|
|
||||||
|
# Use a context size of $128$
|
||||||
|
'seq_len': 64,
|
||||||
|
# Use a context size of $128$
|
||||||
|
'valid_seq_len': 80,
|
||||||
|
# Train for $32$ epochs
|
||||||
|
'epochs': 128,
|
||||||
|
# Batch size $128$
|
||||||
|
'batch_size': 128,
|
||||||
|
# Switch between training and validation for $10$ times
|
||||||
|
# per epoch
|
||||||
|
'inner_iterations': 10,
|
||||||
|
|
||||||
|
# Transformer configurations
|
||||||
|
'transformer.d_model': 128,
|
||||||
|
'transformer.ffn.d_ff': 512,
|
||||||
|
'transformer.n_heads': 8,
|
||||||
|
'transformer.n_layers': 3,
|
||||||
|
|
||||||
|
'transformer.dropout': 0.2,
|
||||||
|
|
||||||
|
'is_log_last_token_loss': True,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Set models for saving and loading
|
||||||
|
experiment.add_pytorch_models({'model': conf.model})
|
||||||
|
|
||||||
|
experiment.load('511bfbc8071b11ecad290d807660f656')
|
||||||
|
|
||||||
|
# Start the experiment
|
||||||
|
with experiment.start():
|
||||||
|
# Run training
|
||||||
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -88,6 +88,9 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
|||||||
# Validation data loader
|
# Validation data loader
|
||||||
valid_loader: DataLoader = 'shuffled_valid_loader'
|
valid_loader: DataLoader = 'shuffled_valid_loader'
|
||||||
|
|
||||||
|
# Report last token loss
|
||||||
|
is_log_last_token_loss: bool = False
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
"""
|
"""
|
||||||
### Initialization
|
### Initialization
|
||||||
@ -108,6 +111,9 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
|||||||
### Training or validation step
|
### Training or validation step
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Set training/eval mode
|
||||||
|
self.model.train(self.mode.is_train)
|
||||||
|
|
||||||
# Move data to the device
|
# Move data to the device
|
||||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||||
|
|
||||||
@ -126,6 +132,11 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
|||||||
loss = self.loss_func(output, target)
|
loss = self.loss_func(output, target)
|
||||||
tracker.add("loss.", loss)
|
tracker.add("loss.", loss)
|
||||||
|
|
||||||
|
if self.is_log_last_token_loss:
|
||||||
|
if self.seq_len < output.shape[0]:
|
||||||
|
tracker.add('loss.seq_len.', self.loss_func(output[self.seq_len - 1], target[self.seq_len - 1]))
|
||||||
|
tracker.add('loss.last.', self.loss_func(output[-1], target[-1]))
|
||||||
|
|
||||||
# Calculate and log accuracy
|
# Calculate and log accuracy
|
||||||
self.accuracy(output, target)
|
self.accuracy(output, target)
|
||||||
self.accuracy.track()
|
self.accuracy.track()
|
||||||
|
|||||||
@ -199,7 +199,7 @@ class TransformerConfigs(BaseConfigs):
|
|||||||
|
|
||||||
# ### Multi-head Attention
|
# ### Multi-head Attention
|
||||||
def _mha(c: TransformerConfigs):
|
def _mha(c: TransformerConfigs):
|
||||||
return MultiHeadAttention(c.n_heads, c.d_model)
|
return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
|
||||||
|
|
||||||
|
|
||||||
calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
|
calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
|
||||||
|
|||||||
Reference in New Issue
Block a user