Relative Multi-Headed Attention

This is an implementation of Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.

Transformer has a limited attention span, equal to the length of the sequence trained in parallel. All these positions have a fixed positional encoding. Transformer XL increases this attention span by letting each of the positions pay attention to precalculated past embeddings. For instance if the context length is $l$ it will keep the embeddings of all layers for previous batch of length $l$ and feed them to current step. If we use fixed-positional encodings these pre-calculated embeddings will have the same positions as the current context. They introduce relative positional encoding, where the positional encodings are introduced at the attention calculation.

27import torch
28from torch import nn
29
30from labml.logger import inspect
31from labml_nn.transformers.mha import MultiHeadAttention

This method shifts $i^{th}$ row of a matrix by $i$ columns.

If the input is [[1, 2 ,3], [4, 5 ,6], [7, 8, 9]], the shifted result would be [[1, 2 ,3], [0, 4, 5], [9, 0, 7]]. Ideally we should mask out the lower triangle but it’s ok for our purpose.

34def shift_right(x: torch.Tensor):

Concatenate a column of zeros

44    zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
45    x_padded = torch.cat([x, zero_pad], dim=1)

Remove excess elements from the end

48    x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
49    x = x_padded[:-1].view_as(x)
50
51    return x

Relative Multi-Head Attention Module

We override Multi-Head Attention module so we only need to write the get_scores method.

54class RelativeMultiHeadAttention(MultiHeadAttention):
62    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):

The linear transformations doesn’t need a bias since we take care of it when calculating scores. However having a bias for value might make sense.

66        super().__init__(heads, d_model, dropout_prob, bias=False)

Number of relative positions

69        self.P = 2 ** 12

Relative positional embeddings for key relative to the query. We need $2P$ embeddings because the keys can be before or after the query.

73        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)

Relative positional embedding bias for key relative to the query.

75        self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)

Positional embeddings for the query is independent of the position of the query

77        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

Get relative attention scores

With absolute attention

where $Q_i, K_j$, are linear transformations of original embeddings $X^q_i, X^k_j$ and $U^Q_i, U^K_j$ are linear transformations of absolute positional encodings $P_i, P_j$.

They reason out that the attention to a given key should be the same regardless of the position of query. Hence replace $\underset{\color{lightgreen}{C}}{{U^Q_i}^\top K_j}$ with a constant $\underset{\color{lightgreen}{C}}{\color{orange}{v^\top} K_j}$.

For the second and third terms relative positional encodings are introduced. So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is replaced with $\underset{\color{lightgreen}{B}}{Q_i^\top \color{orange}{R_{i - j}}}$ and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$ with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$.

79    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

$\color{orange}{R_k}$

118        key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]

$\color{orange}{S_k}$

120        key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]

$\color{orange}{v^\top}$

122        query_pos_bias = self.query_pos_bias[None, None, :, :]

${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} = Q_i^\top K_j + \color{orange}{v^\top} K_jZ$

127        ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)

$\color{lightgreen}{\mathbf{B’}_{i,k}} = Q_i^\top \color{orange}{R_k}$

129        b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)

$\color{lightgreen}{\mathbf{D’}_{i,k}} = \color{orange}{S_k}$

131        d = key_pos_bias[None, :, None, :]

Shift the rows of $\color{lightgreen}{\mathbf{(B’ + D’)}_{i,k}}$ to get

134        bd = shift_right(b + d)

Remove extra positions

136        bd = bd[:, -key.shape[0]:]

Return the sum

144        return ac + bd
147def _test_shift_right():
148    x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
149    inspect(x)
150    inspect(shift_right(x))
151
152    x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
153    inspect(x[:, :, 0, 0])
154    inspect(shift_right(x)[:, :, 0, 0])
155
156    x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
157    inspect(x[:, :, 0, 0])
158    inspect(shift_right(x)[:, :, 0, 0])
159
160
161if __name__ == '__main__':
162    _test_shift_right()