This is an implementation of Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context in PyTorch.
Transformer has a limited attention span, equal to the length of the sequence trained in parallel. All these positions have a fixed positional encoding. Transformer XL increases this attention span by letting each of the positions pay attention to precalculated past embeddings. For instance if the context length is $l$, it will keep the embeddings of all layers for previous batch of length $l$ and feed them to current step. If we use fixed-positional encodings these pre-calculated embeddings will have the same positions as the current context. They introduce relative positional encoding, where the positional encodings are introduced at the attention calculation.
Annotated implementation of relative multi-headed attention is in relative_mha.py.
Here’s the training code and a notebook for training a transformer XL model on Tiny Shakespeare dataset.
36from typing import List, Optional
37
38import torch
39import torch.nn as nn
40
41from labml_helpers.module import Module
42from labml_nn.utils import clone_module_list
43from .relative_mha import RelativeMultiHeadAttention
44from ..feed_forward import FeedForward47class TransformerXLLayer(Module):d_model is the token embedding sizeself_attn is the self attention modulefeed_forward is the feed forward moduledropout_prob is the probability of dropping out after self attention and FFN53    def __init__(self, *,
54                 d_model: int,
55                 self_attn: RelativeMultiHeadAttention,
56                 feed_forward: FeedForward,
57                 dropout_prob: float):64        super().__init__()
65        self.size = d_model
66        self.self_attn = self_attn
67        self.feed_forward = feed_forward
68        self.dropout = nn.Dropout(dropout_prob)
69        self.norm_self_attn = nn.LayerNorm([d_model])
70        self.norm_ff = nn.LayerNorm([d_model])x are the token level feature vectors of shape [seq_len, batch_size, d_model]mem are the past token level feature vectors of shape [mem_len, batch_size, d_model]mask is a matrix of shape [seq_len, mem_len + seq_len, batch_size] or [seq_len, mem_len + seq_len, 1].
mask[i, j] is  true if token at i can see token at j.72    def forward(self, *,
73                x: torch.Tensor,
74                mem: Optional[torch.Tensor],
75                mask: torch.Tensor):Normalize the vectors before doing self attention
83        z = self.norm_self_attn(x)If there is memory
85        if mem is not None:Normalize it
87            mem = self.norm_self_attn(mem)Concatenate with z
89            m_z = torch.cat((mem, z), dim=0)Ignore if there is no memory
91        else:
92            m_z = zAttention
94        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)Add the attention results
96        x = x + self.dropout(self_attn)Normalize for feed-forward
99        z = self.norm_ff(x)Pass through the feed-forward network
101        ff = self.feed_forward(z)Add the feed-forward results back
103        x = x + self.dropout(ff)106        return x109class TransformerXL(Module):116    def __init__(self, layer: TransformerXLLayer, n_layers: int):
117        super().__init__()Make copies of the transformer layer
119        self.layers = clone_module_list(layer, n_layers)Final normalization layer
121        self.norm = nn.LayerNorm([layer.size])x are the token embeddings vectors of shape [seq_len, batch_size, d_model]mem are the past token level feature vectors of shape [mem_len, batch_size, d_model]  for each layermask is the masking matrix123    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):List to store token level feature vectors, which will be the memories for the next sequential batch.
131        new_mem = []Run through each transformer layer
133        for i, layer in enumerate(self.layers):Add to the list of feature vectors
135            new_mem.append(x.detach())Memory
137            m = mem[i] if mem else NoneRun through the transformer XL layer
139            x = layer(x=x, mem=m, mask=mask)Finally, normalize the vectors
141        return self.norm(x), new_mem