Attention with Linear Biases (ALiBi)

This is an implementation of Attention with Linear Biases (ALiBi) from the paper Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation (pdf).

This replaces positional encodings with biases added to attention scores (attention logits, before the softmax). This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens and lower for far-away tokens. The biases decrease linearly in the log scale (because it’s before the softmax) and each head has a different slope.

Here’s the attention formula for $i$-th token,

where $\mathbf{q}_i \in \mathbb{R}^d$ is the query of the $i$-th token, $K \in \mathbb{R}^{i \times d}$ are the keys up to $i$, and $d$ the number of features per head. Note that the above equality halts because $\text{softmax}$ is invariant to translations (you can add any constant to all elements without changing the result).

Here is the training code for a ALiBi model.

View Run

36import math
37
38import torch
39from torch import nn
40
41from labml.logger import inspect
42from labml_nn.transformers.mha import MultiHeadAttention

Get head-specific slope $m$ for each head

  • n_heads is the number of heads in the attention layer $n$

The slope for first head is

The slopes for the rest of the heads are in a geometric series with a ratio same as above.

For instance when the number of heads is $8$ the slopes are

45def get_slopes(n_heads: int):

62    s = (2 ** (-2 ** -(math.log2(n_heads) - 3)))

The geometric sequence

64    return [s * (s ** i) for i in range(n_heads)]

Calculate the attention biases matrix

  • n_heads is the number of heads in the attention layer
  • max_len is the maximum sequence length

This returns a matrix of shape [n_heads, max_len] with attention biases.

67def get_biases(n_heads: int, max_len: int):

Get slopes $m$ for each head

78    slopes = torch.tensor(get_slopes(n_heads))

Calculate distances $[0, 1, \dots, N]$

80    distance = torch.arange(max_len).to(torch.float)

Multiply them pair-wise to get the bias matrix

82    return distance[:, None] * slopes[None, :]

Attention with Linear Biases (ALiBi)

We override Multi-Head Attention module so we only need to write the get_scores method.

85class AlibiMultiHeadAttention(MultiHeadAttention):
93    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, max_len: int = 5_000):
94        super().__init__(heads, d_model, dropout_prob)

Pre-calculate the biases

97        self.bias = nn.Parameter(get_biases(heads, max_len), requires_grad=False)

Calculate attention scores and add attention biases

99    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

Calculate the standard attention score. It has shape [query_seq_len, key_seq_len, batch_size, head]

106        scores = super().get_scores(query, key)

Number of keys

109        key_seq_len = scores.shape[1]

Add the biases to scores.

Note that we add biases for all keys (not just upto $i$). We can do this since those extra entries will get removed because of the masking later.

116        return scores + self.bias[None, :key_seq_len, None, :]

Simple test function to see the slopes.

119def _test_slopes():
123    inspect(get_slopes(8))
124    inspect(get_slopes(16))
128if __name__ == '__main__':
129    _test_slopes()