This is an implementation of Attention with Linear Biases (ALiBi) from the paper Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation (pdf).
This replaces positional encodings with biases added to attention scores (attention logits, before the softmax). This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens and lower for far-away tokens. The biases decrease linearly in the log scale (because it's before the softmax) and each head has a different slope.
Here's the attention formula for -th token,
where is the query of the -th token, are the keys up to , and the number of features per head. Note that the above equality halts because is invariant to translations (you can add any constant to all elements without changing the result).
Here is the training code for a ALiBi model.
36import math
37
38import torch
39from torch import nn
40
41from labml.logger import inspect
42from labml_nn.transformers.mha import MultiHeadAttention
n_heads
is the number of heads in the attention layer The slope for first head is
The slopes for the rest of the heads are in a geometric series with a ratio same as above.
For instance when the number of heads is the slopes are
45def get_slopes(n_heads: int):
62 s = (2 ** (-2 ** -(math.log2(n_heads) - 3)))
The geometric sequence
64 return [s * (s ** i) for i in range(n_heads)]
n_heads
is the number of heads in the attention layer max_len
is the maximum sequence lengthThis returns a matrix of shape [n_heads, max_len]
with attention biases.
67def get_biases(n_heads: int, max_len: int):
Get slopes for each head
78 slopes = torch.tensor(get_slopes(n_heads))
Calculate distances
80 distance = torch.arange(max_len).to(torch.float)
Multiply them pair-wise to get the bias matrix
82 return distance[:, None] * slopes[None, :]
We override Multi-Head Attention module so we only need to write the get_scores
method.
85class AlibiMultiHeadAttention(MultiHeadAttention):
93 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, max_len: int = 5_000):
94 super().__init__(heads, d_model, dropout_prob)
Pre-calculate the biases
97 self.bias = nn.Parameter(get_biases(heads, max_len), requires_grad=False)
99 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
Calculate the standard attention score. It has shape [query_seq_len, key_seq_len, batch_size, head]
106 scores = super().get_scores(query, key)
Number of keys
109 key_seq_len = scores.shape[1]
Add the biases to scores.
Note that we add biases for all keys (not just upto ). We can do this since those extra entries will get removed because of the masking later.
116 return scores + self.bias[None, :key_seq_len, None, :]
Simple test function to see the slopes.
119def _test_slopes():
123 inspect(get_slopes(8))
124 inspect(get_slopes(16))
128if __name__ == '__main__':
129 _test_slopes()