🚧 compressive transformer

This commit is contained in:
Varuna Jayasiri
2021-02-17 16:36:52 +05:30
parent 9a9d6c671d
commit ff8f80039a
3 changed files with 513 additions and 5 deletions

View File

@ -0,0 +1,183 @@
from typing import Optional, List
import torch
import torch.nn.functional as F
from torch import nn
from labml_helpers.module import Module, TypedModuleList
from labml_nn.transformers.feed_forward import FeedForward
from labml_nn.transformers.mha import PrepareForMultiHeadAttention
from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
from labml_nn.utils import clone_module_list
class Conv1dCompression(Module):
def __init__(self, compression_ratio: int, d_model: int):
super().__init__()
self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_ratio, stride=compression_ratio)
def forward(self, mem: torch.Tensor):
"""
* `mem` has shape `[seq_len, batch, d_model]`
"""
# Change the dimensions of `mem` so that we can run it through the convolution layer.
# The convolution layer accepts in the form `[batch, features, sequence]`
mem = mem.permute(1, 2, 0)
# Get compressed memory
c_mem = self.conv(mem)
# Permute back to form `[seq_len, batch, d_model]`
return c_mem.permute(2, 0, 1)
class CompressiveTransformerLayer(Module):
def __init__(self, *,
d_model: int,
self_attn: RelativeMultiHeadAttention,
feed_forward: FeedForward,
dropout_prob: float,
compress: Conv1dCompression):
"""
* `d_model` is the token embedding size
* `self_attn` is the [self attention module](relative_mha.html)
* `feed_forward` is the feed forward module
* `dropout_prob` is the probability of dropping out after self attention and FFN
"""
super().__init__()
self.compress = compress
self.size = d_model
self.self_attn = self_attn
self.feed_forward = feed_forward
self.dropout = nn.Dropout(dropout_prob)
self.norm_self_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def with_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]):
if mem is None:
return z
if c_mem is not None:
mem = torch.cat((c_mem, mem), dim=0)
mem = self.norm_self_attn(mem)
return torch.cat((mem, z), dim=0)
def forward(self, *,
x: torch.Tensor,
mem: Optional[torch.Tensor],
c_mem: Optional[torch.Tensor],
mask: torch.Tensor):
"""
* `x` are the token level feature vectors of shape `[seq_len, batch_size, d_model]`
* `mem` are the past token level feature vectors of shape `[mem_len + c_mem_len * c, batch_size, d_model]`
* `mask` is a matrix of shape `[seq_len, c_mem_len + mem_len + seq_len, batch_size]` or `[seq_len, c_mem_len + mem_len + seq_len, 1]`.
`mask[i, j]` is true if token at `i` can see token at `j`.
"""
# Normalize the vectors before doing self attention
z = self.norm_self_attn(x)
m_z = self.with_memory(z, mem, c_mem)
# Attention
self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
# Add the attention results
x = x + self.dropout(self_attn)
# Normalize for feed-forward
z = self.norm_ff(x)
# Pass through the feed-forward network
ff = self.feed_forward(z)
# Add the feed-forward results back
x = x + self.dropout(ff)
#
return x
class CompressiveTransformer(Module):
"""
## Transformer XL Model
This consists of multiple transformer XL layers
"""
def __init__(self, layer: CompressiveTransformerLayer, n_layers: int):
super().__init__()
# Make copies of the transformer layer
self.layers = clone_module_list(layer, n_layers)
# Final normalization layer
self.norm = nn.LayerNorm([layer.size])
def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor):
"""
* `x` are the token embeddings vectors of shape `[seq_len, batch_size, d_model]`
* `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]` for each layer
* `mask` is the masking matrix
"""
# List to store token level feature vectors,
# which will be the memories for the next sequential batch.
new_mem = []
# Run through each transformer layer
for i, layer in enumerate(self.layers):
# Add to the list of feature vectors
new_mem.append(x.detach())
# Memory
m = mem[i] if mem else None
# Memory
cm = c_mem[i] if c_mem else None
# Run through the transformer XL layer
x = layer(x=x, mem=m, c_mem=cm, mask=mask)
# Finally, normalize the vectors
return self.norm(x), new_mem
class AttentionReconstructionLoss:
def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
self.layers = layers
self.loss_func = nn.MSELoss()
def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor):
head_shape = x.shape[:-1]
# Linear transform
weight = pmha.linear.weight.detach()
bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None
x = F.linear(x, weight, bias)
# Split last dimension into heads
x = x.view(*head_shape, pmha.heads, pmha.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
return x
def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
query = self.prepare_for_attn(layer.query, query)
key = self.prepare_for_attn(layer.key, key)
value = self.prepare_for_attn(layer.value, value)
# Compute attention scores $Q K^\top$.
# This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
scores = torch.einsum('ibhd,jbhd->ijbh', query, key)
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
scores *= layer.scale
# $softmax$ attention along the key sequence dimension
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = layer.softmax(scores)
# Multiply by values
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
return torch.einsum("ijbh,jbhd->ibhd", attn, value)
def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor):
h = h.detach()
mem = mem.detach()
c_mem = layer.compress(mem)
return self.loss_func(self.attn(layer.self_attn, h, mem, mem),
self.attn(layer.self_attn, h, c_mem, c_mem))
def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]):
losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)]
return sum(losses)

View File

@ -0,0 +1,327 @@
"""
---
title: Compressive Transformer Experiment
summary: This experiment trains a compressive transformer model on tiny Shakespeare dataset.
---
# Compressive Transformer Experiment
This is an annotated PyTorch experiment to train a compressive transformer model.
"""
from typing import List, Tuple, NamedTuple
import torch
import torch.nn as nn
from labml import experiment, tracker, monit, logger
from labml.configs import option
from labml.logger import Text
from labml_helpers.metrics.simple_state import SimpleStateModule
from labml_helpers.module import Module
from labml_helpers.train_valid import BatchIndex, hook_model_outputs
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
CompressiveTransformerLayer, Conv1dCompression
class CompressedMemory(NamedTuple):
mem: List[torch.Tensor]
c_mem: List[torch.Tensor]
class AutoregressiveModel(Module):
"""
## Auto regressive model
"""
def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
super().__init__()
# Token embedding module
self.src_embed = nn.Embedding(n_vocab, d_model)
# Transformer
self.transformer = transformer
# Final layer
self.generator = nn.Linear(d_model, n_vocab)
# Masks
self.mask_x = None
self.mask_mem = None
def forward(self, x: torch.Tensor, mem: CompressedMemory):
# Length of the memory
if mem is not None:
mem, c_mem = mem.mem, mem.c_mem
else:
mem = []
c_mem = []
m_len = len(mem[0]) if mem else 0
if c_mem:
m_len += len(c_mem[0])
# Create a subsequent mask for tokens
if self.mask_x is None or self.mask_x.shape[0] < len(x):
from labml_nn.transformers.utils import subsequent_mask
self.mask_x = subsequent_mask(len(x)).to(x.device)
# Create an all ones (full visibility) mask for memory
if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
# Concatenate the masks if there is memory
if m_len:
mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
# Use the subsequent mask otherwise
else:
mask = self.mask_x[:len(x), :len(x)]
# Token embeddings
x = self.src_embed(x)
# Run it through the transformer
res, mem = self.transformer(x, mem, c_mem, mask)
# Generate logits of the next token
res = self.generator(res)
#
return res, mem
class Configs(NLPAutoRegressionConfigs):
"""
## Configurations
The default configs can and will be over-ridden when we start the experiment
"""
model: AutoregressiveModel
# Token embedding size
d_model: int = 128
# Number of attention heads
heads: int = 4
# Dropout probability
dropout: float = 0.0
# Number of features in FFN hidden layer
d_ff: int = 256
# Number of transformer layers
n_layers: int = 6
# Number of memories to keep
mem_len: int = 8
# State module to maintain memories when switching between training and validation
memory = SimpleStateModule()
# Attention Reconstruction Loss
attention_reconstruction_loss: AttentionReconstructionLoss
# Compression ratio
compression_ratio: int = 4
# Compressed memory length
c_mem_len: int = 128
def init(self):
# Set tracker configurations
tracker.set_scalar("accuracy.*", True)
tracker.set_scalar("loss.*", True)
tracker.set_scalar("ar_loss.*", False)
# Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model')
# This will keep the accuracy metric stats and memories separate for training and validation.
self.state_modules = [self.accuracy, self.memory]
@torch.no_grad()
def merge_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
-> Tuple[CompressedMemory, List[torch.Tensor]]:
"""
Concatenate memories and remove old memories to keep a maximum of
`mem_len` memories.
"""
# If it's configured not to use memory
if self.mem_len == 0:
return CompressedMemory([], []), []
if mem is not None:
mem, c_mem = mem.mem, mem.c_mem
else:
mem, c_mem = [], []
# Concatenate with old memory
if mem:
mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
else:
mem = new_mem
if len(mem[0]) > self.mem_len:
n_c_mem = (len(mem[0]) - self.mem_len + self.compression_ratio - 1) // self.compression_ratio
old_mem = []
trunc_mem = []
for m in mem:
n_old = n_c_mem * self.compression_ratio
cm, m = torch.split(m, [n_old, len(m) - n_old])
old_mem.append(cm)
trunc_mem.append(m)
mem = trunc_mem
new_c_mem = []
for i, layer in enumerate(self.model.transformer.layers):
new_c_mem.append(layer.compress(old_mem[i]))
if c_mem:
c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]
else:
c_mem = new_c_mem
# Truncate old memories
if len(c_mem[0]) > self.c_mem_len:
c_mem = [m[-self.c_mem_len:] for m in c_mem]
else:
old_mem = []
#
return CompressedMemory(mem, c_mem), old_mem
def step(self, batch: any, batch_idx: BatchIndex):
"""
### Training/validation step
"""
# Move data to the device
data, target = batch[0].to(self.device), batch[1].to(self.device)
# Update global step (number of tokens processed) when in training mode
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get memories
mem = self.memory.get()
# Run the model
output, new_mem = self.model(data, mem)
# Merge memory
mem, old_mem = self.merge_memory(mem, new_mem)
# Update memories
self.memory.set(mem)
# Calculate and log cross entropy loss
loss = self.loss_func(output, target)
tracker.add("loss.", loss)
if old_mem:
ar_loss = self.attention_reconstruction_loss(new_mem, old_mem)
tracker.add("ar_loss.", ar_loss)
# loss = loss + ar_loss
# Calculate and log accuracy
self.accuracy(output, target)
self.accuracy.track()
# Train the model
if self.mode.is_train:
# Calculate gradients
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
# Take optimizer step
self.optimizer.step()
# Log the model parameters and gradients on last batch of every epoch
if batch_idx.is_last:
tracker.add('model', self.model)
# Clear the gradients
self.optimizer.zero_grad()
# Save the tracked metrics
tracker.save()
def sample(self):
"""
### Sampling function to generate samples periodically while training
"""
# Starting prompt
prompt = self.prompt
# Collect output for printing
log = [(prompt, Text.subtle)]
# memory
mem = CompressedMemory([], [])
# Sample 25 tokens
for i in monit.iterate('Sample', 25):
# Tokenize the prompt
data = self.text.text_to_i(prompt).unsqueeze(-1)
# Move to device
data = data.to(self.device)
# Get the model output
output, new_mem = self.model(data, mem)
# Get the model prediction (greedy)
output = output.argmax(dim=-1).squeeze(1)
# Add the prediction to prompt
prompt += self.prompt_separator + self.text.itos[output[-1]]
# Only feed the last character to model in next iteration, rest will go in as memories
prompt = prompt[-1:]
# Add the prediction for logging
log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
# Update memory
mem, _ = self.merge_memory(mem, new_mem)
# Print the sampled output
logger.log(log)
@option(Configs.model)
def autoregressive_model(c: Configs):
"""
### Initialize the auto-regressive model
"""
from labml_nn.transformers.xl import RelativeMultiHeadAttention
from labml_nn.transformers.feed_forward import FeedForward
m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
CompressiveTransformerLayer(d_model=c.d_model,
self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
dropout_prob=c.dropout,
compress=Conv1dCompression(c.compression_ratio, c.d_model)), c.n_layers))
return m.to(c.device)
@option(Configs.attention_reconstruction_loss)
def attention_reconstruction_loss(c: Configs):
"""
### Initialize the auto-regressive model
"""
return AttentionReconstructionLoss(c.model.transformer.layers)
def main():
"""
### Run the experiment
"""
# Create experiment
experiment.create(name="compressive_transformer", comment='')
# Create configs
conf = Configs()
# Load configurations
experiment.configs(conf,
# A dictionary of configurations to override
{'tokenizer': 'character',
'text': 'tiny_shakespeare',
'optimizer.learning_rate': 2.5e-4,
'optimizer.optimizer': 'AdamW',
'prompt': 'It is',
'prompt_separator': '',
'train_loader': 'sequential_train_loader',
'valid_loader': 'sequential_valid_loader',
'seq_len': 8,
'mem_len': 8,
'epochs': 128,
'batch_size': 32,
'inner_iterations': 25,
})
# Set models for saving and loading
experiment.add_pytorch_models({'model': conf.model})
# Start the experiment
with experiment.start():
# `TrainValidConfigs.run`
conf.run()
#
if __name__ == '__main__':
main()

View File

@ -9,13 +9,11 @@ summary: A bunch of utility functions and classes
import copy
from torch import nn
from labml_helpers.module import Module
from labml_helpers.module import M, TypedModuleList
def clone_module_list(module: Module, n: int):
def clone_module_list(module: M, n: int) -> TypedModuleList[M]:
"""
## Make a `nn.ModuleList` with clones of a given layer
"""
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
return TypedModuleList([copy.deepcopy(module) for _ in range(n)])