This is an implementation of Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.
Transformer has a limited attention span, equal to the length of the sequence trained in parallel. All these positions have a fixed positional encoding. Transformer XL increases this attention span by letting each of the positions pay attention to precalculated past embeddings. For instance if the context length is $l$ it will keep the embeddings of all layers for previous batch of length $l$ and feed them to current step. If we use fixed-positional encodings these pre-calculated embeddings will have the same positions as the current context. They introduce relative positional encoding, where the positional encodings are introduced at the attention calculation.
27import torch
28from torch import nn
29
30from labml.logger import inspect
31from labml_nn.transformers.mha import MultiHeadAttentionThis method shifts $i^{th}$ row of a matrix by $i$ columns.
If the input is [[1, 2 ,3], [4, 5 ,6], [7, 8, 9]], the shifted
result would be [[1, 2 ,3], [0, 4, 5], [9, 0, 7]].
Ideally we should mask out the lower triangle but it’s ok for our purpose.
34def shift_right(x: torch.Tensor):Concatenate a column of zeros
44 zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
45 x_padded = torch.cat([x, zero_pad], dim=1)Remove excess elements from the end
48 x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
49 x = x_padded[:-1].view_as(x)
50
51 return xWe override Multi-Head Attention module so we only need to
write the get_scores method.
54class RelativeMultiHeadAttention(MultiHeadAttention):62 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):The linear transformations doesn’t need a bias since we take care of it when
calculating scores.
However having a bias for value might make sense.
66 super().__init__(heads, d_model, dropout_prob, bias=False)Number of relative positions
69 self.P = 2 ** 12Relative positional embeddings for key relative to the query. We need $2P$ embeddings because the keys can be before or after the query.
73 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)Relative positional embedding bias for key relative to the query.
75 self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)Positional embeddings for the query is independent of the position of the query
77 self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)With absolute attention
where $Q_i, K_j$, are linear transformations of original embeddings $X^q_i, X^k_j$ and $U^Q_i, U^K_j$ are linear transformations of absolute positional encodings $P_i, P_j$.
They reason out that the attention to a given key should be the same regardless of the position of query. Hence replace $\underset{\color{lightgreen}{C}}{{U^Q_i}^\top K_j}$ with a constant $\underset{\color{lightgreen}{C}}{\color{orange}{v^\top} K_j}$.
For the second and third terms relative positional encodings are introduced. So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is replaced with $\underset{\color{lightgreen}{B}}{Q_i^\top \color{orange}{R_{i - j}}}$ and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$ with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$.
79 def get_scores(self, query: torch.Tensor, key: torch.Tensor):$\color{orange}{R_k}$
118 key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]$\color{orange}{S_k}$
120 key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]$\color{orange}{v^\top}$
122 query_pos_bias = self.query_pos_bias[None, None, :, :]${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} = Q_i^\top K_j + \color{orange}{v^\top} K_jZ$
127 ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)$\color{lightgreen}{\mathbf{B’}_{i,k}} = Q_i^\top \color{orange}{R_k}$
129 b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)$\color{lightgreen}{\mathbf{D’}_{i,k}} = \color{orange}{S_k}$
131 d = key_pos_bias[None, :, None, :]Shift the rows of $\color{lightgreen}{\mathbf{(B’ + D’)}_{i,k}}$ to get
134 bd = shift_right(b + d)Remove extra positions
136 bd = bd[:, -key.shape[0]:]Return the sum
144 return ac + bd147def _test_shift_right():
148 x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
149 inspect(x)
150 inspect(shift_right(x))
151
152 x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
153 inspect(x[:, :, 0, 0])
154 inspect(shift_right(x)[:, :, 0, 0])
155
156 x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
157 inspect(x[:, :, 0, 0])
158 inspect(shift_right(x)[:, :, 0, 0])
159
160
161if __name__ == '__main__':
162 _test_shift_right()