This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.
This is an extension of Transformer XL where past memories are compressed to give a longer attention range. That is, the furthest memories are compressed into memories, where is the compression rate.
The compression operation is defined as . The paper introduces multiple choices for and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation where is the layer number.
Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an auto-encoding loss and an attention reconstruction loss. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.
This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.
Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.
54from typing import Optional, List
55
56import torch
57import torch.nn.functional as F
58from torch import nn
59
60from labml_helpers.module import Module, TypedModuleList
61from labml_nn.transformers.feed_forward import FeedForward
62from labml_nn.transformers.mha import PrepareForMultiHeadAttention
63from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
64from labml_nn.utils import clone_module_listThis is a simple wrapper around nn.Conv1d
 with some tensor dimension permutations.
67class Conv1dCompression(Module):compression_rate
  d_model
 is the embedding size75    def __init__(self, compression_rate: int, d_model: int):80        super().__init__()
81        self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate) mem
 has shape [seq_len, batch, d_model]
83    def forward(self, mem: torch.Tensor):Permute the dimensions of mem
 so that we can run it through the convolution layer. The convolution layer accepts in the form [batch, features, sequence]
 
90        mem = mem.permute(1, 2, 0)Get compressed memory by running it through the convolution layer
92        c_mem = self.conv(mem)Permute back to form [seq_len, batch, d_model]
 
94        return c_mem.permute(2, 0, 1)This is the implementation of a single compressive transformer layer
97class CompressiveTransformerLayer(Module):d_model
 is the token embedding size self_attn
 is the self attention module feed_forward
 is the feed forward module dropout_prob
 is the probability of dropping out after self attention and FFN compress
 is the compression function 103    def __init__(self, *,
104                 d_model: int,
105                 self_attn: RelativeMultiHeadAttention,
106                 feed_forward: FeedForward,
107                 dropout_prob: float,
108                 compress: Conv1dCompression):116        super().__init__()
117        self.compress = compress
118        self.size = d_model
119        self.self_attn = self_attn
120        self.feed_forward = feed_forward
121        self.dropout = nn.Dropout(dropout_prob)
122        self.norm_self_attn = nn.LayerNorm([d_model])
123        self.norm_ff = nn.LayerNorm([d_model])Concatenate the normalized token embeddings with memory and compressed memory.
z
 is layer normalized token embeddings. mem
 and c_mem
 are memory and compressed memory (not normalized).125    def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):If there is no memory just return the token embeddings
134        if mem is None:
135            return zIf there are compressed memory concatenate that with memory
138        if c_mem is not None:
139            mem = torch.cat((c_mem, mem), dim=0)Run the memory through the normalization layer
142        mem = self.norm_self_attn(mem)Concatenate normalized memory and normalized token embeddings
144        return torch.cat((mem, z), dim=0)x
 is a tensor of token level feature vectors of shape [seq_len, batch_size, d_model]
 mem
 is a tensor of the past token level feature vectors (memory) of shape [mem_len, batch_size, d_model]
 c_mem
 is a tensor of the compressed memory [c_mem_len, batch_size, d_model]
 mask
 is a matrix of shape [seq_len, c_mem_len + mem_len + seq_len, batch_size]
 or [seq_len, c_mem_len + mem_len + seq_len, 1]
. mask[i, j]
 is true if token at i
 can see token at j
.146    def forward(self, *,
147                x: torch.Tensor,
148                mem: Optional[torch.Tensor],
149                c_mem: Optional[torch.Tensor],
150                mask: torch.Tensor):Normalize the vectors before doing self attention
160        z = self.norm_self_attn(x)Normalize and concatenate memory and compressed memory
162        m_z = self.concat_memory(z, mem, c_mem)Attention
164        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)Add the attention results
166        x = x + self.dropout(self_attn)Normalize for feed-forward
169        z = self.norm_ff(x)Pass through the feed-forward network
171        ff = self.feed_forward(z)Add the feed-forward results back
173        x = x + self.dropout(ff)176        return x179class CompressiveTransformer(Module):186    def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
187        super().__init__()Make copies of the transformer layer
189        self.layers = clone_module_list(layer, n_layers)Final normalization layer
191        self.norm = nn.LayerNorm([layer.size])x
 is a tensor of the token embeddings vectors of shape [seq_len, batch_size, d_model]
 mem
 is a list of tensors of the past token level feature vectors of shape  [mem_len, batch_size, d_model]
 for each layer c_mem
 is a list of tensors of the compressed memory  [c_mem_len, batch_size, d_model]
 for each layer mask
 is the masking matrix193    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):List to store token level feature vectors, which will become the memories for the next sequential batch.
204        new_mem = []Run through each transformer layer
206        for i, layer in enumerate(self.layers):Add to the list of feature vectors
208            new_mem.append(x.detach())Memory
210            m = mem[i] if mem else NoneCompressed Memory
212            cm = c_mem[i] if c_mem else NoneRun through the transformer XL layer
214            x = layer(x=x, mem=m, c_mem=cm, mask=mask)Finally, normalize the vectors
216        return self.norm(x), new_memAttention reconstruction loss recreates the self-attention output with uncompressed memory and with compressed memory and calculates the mean squared error between the two. It does this without positional encoding.
When calculating and training the compression function with attention reconstruction loss, all parameters but are frozen. This includes key/value projections and bias/scaling after normalization.
Since this loss can be computed independently of the cross-entropy-loss of the model you can have a separate optimizer that only updates . However, we use the same optimizer to update so when calculating attention reconstruction loss, we detach all other parameters except from the gradient computation.
219class AttentionReconstructionLoss: layers
 is the list of Compressive Transformer layers
237    def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):241        self.layers = layers
242        self.loss_func = nn.MSELoss()This is a reimplementation of 'PrepareForMultiHeadAttention' where the projections are done with the parameters detached from gradient computation.
pmha
 is the 'PrepareForMultiHeadAttention' module x
 is tensor with the token embeddings244    def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):Shape of the input except embedding dimension; [seq_len, batch_size]
. 
254        head_shape = x.shape[:-1]Detach projection weights and bias
257        weight = pmha.linear.weight.detach()
258        bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else NoneLinear transform
260        x = F.linear(x, weight, bias)Split last dimension into heads
263        x = x.view(*head_shape, pmha.heads, pmha.d_k)Output has shape [seq_len, batch_size, heads, d_k]
 or [batch_size, d_model]
 
266        return x This is a reimplementation of 'Multi-Head Attention' which calls prepare_for_attn
 instead of 'PrepareForMultiHeadAttention' to detach projection parameters.
268    def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):Calculate query, key and value projections
275        query = self.prepare_for_attn(layer.query, query)
276        key = self.prepare_for_attn(layer.key, key)
277        value = self.prepare_for_attn(layer.value, value)Compute attention scores . This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
. 
281        scores = torch.einsum('ibhd,jbhd->ijbh', query, key)Scale scores
284        scores *= layer.scaleattention along the key sequence dimension
288        attn = layer.softmax(scores)Multiply by values
292        return torch.einsum("ijbh,jbhd->ibhd", attn, value)Perform layer normalization with shift and scale parameters detached.
294    def norm(self, ln: nn.LayerNorm, x: torch.Tensor):Detach shift(bias
) and scaling(weight
) parameters 
300        weight = ln.weight.detach() if ln.weight is not None else None
301        bias = ln.bias.detach() if ln.bias is not None else NoneLayer normalization
304        return F.layer_norm(x, ln.normalized_shape, weight, bias, ln.eps)This calculates the loss for a layer
306    def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):Detach the token embeddings and memory.
312        h = h.detach()
313        mem = mem.detach()Compress the memory with . The parameters of are the only parameters not detached from gradient computation.
317        c_mem = layer.compress(mem)Normalize the embeddings and memories
320        h = self.norm(layer.norm_self_attn, h)
321        mem = self.norm(layer.norm_self_attn, mem)
322        c_mem = self.norm(layer.norm_self_attn, c_mem)Calculate the attention with uncompressed memory
325        attn_mem = self.attn(layer.self_attn, h, mem, mem)Calculate the attention with compressed memory
327        attn_cmem = self.attn(layer.self_attn, h, c_mem, c_mem)Calculate the mean square error
330        return self.loss_func(attn_cmem, attn_mem)332    def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):Calculate the losses for each layer
334        losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]Sum of the losses
336        return sum(losses)