mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 21:40:15 +08:00
📚 compressive transformer docs - work in progress
This commit is contained in:
@ -1,3 +1,42 @@
|
||||
"""
|
||||
---
|
||||
title: Compressive Transformer
|
||||
summary: >
|
||||
Documented implementation with explanations of a
|
||||
Compressive Transformer model.
|
||||
---
|
||||
|
||||
# Compressive Transformer
|
||||
|
||||
This is an implementation of
|
||||
[Compressive Transformers for Long-Range Sequence Modelling](https://arxiv.org/abs/1911.05507)
|
||||
in [PyTorch](https://pytorch.org).
|
||||
|
||||
This is an extension of [Transformer XL](../xl/index.html) where past memories
|
||||
are compressed to give a longer attention range.
|
||||
That is, the furthest $n_{cm} c$ memories are compressed into
|
||||
$n_{cm}$ memories, where $c$ is the compression rate.
|
||||
|
||||
## Compression operation
|
||||
|
||||
The compression operation is defined as
|
||||
$f_c: \mathbb{R}^{nc \times d} \rightarrow \mathbb{R}^{n \times d}$.
|
||||
The paper introduces multiple choices for $f_c$ and we have only implemented
|
||||
1D convolution which seems to give best results.
|
||||
|
||||
## Training compression operation
|
||||
|
||||
Since training compression with BPTT requires maintaining
|
||||
a very large computational graph (many time steps), paper proposes
|
||||
an *auto-encoding loss* and an *attention reconstruction loss*.
|
||||
The auto-encoding loss, decodes the original memories from the compressed memories,
|
||||
and calculate the loss.
|
||||
Attention reconstruction loss computes the multi-headed attention results
|
||||
on the compressed memory and on uncompressed memory and get a mean squared error
|
||||
between them.
|
||||
We have implemented the latter here since it gives better results.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
@ -12,25 +51,41 @@ from labml_nn.utils import clone_module_list
|
||||
|
||||
|
||||
class Conv1dCompression(Module):
|
||||
def __init__(self, compression_ratio: int, d_model: int):
|
||||
"""
|
||||
## 1D Convolution Compression $f_c$
|
||||
|
||||
This is a simple wrapper around
|
||||
[`nn.Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html)
|
||||
with some tensor dimension permutations.
|
||||
"""
|
||||
def __init__(self, compression_rate: int, d_model: int):
|
||||
"""
|
||||
* `compression_rate` $c$
|
||||
* `d_model` is the embedding size
|
||||
"""
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_ratio, stride=compression_ratio)
|
||||
self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_rate, stride=compression_rate)
|
||||
|
||||
def forward(self, mem: torch.Tensor):
|
||||
"""
|
||||
* `mem` has shape `[seq_len, batch, d_model]`
|
||||
`mem` has shape `[seq_len, batch, d_model]`
|
||||
"""
|
||||
|
||||
# Change the dimensions of `mem` so that we can run it through the convolution layer.
|
||||
# Permute the dimensions of `mem` so that we can run it through the convolution layer.
|
||||
# The convolution layer accepts in the form `[batch, features, sequence]`
|
||||
mem = mem.permute(1, 2, 0)
|
||||
# Get compressed memory
|
||||
# Get compressed memory by running it through the convolution layer
|
||||
c_mem = self.conv(mem)
|
||||
# Permute back to form `[seq_len, batch, d_model]`
|
||||
return c_mem.permute(2, 0, 1)
|
||||
|
||||
|
||||
class CompressiveTransformerLayer(Module):
|
||||
"""
|
||||
## Compressive Transformer Layer
|
||||
|
||||
This is the implementation of a single compressive transformer layer
|
||||
"""
|
||||
def __init__(self, *,
|
||||
d_model: int,
|
||||
self_attn: RelativeMultiHeadAttention,
|
||||
@ -39,9 +94,10 @@ class CompressiveTransformerLayer(Module):
|
||||
compress: Conv1dCompression):
|
||||
"""
|
||||
* `d_model` is the token embedding size
|
||||
* `self_attn` is the [self attention module](relative_mha.html)
|
||||
* `feed_forward` is the feed forward module
|
||||
* `self_attn` is the [self attention module](../xl/relative_mha.html)
|
||||
* `feed_forward` is the [feed forward module](../feed_forward.html)
|
||||
* `dropout_prob` is the probability of dropping out after self attention and FFN
|
||||
* `compress` is the compression function $f_c$
|
||||
"""
|
||||
super().__init__()
|
||||
self.compress = compress
|
||||
@ -52,14 +108,25 @@ class CompressiveTransformerLayer(Module):
|
||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||
self.norm_ff = nn.LayerNorm([d_model])
|
||||
|
||||
def with_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):
|
||||
def concat_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):
|
||||
"""
|
||||
Concatenate the normalized token embeddings with memory and compressed memory.
|
||||
|
||||
* `z` is layer normalized token embeddings.
|
||||
* `mem` and `c_mem` are memory and compressed memory (not normalized).
|
||||
"""
|
||||
|
||||
# If there is no memory just return the token embeddings
|
||||
if mem is None:
|
||||
return z
|
||||
|
||||
# If there are compressed memory concatenate that with memory
|
||||
if c_mem is not None:
|
||||
mem = torch.cat((c_mem, mem), dim=0)
|
||||
|
||||
# Run the memory through the normalization layer
|
||||
mem = self.norm_self_attn(mem)
|
||||
# Concatenate normalized memory and normalized token embeddings
|
||||
return torch.cat((mem, z), dim=0)
|
||||
|
||||
def forward(self, *,
|
||||
@ -68,15 +135,17 @@ class CompressiveTransformerLayer(Module):
|
||||
c_mem: Optional[torch.Tensor],
|
||||
mask: torch.Tensor):
|
||||
"""
|
||||
* `x` are the token level feature vectors of shape `[seq_len, batch_size, d_model]`
|
||||
* `mem` are the past token level feature vectors of shape `[mem_len + c_mem_len * c, batch_size, d_model]`
|
||||
* `x` is a tensor of token level feature vectors of shape `[seq_len, batch_size, d_model]`
|
||||
* `mem` is a tensor of the past token level feature vectors (memory) of shape `[mem_len, batch_size, d_model]`
|
||||
* `c_mem` is a tensor of the compressed memory `[c_mem_len, batch_size, d_model]`
|
||||
* `mask` is a matrix of shape `[seq_len, c_mem_len + mem_len + seq_len, batch_size]` or `[seq_len, c_mem_len + mem_len + seq_len, 1]`.
|
||||
`mask[i, j]` is true if token at `i` can see token at `j`.
|
||||
"""
|
||||
|
||||
# Normalize the vectors before doing self attention
|
||||
z = self.norm_self_attn(x)
|
||||
m_z = self.with_memory(z, mem, c_mem)
|
||||
# Normalize and concatenate memory and compressed memory
|
||||
m_z = self.concat_memory(z, mem, c_mem)
|
||||
# Attention
|
||||
self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
|
||||
# Add the attention results
|
||||
@ -95,9 +164,9 @@ class CompressiveTransformerLayer(Module):
|
||||
|
||||
class CompressiveTransformer(Module):
|
||||
"""
|
||||
## Transformer XL Model
|
||||
## Compressive Transformer Model
|
||||
|
||||
This consists of multiple transformer XL layers
|
||||
This consists of multiple compressive transformer layers
|
||||
"""
|
||||
|
||||
def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
|
||||
@ -109,12 +178,15 @@ class CompressiveTransformer(Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):
|
||||
"""
|
||||
* `x` are the token embeddings vectors of shape `[seq_len, batch_size, d_model]`
|
||||
* `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]` for each layer
|
||||
* `x` is a tensor of the token embeddings vectors of shape `[seq_len, batch_size, d_model]`
|
||||
* `mem` is a list of tensors of the past token level feature vectors of shape
|
||||
`[mem_len, batch_size, d_model]` for each layer
|
||||
* `c_mem` is a list of tensors of the compressed memory
|
||||
`[c_mem_len, batch_size, d_model]` for each layer
|
||||
* `mask` is the masking matrix
|
||||
"""
|
||||
# List to store token level feature vectors,
|
||||
# which will be the memories for the next sequential batch.
|
||||
# which will become the memories for the next sequential batch.
|
||||
new_mem = []
|
||||
# Run through each transformer layer
|
||||
for i, layer in enumerate(self.layers):
|
||||
@ -122,7 +194,7 @@ class CompressiveTransformer(Module):
|
||||
new_mem.append(x.detach())
|
||||
# Memory
|
||||
m = mem[i] if mem else None
|
||||
# Memory
|
||||
# Compressed Memory
|
||||
cm = c_mem[i] if c_mem else None
|
||||
# Run through the transformer XL layer
|
||||
x = layer(x=x, mem=m, c_mem=cm, mask=mask)
|
||||
@ -131,7 +203,13 @@ class CompressiveTransformer(Module):
|
||||
|
||||
|
||||
class AttentionReconstructionLoss:
|
||||
"""
|
||||
## Attention Reconstruction Loss
|
||||
"""
|
||||
def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
|
||||
"""
|
||||
`layers` is the list of Compressive Transformer layers
|
||||
"""
|
||||
self.layers = layers
|
||||
self.loss_func = nn.MSELoss()
|
||||
|
||||
|
||||
@ -74,8 +74,8 @@ class TransformerXLLayer(Module):
|
||||
mem: Optional[torch.Tensor],
|
||||
mask: torch.Tensor):
|
||||
"""
|
||||
* `x` are the token level feature vectors of shape `[seq_len, batch_size, d_model]`
|
||||
* `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]`
|
||||
* `x` is a tensor of the token level feature vectors of shape `[seq_len, batch_size, d_model]`
|
||||
* `mem` is a tensor of the past token level feature vectors of shape `[mem_len, batch_size, d_model]`
|
||||
* `mask` is a matrix of shape `[seq_len, mem_len + seq_len, batch_size]` or `[seq_len, mem_len + seq_len, 1]`.
|
||||
`mask[i, j]` is true if token at `i` can see token at `j`.
|
||||
"""
|
||||
@ -122,12 +122,13 @@ class TransformerXL(Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):
|
||||
"""
|
||||
* `x` are the token embeddings vectors of shape `[seq_len, batch_size, d_model]`
|
||||
* `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]` for each layer
|
||||
* `x` is a tensor of the token embeddings vectors of shape `[seq_len, batch_size, d_model]`
|
||||
* `mem` is a list of tensors of the past token level feature vectors of shape
|
||||
`[mem_len, batch_size, d_model]` for each layer
|
||||
* `mask` is the masking matrix
|
||||
"""
|
||||
# List to store token level feature vectors,
|
||||
# which will be the memories for the next sequential batch.
|
||||
# which will become the memories for the next sequential batch.
|
||||
new_mem = []
|
||||
# Run through each transformer layer
|
||||
for i, layer in enumerate(self.layers):
|
||||
|
||||
Reference in New Issue
Block a user