This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.
This is an extension of Transformer XL where past memories are compressed to give a longer attention range. That is, the furthest memories are compressed into memories, where is the compression rate.
The compression operation is defined as . The paper introduces multiple choices for and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation where is the layer number.
Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an auto-encoding loss and an attention reconstruction loss. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.
This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.
Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.
53from typing import Optional, List
54
55import torch
56import torch.nn.functional as F
57from torch import nn
58
59from labml_nn.transformers.feed_forward import FeedForward
60from labml_nn.transformers.mha import PrepareForMultiHeadAttention
61from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
62from labml_nn.utils import clone_module_listThis is a simple wrapper around nn.Conv1d
 with some tensor dimension permutations.
65class Conv1dCompression(nn.Module):compression_rate
  d_model
 is the embedding size73    def __init__(self, compression_rate: int, d_model: int):78        super().__init__()
79        self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate) mem
 has shape [seq_len, batch, d_model]
81    def forward(self, mem: torch.Tensor):Permute the dimensions of mem
 so that we can run it through the convolution layer. The convolution layer accepts in the form [batch, features, sequence]
 
88        mem = mem.permute(1, 2, 0)Get compressed memory by running it through the convolution layer
90        c_mem = self.conv(mem)Permute back to form [seq_len, batch, d_model]
 
92        return c_mem.permute(2, 0, 1)This is the implementation of a single compressive transformer layer
95class CompressiveTransformerLayer(nn.Module):d_model
 is the token embedding size self_attn
 is the self attention module feed_forward
 is the feed forward module dropout_prob
 is the probability of dropping out after self attention and FFN compress
 is the compression function 101    def __init__(self, *,
102                 d_model: int,
103                 self_attn: RelativeMultiHeadAttention,
104                 feed_forward: FeedForward,
105                 dropout_prob: float,
106                 compress: Conv1dCompression):114        super().__init__()
115        self.compress = compress
116        self.size = d_model
117        self.self_attn = self_attn
118        self.feed_forward = feed_forward
119        self.dropout = nn.Dropout(dropout_prob)
120        self.norm_self_attn = nn.LayerNorm([d_model])
121        self.norm_ff = nn.LayerNorm([d_model])Concatenate the normalized token embeddings with memory and compressed memory.
z
 is layer normalized token embeddings. mem
 and c_mem
 are memory and compressed memory (not normalized).123    def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):If there is no memory just return the token embeddings
132        if mem is None:
133            return zIf there are compressed memory concatenate that with memory
136        if c_mem is not None:
137            mem = torch.cat((c_mem, mem), dim=0)Run the memory through the normalization layer
140        mem = self.norm_self_attn(mem)Concatenate normalized memory and normalized token embeddings
142        return torch.cat((mem, z), dim=0)x
 is a tensor of token level feature vectors of shape [seq_len, batch_size, d_model]
 mem
 is a tensor of the past token level feature vectors (memory) of shape [mem_len, batch_size, d_model]
 c_mem
 is a tensor of the compressed memory [c_mem_len, batch_size, d_model]
 mask
 is a matrix of shape [seq_len, c_mem_len + mem_len + seq_len, batch_size]
 or [seq_len, c_mem_len + mem_len + seq_len, 1]
. mask[i, j]
 is true if token at i
 can see token at j
.144    def forward(self, *,
145                x: torch.Tensor,
146                mem: Optional[torch.Tensor],
147                c_mem: Optional[torch.Tensor],
148                mask: torch.Tensor):Normalize the vectors before doing self attention
158        z = self.norm_self_attn(x)Normalize and concatenate memory and compressed memory
160        m_z = self.concat_memory(z, mem, c_mem)Attention
162        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)Add the attention results
164        x = x + self.dropout(self_attn)Normalize for feed-forward
167        z = self.norm_ff(x)Pass through the feed-forward network
169        ff = self.feed_forward(z)Add the feed-forward results back
171        x = x + self.dropout(ff)174        return x177class CompressiveTransformer(nn.Module):184    def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
185        super().__init__()Make copies of the transformer layer
187        self.layers = clone_module_list(layer, n_layers)Final normalization layer
189        self.norm = nn.LayerNorm([layer.size])x
 is a tensor of the token embeddings vectors of shape [seq_len, batch_size, d_model]
 mem
 is a list of tensors of the past token level feature vectors of shape  [mem_len, batch_size, d_model]
 for each layer c_mem
 is a list of tensors of the compressed memory  [c_mem_len, batch_size, d_model]
 for each layer mask
 is the masking matrix191    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):List to store token level feature vectors, which will become the memories for the next sequential batch.
202        new_mem = []Run through each transformer layer
204        for i, layer in enumerate(self.layers):Add to the list of feature vectors
206            new_mem.append(x.detach())Memory
208            m = mem[i] if mem else NoneCompressed Memory
210            cm = c_mem[i] if c_mem else NoneRun through the transformer XL layer
212            x = layer(x=x, mem=m, c_mem=cm, mask=mask)Finally, normalize the vectors
214        return self.norm(x), new_memAttention reconstruction loss recreates the self-attention output with uncompressed memory and with compressed memory and calculates the mean squared error between the two. It does this without positional encoding.
When calculating and training the compression function with attention reconstruction loss, all parameters but are frozen. This includes key/value projections and bias/scaling after normalization.
Since this loss can be computed independently of the cross-entropy-loss of the model you can have a separate optimizer that only updates . However, we use the same optimizer to update so when calculating attention reconstruction loss, we detach all other parameters except from the gradient computation.
217class AttentionReconstructionLoss: layers
 is the list of Compressive Transformer layers
235    def __init__(self, layers: nn.ModuleList):239        self.layers = layers
240        self.loss_func = nn.MSELoss()This is a reimplementation of 'PrepareForMultiHeadAttention' where the projections are done with the parameters detached from gradient computation.
pmha
 is the 'PrepareForMultiHeadAttention' module x
 is tensor with the token embeddings242    def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):Shape of the input except embedding dimension; [seq_len, batch_size]
. 
252        head_shape = x.shape[:-1]Detach projection weights and bias
255        weight = pmha.linear.weight.detach()
256        bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else NoneLinear transform
258        x = F.linear(x, weight, bias)Split last dimension into heads
261        x = x.view(*head_shape, pmha.heads, pmha.d_k)Output has shape [seq_len, batch_size, heads, d_k]
 or [batch_size, d_model]
 
264        return x This is a reimplementation of 'Multi-Head Attention' which calls prepare_for_attn
 instead of 'PrepareForMultiHeadAttention' to detach projection parameters.
266    def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):Calculate query, key and value projections
273        query = self.prepare_for_attn(layer.query, query)
274        key = self.prepare_for_attn(layer.key, key)
275        value = self.prepare_for_attn(layer.value, value)Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
. 
279        scores = torch.einsum('ibhd,jbhd->ijbh', query, key)Scale scores
282        scores *= layer.scaleattention along the key sequence dimension
286        attn = layer.softmax(scores)Multiply by values
290        return torch.einsum("ijbh,jbhd->ibhd", attn, value)Perform layer normalization with shift and scale parameters detached.
292    def norm(self, ln: nn.LayerNorm, x: torch.Tensor):Detach shift(bias
) and scaling(weight
) parameters 
298        weight = ln.weight.detach() if ln.weight is not None else None
299        bias = ln.bias.detach() if ln.bias is not None else NoneLayer normalization
302        return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)This calculates the loss for a layer
304    def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):Detach the token embeddings and memory.
310        h = h.detach()
311        mem = mem.detach()Compress the memory with . The parameters of are the only parameters not detached from gradient computation.
315        c_mem = layer.compress(mem)Normalize the embeddings and memories
318        h = self.norm(layer.norm_self_attn, h)
319        mem = self.norm(layer.norm_self_attn, mem)
320        c_mem = self.norm(layer.norm_self_attn, c_mem)Calculate the attention with uncompressed memory
323        attn_mem = self.attn(layer.self_attn, h, mem, mem)Calculate the attention with compressed memory
325        attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)Calculate the mean square error
328        return self.loss_func(attn_cmem, attn_mem)330    def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):Calculate the losses for each layer
332        losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]Sum of the losses
334        return sum(losses)