📚 compressive transformer docs - work in progress

This commit is contained in:
Varuna Jayasiri
2021-02-18 14:58:45 +05:30
parent 7af3e03d24
commit e5751ab341
4 changed files with 417 additions and 233 deletions

View File

@ -1,3 +1,42 @@
"""
---
title: Compressive Transformer
summary: >
Documented implementation with explanations of a
Compressive Transformer model.
---
# Compressive Transformer
This is an implementation of
[Compressive Transformers for Long-Range Sequence Modelling](https://arxiv.org/abs/1911.05507)
in [PyTorch](https://pytorch.org).
This is an extension of [Transformer XL](../xl/index.html) where past memories
are compressed to give a longer attention range.
That is, the furthest $n_{cm} c$ memories are compressed into
$n_{cm}$ memories, where $c$ is the compression rate.
## Compression operation
The compression operation is defined as
$f_c: \mathbb{R}^{nc \times d} \rightarrow \mathbb{R}^{n \times d}$.
The paper introduces multiple choices for $f_c$ and we have only implemented
1D convolution which seems to give best results.
## Training compression operation
Since training compression with BPTT requires maintaining
a very large computational graph (many time steps), paper proposes
an *auto-encoding loss* and an *attention reconstruction loss*.
The auto-encoding loss, decodes the original memories from the compressed memories,
and calculate the loss.
Attention reconstruction loss computes the multi-headed attention results
on the compressed memory and on uncompressed memory and get a mean squared error
between them.
We have implemented the latter here since it gives better results.
"""
from typing import Optional, List
import torch
@ -12,25 +51,41 @@ from labml_nn.utils import clone_module_list
class Conv1dCompression(Module):
def __init__(self, compression_ratio: int, d_model: int):
"""
## 1D Convolution Compression $f_c$
This is a simple wrapper around
[`nn.Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html)
with some tensor dimension permutations.
"""
def __init__(self, compression_rate: int, d_model: int):
"""
* `compression_rate` $c$
* `d_model` is the embedding size
"""
super().__init__()
self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_ratio, stride=compression_ratio)
self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)
def forward(self, mem: torch.Tensor):
"""
* `mem` has shape `[seq_len, batch, d_model]`
`mem` has shape `[seq_len, batch, d_model]`
"""
# Change the dimensions of `mem` so that we can run it through the convolution layer.
# Permute the dimensions of `mem` so that we can run it through the convolution layer.
# The convolution layer accepts in the form `[batch, features, sequence]`
mem = mem.permute(1, 2, 0)
# Get compressed memory
# Get compressed memory by running it through the convolution layer
c_mem = self.conv(mem)
# Permute back to form `[seq_len, batch, d_model]`
return c_mem.permute(2, 0, 1)
class CompressiveTransformerLayer(Module):
"""
## Compressive Transformer Layer
This is the implementation of a single compressive transformer layer
"""
def __init__(self, *,
d_model: int,
self_attn: RelativeMultiHeadAttention,
@ -39,9 +94,10 @@ class CompressiveTransformerLayer(Module):
compress: Conv1dCompression):
"""
* `d_model` is the token embedding size
* `self_attn` is the [self attention module](relative_mha.html)
* `feed_forward` is the feed forward module
* `self_attn` is the [self attention module](../xl/relative_mha.html)
* `feed_forward` is the [feed forward module](../feed_forward.html)
* `dropout_prob` is the probability of dropping out after self attention and FFN
* `compress` is the compression function $f_c$
"""
super().__init__()
self.compress = compress
@ -52,14 +108,25 @@ class CompressiveTransformerLayer(Module):
self.norm_self_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def with_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):
def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):
"""
Concatenate the normalized token embeddings with memory and compressed memory.
* `z` is layer normalized token embeddings.
* `mem` and `c_mem` are memory and compressed memory (not normalized).
"""
# If there is no memory just return the token embeddings
if mem is None:
return z
# If there are compressed memory concatenate that with memory
if c_mem is not None:
mem = torch.cat((c_mem, mem), dim=0)
# Run the memory through the normalization layer
mem = self.norm_self_attn(mem)
# Concatenate normalized memory and normalized token embeddings
return torch.cat((mem, z), dim=0)
def forward(self, *,
@ -68,15 +135,17 @@ class CompressiveTransformerLayer(Module):
c_mem: Optional[torch.Tensor],
mask: torch.Tensor):
"""
* `x` are the token level feature vectors of shape `[seq_len, batch_size, d_model]`
* `mem` are the past token level feature vectors of shape `[mem_len + c_mem_len * c, batch_size, d_model]`
* `x` is a tensor of token level feature vectors of shape `[seq_len, batch_size, d_model]`
* `mem` is a tensor of the past token level feature vectors (memory) of shape `[mem_len, batch_size, d_model]`
* `c_mem` is a tensor of the compressed memory `[c_mem_len, batch_size, d_model]`
* `mask` is a matrix of shape `[seq_len, c_mem_len + mem_len + seq_len, batch_size]` or `[seq_len, c_mem_len + mem_len + seq_len, 1]`.
`mask[i, j]` is true if token at `i` can see token at `j`.
"""
# Normalize the vectors before doing self attention
z = self.norm_self_attn(x)
m_z = self.with_memory(z, mem, c_mem)
# Normalize and concatenate memory and compressed memory
m_z = self.concat_memory(z, mem, c_mem)
# Attention
self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
# Add the attention results
@ -95,9 +164,9 @@ class CompressiveTransformerLayer(Module):
class CompressiveTransformer(Module):
"""
## Transformer XL Model
## Compressive Transformer Model
This consists of multiple transformer XL layers
This consists of multiple compressive transformer layers
"""
def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
@ -109,12 +178,15 @@ class CompressiveTransformer(Module):
def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):
"""
* `x` are the token embeddings vectors of shape `[seq_len, batch_size, d_model]`
* `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]` for each layer
* `x` is a tensor of the token embeddings vectors of shape `[seq_len, batch_size, d_model]`
* `mem` is a list of tensors of the past token level feature vectors of shape
`[mem_len, batch_size, d_model]` for each layer
* `c_mem` is a list of tensors of the compressed memory
`[c_mem_len, batch_size, d_model]` for each layer
* `mask` is the masking matrix
"""
# List to store token level feature vectors,
# which will be the memories for the next sequential batch.
# which will become the memories for the next sequential batch.
new_mem = []
# Run through each transformer layer
for i, layer in enumerate(self.layers):
@ -122,7 +194,7 @@ class CompressiveTransformer(Module):
new_mem.append(x.detach())
# Memory
m = mem[i] if mem else None
# Memory
# Compressed Memory
cm = c_mem[i] if c_mem else None
# Run through the transformer XL layer
x = layer(x=x, mem=m, c_mem=cm, mask=mask)
@ -131,7 +203,13 @@ class CompressiveTransformer(Module):
class AttentionReconstructionLoss:
"""
## Attention Reconstruction Loss
"""
def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
"""
`layers` is the list of Compressive Transformer layers
"""
self.layers = layers
self.loss_func = nn.MSELoss()

View File

@ -74,8 +74,8 @@ class TransformerXLLayer(Module):
mem: Optional[torch.Tensor],
mask: torch.Tensor):
"""
* `x` are the token level feature vectors of shape `[seq_len, batch_size, d_model]`
* `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]`
* `x` is a tensor of the token level feature vectors of shape `[seq_len, batch_size, d_model]`
* `mem` is a tensor of the past token level feature vectors of shape `[mem_len, batch_size, d_model]`
* `mask` is a matrix of shape `[seq_len, mem_len + seq_len, batch_size]` or `[seq_len, mem_len + seq_len, 1]`.
`mask[i, j]` is true if token at `i` can see token at `j`.
"""
@ -122,12 +122,13 @@ class TransformerXL(Module):
def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):
"""
* `x` are the token embeddings vectors of shape `[seq_len, batch_size, d_model]`
* `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]` for each layer
* `x` is a tensor of the token embeddings vectors of shape `[seq_len, batch_size, d_model]`
* `mem` is a list of tensors of the past token level feature vectors of shape
`[mem_len, batch_size, d_model]` for each layer
* `mask` is the masking matrix
"""
# List to store token level feature vectors,
# which will be the memories for the next sequential batch.
# which will become the memories for the next sequential batch.
new_mem = []
# Run through each transformer layer
for i, layer in enumerate(self.layers):