alibi docs

This commit is contained in:
Varuna Jayasiri
2021-08-28 14:24:41 +05:30
parent 02992a43ab
commit e06a89e2c2
8 changed files with 1348 additions and 254 deletions

View File

@ -1,71 +0,0 @@
"""
---
title: Attention with Linear Biases (ALiBi)
summary: >
Documented implementation with explanations of Attention with Linear Biases (ALiBi)
---
# Attention with Linear Biases (ALiBi)
This is an implementation of Attention with Linear Biases (ALiBi).
"""
import math
import torch
from torch import nn
from labml.logger import inspect
from labml_nn.transformers.mha import MultiHeadAttention
def get_slopes(n_heads: int):
"""
## Get head-specific slope $m$ for each head
"""
assert math.log2(n_heads).is_integer()
s = (2 ** (-2 ** -(math.log2(n_heads) - 3)))
r = s
return [s * (r ** i) for i in range(n_heads)]
class AlibiMultiHeadAttention(MultiHeadAttention):
"""
## Attention with Linear Biases (ALiBi)
We override [Multi-Head Attention](mha.html) module so we only need to
write the `get_scores` method.
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
# The linear transformations do not need a bias since we
# explicitly include it when calculating scores.
# However having a bias for `value` might make sense.
super().__init__(heads, d_model, dropout_prob)
self.slopes = nn.Parameter(torch.tensor(get_slopes(heads)), requires_grad=False)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
r"""
### Calculate attention scores and add attention biases
"""
# scores has shape `[query_seq_len, key_seq_len, batch_size, head]`
scores = super().get_scores(query, key)
distance = torch.arange(scores.shape[1]).to(scores.device, scores.dtype)
bias = distance[None, :, None, None] * self.slopes[None, None, None, :]
# add to scores
scores = scores + bias
return scores
def _test_slopes():
inspect(get_slopes(8))
inspect(get_slopes(16))
if __name__ == '__main__':
_test_slopes()

View File

@ -0,0 +1,129 @@
"""
---
title: Attention with Linear Biases (ALiBi)
summary: >
Documented implementation with explanations of Attention with Linear Biases (ALiBi)
---
# Attention with Linear Biases (ALiBi)
This is an implementation of Attention with Linear Biases (ALiBi) from the paper
Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation
[(pdf)](https://ofir.io/train_short_test_long.pdf).
This replaces positional encodings with biases added to attention scores (attention logits, before the softmax).
This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens
and lower for far-away tokens.
The biases decrease linearly in the log scale (because it's before the softmax) and each head has a different slope.
Here's the attention formula for $i$-th token,
\begin{align}
\mathbf{a}_i
&= \text{softmax} \bigg( \mathbf{q}_i \mathbf{K}^\top + m \cdot \big[-(i-1), \dots, 1, 0 \big] \bigg) \\
&= \text{softmax} \bigg( \mathbf{q}_i \mathbf{K}^\top + m \cdot \big[0, 1, \dots, (i - 1) \big] \bigg)
\end{align}
where $\mathbf{q}_i \in \mathbb{R}^d$ is the query of the $i$-th token, $K \in \mathbb{R}^{i \times d}$ are the keys
up to $i$, and $d$ the number of features per head.
Note that the above equality halts because $\text{softmax}$ is invariant to translations
(you can add any constant to all elements without changing the result).
Here is [the training code](experiment.html) for a ALiBi model.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/e87bec2a074911ec82cdd1759f10c925)
"""
import math
import torch
from torch import nn
from labml.logger import inspect
from labml_nn.transformers.mha import MultiHeadAttention
def get_slopes(n_heads: int):
"""
## Get head-specific slope $m$ for each head
* `n_heads` is the number of heads in the attention layer $n$
The slope for first head is
$$2^{-2^{-(\log_2 n) - 3}}$$
The slopes for the rest of the heads are in a geometric series with a ratio same as above.
For instance when the number of heads is $8$ the slopes are
$$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$
"""
# $$2^{-2^{-(\log_2 n) - 3}}$$
s = (2 ** (-2 ** -(math.log2(n_heads) - 3)))
# The geometric sequence
return [s * (s ** i) for i in range(n_heads)]
def get_biases(n_heads: int, max_len: int):
"""
## Calculate the attention biases matrix
* `n_heads` is the number of heads in the attention layer
* `max_len` is the maximum sequence length
This returns a matrix of shape `[n_heads, max_len]` with attention biases.
"""
# Get slopes $m$ for each head
slopes = torch.tensor(get_slopes(n_heads))
# Calculate distances $[0, 1, \dots, N]$
distance = torch.arange(max_len).to(torch.float)
# Multiply them pair-wise to get the bias matrix
return distance[:, None] * slopes[None, :]
class AlibiMultiHeadAttention(MultiHeadAttention):
"""
## Attention with Linear Biases (ALiBi)
We override [Multi-Head Attention](mha.html) module so we only need to
write the `get_scores` method.
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, max_len: int = 5_000):
super().__init__(heads, d_model, dropout_prob)
# Pre-calculate the biases
self.bias = nn.Parameter(get_biases(heads, max_len), requires_grad=False)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
r"""
### Calculate attention scores and add attention biases
"""
# Calculate the standard attention score.
# It has shape `[query_seq_len, key_seq_len, batch_size, head]`
scores = super().get_scores(query, key)
# Number of keys
key_seq_len = scores.shape[1]
# Add the biases to scores.
#
# $$\mathbf{q}_i \mathbf{K}^\top + m \cdot \big[0, 1, \dots, (i - 1) \big]$$
#
# Note that we add biases for all keys (not just upto $i$). We can do this since
# those extra entries will get removed because of the masking later.
return scores + self.bias[None, :key_seq_len, None, :]
def _test_slopes():
"""
Simple test function to see the slopes.
"""
inspect(get_slopes(8))
inspect(get_slopes(16))
#
if __name__ == '__main__':
_test_slopes()

View File

@ -1,32 +1,66 @@
"""
---
title: Attention with Linear Biases (ALiBi) Experiment
summary: This experiment trains an Attention with Linear Biases (ALiBi) based model on Tiny Shakespeare dataset.
---
# [Attention with Linear Biases (ALiBi)](index.html) Experiment
This is an annotated PyTorch experiment to train a [ALiBi model](index.html).
This is based on
[our GPT model](../gpt/index.html).
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/e87bec2a074911ec82cdd1759f10c925)
"""
import torch
from torch.utils.data import DataLoader
from labml import experiment, tracker
from labml.configs import option, calculate
from labml_helpers.datasets.text import SequentialUnBatchedDataset
from labml_nn.alibi import AlibiMultiHeadAttention
from labml_nn.transformers.alibi import AlibiMultiHeadAttention
from labml_nn.experiments.nlp_autoregression import transpose_batch
from labml_nn.transformers import TransformerConfigs
from labml_nn.transformers.gpt import Configs as GPTConfigs
class Configs(GPTConfigs):
"""
## Configurations
We extend [GPT configurations](../gpt/index.html) and change the attention mechanism.
"""
# ALiBi based transformer (defined below)
transformer: TransformerConfigs = 'GPT_ALiBi'
# Longer validation set
valid_seq_len: int = 128
valid_loader = 'shuffled_longer_valid_loader'
def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
"""
Log losses at the initial and final tokens
"""
# If there are more tokens that the training sequence length (during validation),
if self.seq_len < output.shape[0]:
# Log the loss at training sequence length
tracker.add(f'loss.{self.seq_len - 1}.', self.loss_func(output[self.seq_len - 1], target[self.seq_len - 1]))
# Log the loss at the first token
tracker.add(f'loss.0.', self.loss_func(output[0], target[0]))
# Log the loss at the final token
tracker.add(f'loss.{int(output.shape[0]) - 1}.', self.loss_func(output[-1], target[-1]))
# ### Multi-head Attention
def _alibi_mha(c: TransformerConfigs):
"""
Create an ALiBi attention module
"""
return AlibiMultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
# Set all attention mechanisms to ALiBi
calculate(TransformerConfigs.encoder_attn, 'alibi_mha', _alibi_mha)
calculate(TransformerConfigs.decoder_attn, 'alibi_mha', _alibi_mha)
calculate(TransformerConfigs.decoder_mem_attn, 'alibi_mha', _alibi_mha)
@ -35,7 +69,7 @@ calculate(TransformerConfigs.decoder_mem_attn, 'alibi_mha', _alibi_mha)
@option(Configs.valid_loader)
def shuffled_longer_valid_loader(c: Configs):
"""
### Shuffled validation data loader
Shuffled validation data loader with `valid_seq_len` sequence length
"""
return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
dataset=c.text,
@ -48,7 +82,7 @@ def shuffled_longer_valid_loader(c: Configs):
@option(Configs.transformer, 'GPT_ALiBi')
def _transformer_configs(c: Configs):
"""
### Transformer configurations
### ALiBi based Transformer configurations
"""
# We use our
@ -60,9 +94,11 @@ def _transformer_configs(c: Configs):
# GPT uses GELU activation for position wise feedforward
conf.ffn.activation = 'GELU'
# ALiBi doesn't use positional embeddings
conf.src_embed = 'no_pos'
conf.tgt_embed = 'no_pos'
# Set all attention mechanisms to ALiBi
conf.encoder_attn = 'alibi_mha'
conf.decoder_attn = 'alibi_mha'
conf.decoder_mem_attn = 'alibi_mha'
@ -105,7 +141,6 @@ def main():
'transformer.ffn.d_ff': 512,
'transformer.n_heads': 8,
'transformer.n_layers': 4,
'transformer.dropout': 0.1,
})