Compressive Transformer
This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.
This is an extension of Transformer XL where past memories
are compressed to give a longer attention range.
That is, the furthest n_{cm} c memories are compressed into
n_{cm} memories, where c is the compression rate.
Compression operation
The compression operation is defined as
f_c: \mathbb{R}^{nc \times d} \rightarrow \mathbb{R}^{n \times d}.
The paper introduces multiple choices for f_c and we have only implemented
1D convolution which seems to give the best results.
Each layer has a separate compression operation f_c^{(i)} where
i is the layer number.
Training compression operation
Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an auto-encoding loss and an attention reconstruction loss. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.
This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN[../feedforward.html) and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.
Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.