This is an implementation of Attention with Linear Biases (ALiBi) from the paper Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.
This replaces positional encodings with biases added to attention scores (attention logits, before the softmax). This is a relative scheme tested on autoregressive tasks, and the bias is higher for closeby tokens and lower for far-away tokens. The biases decrease linearly in the log scale (because it's before the softmax) and each head has a different slope.
Here's the attention formula for -th token,
where is the query of the -th token, are the keys up to , and the number of features per head. Note that the above equality halts because is invariant to translations (you can add any constant to all elements without changing the result).
Here is the training code for a ALiBi model.
35import math
36from typing import Optional
37
38import torch
39from torch import nn
40
41from labml.logger import inspect
42from labml_nn.transformers.mha import MultiHeadAttentionn_heads
 is the number of heads in the attention layer The slope for first head is
The slopes for the rest of the heads are in a geometric series with a ratio same as above.
For instance when the number of heads is the slopes are
45def get_slopes(n_heads: int):Get the closest power of 2 to n_heads
. If n_heads
 is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2, and then add the remaining slopes. 
64    n = 2 ** math.floor(math.log2(n_heads))66    m_0 = 2.0 ** (-8.0 / n)68    m = torch.pow(m_0, torch.arange(1, 1 + n))If n_heads
 is not a power of 2, then we add the remaining slopes. We calculate the remaining slopes for  (avoiding slopes added previously). And pick the slopes upto n_heads
. 
73    if n < n_heads:75        m_hat_0 = 2.0 ** (-4.0 / n)Note that we take steps by to avoid slopes added previously.
78        m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))Concatenate the slopes with the remaining slopes.
80        m = torch.cat([m, m_hat])
81
82    return mn_heads
 is the number of heads in the attention layer mask
 is the attention mask of shape [seq_len_q, seq_len_k]
This returns a matrix of shape [seq_len_q, seq_len_k, n_heads, ]
 with ALiBi attention biases.
85@torch.no_grad()
86def get_alibi_biases(n_heads: int, mask: torch.Tensor):Get slopes for each head
97    m = get_slopes(n_heads).to(mask.device)Calculate distances Here we calculate the distances using the mask.
Since it's causal mask we can just use  too. distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]
 
104    distance = mask.cumsum(dim=-1)Multiply them pair-wise to get the AliBi bias matrix
107    return distance[:, :, None] * m[None, None, :]110class AlibiMultiHeadAttention(MultiHeadAttention):117    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
118        super().__init__(heads, d_model, dropout_prob)To cache AliBi the biases
121        self.alibi_biases = None query
, key
 and value
 are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model]
.
mask
 has shape [seq_len, seq_len, batch_size]
 and mask[i, j, b]
 indicates whether for batch b
, query at position i
 has access to key-value at position j
.
123    def forward(self, *,
124                query: torch.Tensor,
125                key: torch.Tensor,
126                value: torch.Tensor,
127                mask: Optional[torch.Tensor] = None):ALiBi only works with causal masks.
139        assert mask is not None
140        assert mask.shape[0] == mask.shape[1] and mask.shape[2] == 1query
, key
 and value
 have shape [seq_len, batch_size, d_model]
 
143        seq_len, batch_size, _ = query.shapeAdd head dimension to mask and check its shape.
146        mask = self.prepare_mask(mask, query.shape, key.shape)Prepare query
, key
 and value
 for attention computation. These will then have shape [seq_len, batch_size, heads, d_k]
. 
150        query = self.query(query)
151        key = self.key(key)
152        value = self.value(value)Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
. 
156        scores = self.get_scores(query, key)Scale scores
159        scores *= self.scaleCreate AliBi biases if it's not cached
162        if self.alibi_biases is None or self.alibi_biases.shape[1] < seq_len:mask
 has shape seq_len, seq_len, 1, 1 
164            self.alibi_biases = get_alibi_biases(scores.shape[-1], mask[:, :, 0, 0])Add AliBi biases to attention scores. ALiBi biases has shape [seq_len, seq_len, n_heads]
 and scores
 has shape [seq_len, seq_len, batch_size, n_heads]
 
169        scores += self.alibi_biases[:seq_len, :seq_len, None, :]Apply mask
172        scores = scores.masked_fill(mask == 0, float('-inf'))attention along the key sequence dimension
176        attn = self.softmax(scores)Apply dropout
179        attn = self.dropout(attn)Multiply by values
183        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)Concatenate multiple heads
186        x = x.reshape(seq_len, batch_size, -1)Output layer
189        return self.output(x)Simple test function to see the slopes.
192def _test_alibi():196    inspect(get_slopes(12).tolist(), _n=-1)
197    from labml_nn.transformers.utils import subsequent_mask
198
199    mask = subsequent_mask(8)[:, :, 0]
200    inspect(mask)
201
202    inspect(get_alibi_biases(12, mask)[:, :, 3], _n=-1)206if __name__ == '__main__':
207    _test_alibi()