mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
alibi docs
This commit is contained in:
@ -1,71 +0,0 @@
|
||||
"""
|
||||
---
|
||||
title: Attention with Linear Biases (ALiBi)
|
||||
summary: >
|
||||
Documented implementation with explanations of Attention with Linear Biases (ALiBi)
|
||||
---
|
||||
|
||||
# Attention with Linear Biases (ALiBi)
|
||||
|
||||
This is an implementation of Attention with Linear Biases (ALiBi).
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml.logger import inspect
|
||||
from labml_nn.transformers.mha import MultiHeadAttention
|
||||
|
||||
|
||||
def get_slopes(n_heads: int):
|
||||
"""
|
||||
## Get head-specific slope $m$ for each head
|
||||
"""
|
||||
|
||||
assert math.log2(n_heads).is_integer()
|
||||
|
||||
s = (2 ** (-2 ** -(math.log2(n_heads) - 3)))
|
||||
r = s
|
||||
return [s * (r ** i) for i in range(n_heads)]
|
||||
|
||||
|
||||
class AlibiMultiHeadAttention(MultiHeadAttention):
|
||||
"""
|
||||
## Attention with Linear Biases (ALiBi)
|
||||
|
||||
We override [Multi-Head Attention](mha.html) module so we only need to
|
||||
write the `get_scores` method.
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||
# The linear transformations do not need a bias since we
|
||||
# explicitly include it when calculating scores.
|
||||
# However having a bias for `value` might make sense.
|
||||
super().__init__(heads, d_model, dropout_prob)
|
||||
|
||||
self.slopes = nn.Parameter(torch.tensor(get_slopes(heads)), requires_grad=False)
|
||||
|
||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||
r"""
|
||||
### Calculate attention scores and add attention biases
|
||||
"""
|
||||
|
||||
# scores has shape `[query_seq_len, key_seq_len, batch_size, head]`
|
||||
scores = super().get_scores(query, key)
|
||||
|
||||
distance = torch.arange(scores.shape[1]).to(scores.device, scores.dtype)
|
||||
bias = distance[None, :, None, None] * self.slopes[None, None, None, :]
|
||||
# add to scores
|
||||
scores = scores + bias
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def _test_slopes():
|
||||
inspect(get_slopes(8))
|
||||
inspect(get_slopes(16))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_slopes()
|
||||
129
labml_nn/transformers/alibi/__init__.py
Normal file
129
labml_nn/transformers/alibi/__init__.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""
|
||||
---
|
||||
title: Attention with Linear Biases (ALiBi)
|
||||
summary: >
|
||||
Documented implementation with explanations of Attention with Linear Biases (ALiBi)
|
||||
---
|
||||
|
||||
# Attention with Linear Biases (ALiBi)
|
||||
|
||||
This is an implementation of Attention with Linear Biases (ALiBi) from the paper
|
||||
Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation
|
||||
[(pdf)](https://ofir.io/train_short_test_long.pdf).
|
||||
|
||||
This replaces positional encodings with biases added to attention scores (attention logits, before the softmax).
|
||||
This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens
|
||||
and lower for far-away tokens.
|
||||
The biases decrease linearly in the log scale (because it's before the softmax) and each head has a different slope.
|
||||
|
||||
Here's the attention formula for $i$-th token,
|
||||
|
||||
\begin{align}
|
||||
\mathbf{a}_i
|
||||
&= \text{softmax} \bigg( \mathbf{q}_i \mathbf{K}^\top + m \cdot \big[-(i-1), \dots, 1, 0 \big] \bigg) \\
|
||||
&= \text{softmax} \bigg( \mathbf{q}_i \mathbf{K}^\top + m \cdot \big[0, 1, \dots, (i - 1) \big] \bigg)
|
||||
\end{align}
|
||||
|
||||
where $\mathbf{q}_i \in \mathbb{R}^d$ is the query of the $i$-th token, $K \in \mathbb{R}^{i \times d}$ are the keys
|
||||
up to $i$, and $d$ the number of features per head.
|
||||
Note that the above equality halts because $\text{softmax}$ is invariant to translations
|
||||
(you can add any constant to all elements without changing the result).
|
||||
|
||||
Here is [the training code](experiment.html) for a ALiBi model.
|
||||
|
||||
[](https://app.labml.ai/run/e87bec2a074911ec82cdd1759f10c925)
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml.logger import inspect
|
||||
from labml_nn.transformers.mha import MultiHeadAttention
|
||||
|
||||
|
||||
def get_slopes(n_heads: int):
|
||||
"""
|
||||
## Get head-specific slope $m$ for each head
|
||||
|
||||
* `n_heads` is the number of heads in the attention layer $n$
|
||||
|
||||
The slope for first head is
|
||||
|
||||
$$2^{-2^{-(\log_2 n) - 3}}$$
|
||||
|
||||
The slopes for the rest of the heads are in a geometric series with a ratio same as above.
|
||||
|
||||
For instance when the number of heads is $8$ the slopes are
|
||||
$$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$
|
||||
"""
|
||||
|
||||
# $$2^{-2^{-(\log_2 n) - 3}}$$
|
||||
s = (2 ** (-2 ** -(math.log2(n_heads) - 3)))
|
||||
# The geometric sequence
|
||||
return [s * (s ** i) for i in range(n_heads)]
|
||||
|
||||
|
||||
def get_biases(n_heads: int, max_len: int):
|
||||
"""
|
||||
## Calculate the attention biases matrix
|
||||
|
||||
* `n_heads` is the number of heads in the attention layer
|
||||
* `max_len` is the maximum sequence length
|
||||
|
||||
This returns a matrix of shape `[n_heads, max_len]` with attention biases.
|
||||
"""
|
||||
|
||||
# Get slopes $m$ for each head
|
||||
slopes = torch.tensor(get_slopes(n_heads))
|
||||
# Calculate distances $[0, 1, \dots, N]$
|
||||
distance = torch.arange(max_len).to(torch.float)
|
||||
# Multiply them pair-wise to get the bias matrix
|
||||
return distance[:, None] * slopes[None, :]
|
||||
|
||||
|
||||
class AlibiMultiHeadAttention(MultiHeadAttention):
|
||||
"""
|
||||
## Attention with Linear Biases (ALiBi)
|
||||
|
||||
We override [Multi-Head Attention](mha.html) module so we only need to
|
||||
write the `get_scores` method.
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, max_len: int = 5_000):
|
||||
super().__init__(heads, d_model, dropout_prob)
|
||||
|
||||
# Pre-calculate the biases
|
||||
self.bias = nn.Parameter(get_biases(heads, max_len), requires_grad=False)
|
||||
|
||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||
r"""
|
||||
### Calculate attention scores and add attention biases
|
||||
"""
|
||||
|
||||
# Calculate the standard attention score.
|
||||
# It has shape `[query_seq_len, key_seq_len, batch_size, head]`
|
||||
scores = super().get_scores(query, key)
|
||||
|
||||
# Number of keys
|
||||
key_seq_len = scores.shape[1]
|
||||
# Add the biases to scores.
|
||||
#
|
||||
# $$\mathbf{q}_i \mathbf{K}^\top + m \cdot \big[0, 1, \dots, (i - 1) \big]$$
|
||||
#
|
||||
# Note that we add biases for all keys (not just upto $i$). We can do this since
|
||||
# those extra entries will get removed because of the masking later.
|
||||
return scores + self.bias[None, :key_seq_len, None, :]
|
||||
|
||||
|
||||
def _test_slopes():
|
||||
"""
|
||||
Simple test function to see the slopes.
|
||||
"""
|
||||
inspect(get_slopes(8))
|
||||
inspect(get_slopes(16))
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
_test_slopes()
|
||||
@ -1,32 +1,66 @@
|
||||
"""
|
||||
---
|
||||
title: Attention with Linear Biases (ALiBi) Experiment
|
||||
summary: This experiment trains an Attention with Linear Biases (ALiBi) based model on Tiny Shakespeare dataset.
|
||||
---
|
||||
|
||||
# [Attention with Linear Biases (ALiBi)](index.html) Experiment
|
||||
|
||||
This is an annotated PyTorch experiment to train a [ALiBi model](index.html).
|
||||
|
||||
This is based on
|
||||
[our GPT model](../gpt/index.html).
|
||||
|
||||
[](https://app.labml.ai/run/e87bec2a074911ec82cdd1759f10c925)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option, calculate
|
||||
from labml_helpers.datasets.text import SequentialUnBatchedDataset
|
||||
from labml_nn.alibi import AlibiMultiHeadAttention
|
||||
from labml_nn.transformers.alibi import AlibiMultiHeadAttention
|
||||
from labml_nn.experiments.nlp_autoregression import transpose_batch
|
||||
from labml_nn.transformers import TransformerConfigs
|
||||
from labml_nn.transformers.gpt import Configs as GPTConfigs
|
||||
|
||||
|
||||
class Configs(GPTConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
We extend [GPT configurations](../gpt/index.html) and change the attention mechanism.
|
||||
"""
|
||||
|
||||
# ALiBi based transformer (defined below)
|
||||
transformer: TransformerConfigs = 'GPT_ALiBi'
|
||||
# Longer validation set
|
||||
valid_seq_len: int = 128
|
||||
valid_loader = 'shuffled_longer_valid_loader'
|
||||
|
||||
def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Log losses at the initial and final tokens
|
||||
"""
|
||||
# If there are more tokens that the training sequence length (during validation),
|
||||
if self.seq_len < output.shape[0]:
|
||||
# Log the loss at training sequence length
|
||||
tracker.add(f'loss.{self.seq_len - 1}.', self.loss_func(output[self.seq_len - 1], target[self.seq_len - 1]))
|
||||
# Log the loss at the first token
|
||||
tracker.add(f'loss.0.', self.loss_func(output[0], target[0]))
|
||||
# Log the loss at the final token
|
||||
tracker.add(f'loss.{int(output.shape[0]) - 1}.', self.loss_func(output[-1], target[-1]))
|
||||
|
||||
|
||||
# ### Multi-head Attention
|
||||
def _alibi_mha(c: TransformerConfigs):
|
||||
"""
|
||||
Create an ALiBi attention module
|
||||
"""
|
||||
return AlibiMultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
|
||||
|
||||
|
||||
# Set all attention mechanisms to ALiBi
|
||||
calculate(TransformerConfigs.encoder_attn, 'alibi_mha', _alibi_mha)
|
||||
calculate(TransformerConfigs.decoder_attn, 'alibi_mha', _alibi_mha)
|
||||
calculate(TransformerConfigs.decoder_mem_attn, 'alibi_mha', _alibi_mha)
|
||||
@ -35,7 +69,7 @@ calculate(TransformerConfigs.decoder_mem_attn, 'alibi_mha', _alibi_mha)
|
||||
@option(Configs.valid_loader)
|
||||
def shuffled_longer_valid_loader(c: Configs):
|
||||
"""
|
||||
### Shuffled validation data loader
|
||||
Shuffled validation data loader with `valid_seq_len` sequence length
|
||||
"""
|
||||
return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
|
||||
dataset=c.text,
|
||||
@ -48,7 +82,7 @@ def shuffled_longer_valid_loader(c: Configs):
|
||||
@option(Configs.transformer, 'GPT_ALiBi')
|
||||
def _transformer_configs(c: Configs):
|
||||
"""
|
||||
### Transformer configurations
|
||||
### ALiBi based Transformer configurations
|
||||
"""
|
||||
|
||||
# We use our
|
||||
@ -60,9 +94,11 @@ def _transformer_configs(c: Configs):
|
||||
# GPT uses GELU activation for position wise feedforward
|
||||
conf.ffn.activation = 'GELU'
|
||||
|
||||
# ALiBi doesn't use positional embeddings
|
||||
conf.src_embed = 'no_pos'
|
||||
conf.tgt_embed = 'no_pos'
|
||||
|
||||
# Set all attention mechanisms to ALiBi
|
||||
conf.encoder_attn = 'alibi_mha'
|
||||
conf.decoder_attn = 'alibi_mha'
|
||||
conf.decoder_mem_attn = 'alibi_mha'
|
||||
@ -105,7 +141,6 @@ def main():
|
||||
'transformer.ffn.d_ff': 512,
|
||||
'transformer.n_heads': 8,
|
||||
'transformer.n_layers': 4,
|
||||
|
||||
'transformer.dropout': 0.1,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user