This is an implementation of Attention with Linear Biases (ALiBi) from the paper Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.
This replaces positional encodings with biases added to attention scores (attention logits, before the softmax). This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens and lower for far-away tokens. The biases decrease linearly in the log scale (because it's before the softmax) and each head has a different slope.
Here's the attention formula for -th token,
where is the query of the -th token, are the keys up to , and the number of features per head. Note that the above equality halts because is invariant to translations (you can add any constant to all elements without changing the result).
Here is the training code for a ALiBi model.
33import math
34from typing import Optional
35
36import torch
37from torch import nn
38
39from labml.logger import inspect
40from labml_nn.transformers.mha import MultiHeadAttentionn_heads
 is the number of heads in the attention layer The slope for first head is
The slopes for the rest of the heads are in a geometric series with a ratio same as above.
For instance when the number of heads is the slopes are
43def get_slopes(n_heads: int):Get the closest power of 2 to n_heads
. If n_heads
 is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2, and then add the remaining slopes. 
62    n = 2 ** math.floor(math.log2(n_heads))64    m_0 = 2.0 ** (-8.0 / n)66    m = torch.pow(m_0, torch.arange(1, 1 + n))If n_heads
 is not a power of 2, then we add the remaining slopes. We calculate the remaining slopes for  (avoiding slopes added previously). And pick the slopes upto n_heads
. 
71    if n < n_heads:73        m_hat_0 = 2.0 ** (-4.0 / n)Note that we take steps by to avoid slopes added previously.
76        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))Concatenate the slopes with the remaining slopes.
78        m = torch.cat([m, m_hat])
79
80    return mn_heads
 is the number of heads in the attention layer mask
 is the attention mask of shape [seq_len_q, seq_len_k]
This returns a matrix of shape [seq_len_q, seq_len_k, n_heads, ]
 with ALiBi attention biases.
83@torch.no_grad()
84def get_alibi_biases(n_heads: int, mask: torch.Tensor):Get slopes for each head
95    m = get_slopes(n_heads).to(mask.device)Calculate distances Here we calculate the distances using the mask.
Since it's causal mask we can just use  too. distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]
 
102    distance = mask.cumsum(dim=-1)Multiply them pair-wise to get the AliBi bias matrix
105    return distance[:, :, None] * m[None, None, :]108class AlibiMultiHeadAttention(MultiHeadAttention):115    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
116        super().__init__(heads, d_model, dropout_prob)To cache AliBi the biases
119        self.alibi_biases = None query
, key
 and value
 are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model]
.
mask
 has shape [seq_len, seq_len, batch_size]
 and mask[i, j, b]
 indicates whether for batch b
, query at position i
 has access to key-value at position j
.
121    def forward(self, *,
122                query: torch.Tensor,
123                key: torch.Tensor,
124                value: torch.Tensor,
125                mask: Optional[torch.Tensor] = None):ALiBi only works with causal masks.
137        assert mask is not None
138        assert mask.shape[0] == mask.shape[1] and mask.shape[2] == 1query
, key
 and value
 have shape [seq_len, batch_size, d_model]
 
141        seq_len, batch_size, _ = query.shapeAdd head dimension to mask and check its shape.
144        mask = self.prepare_mask(mask, query.shape, key.shape)Prepare query
, key
 and value
 for attention computation. These will then have shape [seq_len, batch_size, heads, d_k]
. 
148        query = self.query(query)
149        key = self.key(key)
150        value = self.value(value)Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
. 
154        scores = self.get_scores(query, key)Scale scores
157        scores *= self.scaleCreate AliBi biases if it's not cached
160        if self.alibi_biases is None or self.alibi_biases.shape[1] < seq_len:mask
 has shape [seq_len, seq_len, 1, 1]
 
162            self.alibi_biases = get_alibi_biases(scores.shape[-1], mask[:, :, 0, 0])Add AliBi biases to attention scores. ALiBi biases has shape [seq_len, seq_len, n_heads]
 and scores
 has shape [seq_len, seq_len, batch_size, n_heads]
 
167        scores += self.alibi_biases[:seq_len, :seq_len, None, :]Apply mask
170        scores = scores.masked_fill(mask == 0, float('-inf'))attention along the key sequence dimension
174        attn = self.softmax(scores)Apply dropout
177        attn = self.dropout(attn)Multiply by values
181        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)Concatenate multiple heads
184        x = x.reshape(seq_len, batch_size, -1)Output layer
187        return self.output(x)Simple test function to see the slopes.
190def _test_alibi():194    inspect(get_slopes(12).tolist(), _n=-1)
195    from labml_nn.transformers.utils import subsequent_mask
196
197    mask = subsequent_mask(8)[:, :, 0]
198    inspect(mask)
199
200    inspect(get_alibi_biases(12, mask)[:, :, 3], _n=-1)204if __name__ == '__main__':
205    _test_alibi()