Attention with Linear Biases (ALiBi)

This is an implementation of Attention with Linear Biases (ALiBi) from the paper Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation (pdf).

This replaces positional encodings with biases added to attention scores (attention logits, before the softmax). This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens and lower for far-away tokens. The biases decrease linearly in the log scale (because it's before the softmax) and each head has a different slope.

Here's the attention formula for -th token,

where is the query of the -th token, are the keys up to , and the number of features per head. Note that the above equality halts because is invariant to translations (you can add any constant to all elements without changing the result).

Here is the training code for a ALiBi model.

View Run

36import math
37
38import torch
39from torch import nn
40
41from labml.logger import inspect
42from labml_nn.transformers.mha import MultiHeadAttention

Get head-specific slope for each head

  • n_heads is the number of heads in the attention layer

The slope for first head is

The slopes for the rest of the heads are in a geometric series with a ratio same as above.

For instance when the number of heads is the slopes are

45def get_slopes(n_heads: int):

62    s = (2 ** (-2 ** -(math.log2(n_heads) - 3)))

The geometric sequence

64    return [s * (s ** i) for i in range(n_heads)]

Calculate the attention biases matrix

  • n_heads is the number of heads in the attention layer
  • max_len is the maximum sequence length

This returns a matrix of shape [n_heads, max_len] with attention biases.

67def get_biases(n_heads: int, max_len: int):

Get slopes for each head

78    slopes = torch.tensor(get_slopes(n_heads))

Calculate distances

80    distance = torch.arange(max_len).to(torch.float)

Multiply them pair-wise to get the bias matrix

82    return distance[:, None] * slopes[None, :]

Attention with Linear Biases (ALiBi)

We override Multi-Head Attention module so we only need to write the get_scores method.

85class AlibiMultiHeadAttention(MultiHeadAttention):
93    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, max_len: int = 5_000):
94        super().__init__(heads, d_model, dropout_prob)

Pre-calculate the biases

97        self.bias = nn.Parameter(get_biases(heads, max_len), requires_grad=False)

Calculate attention scores and add attention biases

99    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

Calculate the standard attention score. It has shape [query_seq_len, key_seq_len, batch_size, head]

106        scores = super().get_scores(query, key)

Number of keys

109        key_seq_len = scores.shape[1]

Add the biases to scores.

Note that we add biases for all keys (not just upto ). We can do this since those extra entries will get removed because of the masking later.

116        return scores + self.bias[None, :key_seq_len, None, :]

Simple test function to see the slopes.

119def _test_slopes():
123    inspect(get_slopes(8))
124    inspect(get_slopes(16))

128if __name__ == '__main__':
129    _test_slopes()