This is an implementation of Attention with Linear Biases (ALiBi) from the paper Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation (pdf).
This replaces positional encodings with biases added to attention scores (attention logits, before the softmax). This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens and lower for far-away tokens. The biases decrease linearly in the log scale (because it’s before the softmax) and each head has a different slope.
Here’s the attention formula for $i$-th token,
where $\mathbf{q}_i \in \mathbb{R}^d$ is the query of the $i$-th token, $K \in \mathbb{R}^{i \times d}$ are the keys up to $i$, and $d$ the number of features per head. Note that the above equality halts because $\text{softmax}$ is invariant to translations (you can add any constant to all elements without changing the result).
Here is the training code for a ALiBi model.
36import math
37
38import torch
39from torch import nn
40
41from labml.logger import inspect
42from labml_nn.transformers.mha import MultiHeadAttentionn_heads is the number of heads in the attention layer $n$The slope for first head is
The slopes for the rest of the heads are in a geometric series with a ratio same as above.
For instance when the number of heads is $8$ the slopes are
45def get_slopes(n_heads: int):62 s = (2 ** (-2 ** -(math.log2(n_heads) - 3)))The geometric sequence
64 return [s * (s ** i) for i in range(n_heads)]n_heads is the number of heads in the attention layermax_len is the maximum sequence lengthThis returns a matrix of shape [n_heads, max_len] with attention biases.
67def get_biases(n_heads: int, max_len: int):Get slopes $m$ for each head
78 slopes = torch.tensor(get_slopes(n_heads))Calculate distances $[0, 1, \dots, N]$
80 distance = torch.arange(max_len).to(torch.float)Multiply them pair-wise to get the bias matrix
82 return distance[:, None] * slopes[None, :]We override Multi-Head Attention module so we only need to
write the get_scores method.
85class AlibiMultiHeadAttention(MultiHeadAttention):93 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, max_len: int = 5_000):
94 super().__init__(heads, d_model, dropout_prob)Pre-calculate the biases
97 self.bias = nn.Parameter(get_biases(heads, max_len), requires_grad=False)99 def get_scores(self, query: torch.Tensor, key: torch.Tensor):Calculate the standard attention score.
It has shape [query_seq_len, key_seq_len, batch_size, head]
106 scores = super().get_scores(query, key)Number of keys
109 key_seq_len = scores.shape[1]Add the biases to scores.
Note that we add biases for all keys (not just upto $i$). We can do this since those extra entries will get removed because of the masking later.
116 return scores + self.bias[None, :key_seq_len, None, :]Simple test function to see the slopes.
119def _test_slopes():123 inspect(get_slopes(8))
124 inspect(get_slopes(16))128if __name__ == '__main__':
129 _test_slopes()