Compressive Transformer
This is an implementation of Compressive Transformers for Long-Range Sequence Modelling in PyTorch.
This is an extension of Transformer XL where past memories
are compressed to give a longer attention range.
That is, the furthest n_{cm} c
memories are compressed into
n_{cm}
memories, where c
is the compression rate.
Compression operation
The compression operation is defined as
f_c: \mathbb{R}^{nc \times d} \rightarrow \mathbb{R}^{n \times d}
.
The paper introduces multiple choices for f_c
and we have only implemented
1D convolution which seems to give the best results.
Each layer has a separate compression operation f_c^{(i)}
where
i
is the layer number.
Training compression operation
Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an auto-encoding loss and an attention reconstruction loss. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.
This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before FFN and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.
Here are the training code and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.