diff --git a/labml_nn/transformers/compressive/__init__.py b/labml_nn/transformers/compressive/__init__.py new file mode 100644 index 00000000..d0a22c83 --- /dev/null +++ b/labml_nn/transformers/compressive/__init__.py @@ -0,0 +1,183 @@ +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn + +from labml_helpers.module import Module, TypedModuleList +from labml_nn.transformers.feed_forward import FeedForward +from labml_nn.transformers.mha import PrepareForMultiHeadAttention +from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention +from labml_nn.utils import clone_module_list + + +class Conv1dCompression(Module): + def __init__(self, compression_ratio: int, d_model: int): + super().__init__() + self.conv = nn.Conv1d(d_model, d_model, kernel_size=compression_ratio, stride=compression_ratio) + + def forward(self, mem: torch.Tensor): + """ + * `mem` has shape `[seq_len, batch, d_model]` + """ + + # Change the dimensions of `mem` so that we can run it through the convolution layer. + # The convolution layer accepts in the form `[batch, features, sequence]` + mem = mem.permute(1, 2, 0) + # Get compressed memory + c_mem = self.conv(mem) + # Permute back to form `[seq_len, batch, d_model]` + return c_mem.permute(2, 0, 1) + + +class CompressiveTransformerLayer(Module): + def __init__(self, *, + d_model: int, + self_attn: RelativeMultiHeadAttention, + feed_forward: FeedForward, + dropout_prob: float, + compress: Conv1dCompression): + """ + * `d_model` is the token embedding size + * `self_attn` is the [self attention module](relative_mha.html) + * `feed_forward` is the feed forward module + * `dropout_prob` is the probability of dropping out after self attention and FFN + """ + super().__init__() + self.compress = compress + self.size = d_model + self.self_attn = self_attn + self.feed_forward = feed_forward + self.dropout = nn.Dropout(dropout_prob) + self.norm_self_attn = nn.LayerNorm([d_model]) + self.norm_ff = nn.LayerNorm([d_model]) + + def with_memory(self, z: torch.Tensor, mem: Optional[torch.Tensor], c_mem: Optional[torch.Tensor]): + if mem is None: + return z + + if c_mem is not None: + mem = torch.cat((c_mem, mem), dim=0) + + mem = self.norm_self_attn(mem) + return torch.cat((mem, z), dim=0) + + def forward(self, *, + x: torch.Tensor, + mem: Optional[torch.Tensor], + c_mem: Optional[torch.Tensor], + mask: torch.Tensor): + """ + * `x` are the token level feature vectors of shape `[seq_len, batch_size, d_model]` + * `mem` are the past token level feature vectors of shape `[mem_len + c_mem_len * c, batch_size, d_model]` + * `mask` is a matrix of shape `[seq_len, c_mem_len + mem_len + seq_len, batch_size]` or `[seq_len, c_mem_len + mem_len + seq_len, 1]`. + `mask[i, j]` is true if token at `i` can see token at `j`. + """ + + # Normalize the vectors before doing self attention + z = self.norm_self_attn(x) + m_z = self.with_memory(z, mem, c_mem) + # Attention + self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask) + # Add the attention results + x = x + self.dropout(self_attn) + + # Normalize for feed-forward + z = self.norm_ff(x) + # Pass through the feed-forward network + ff = self.feed_forward(z) + # Add the feed-forward results back + x = x + self.dropout(ff) + + # + return x + + +class CompressiveTransformer(Module): + """ + ## Transformer XL Model + + This consists of multiple transformer XL layers + """ + + def __init__(self, layer: CompressiveTransformerLayer, n_layers: int): + super().__init__() + # Make copies of the transformer layer + self.layers = clone_module_list(layer, n_layers) + # Final normalization layer + self.norm = nn.LayerNorm([layer.size]) + + def forward(self, x: torch.Tensor, mem: List[torch.Tensor], c_mem: List[torch.Tensor], mask: torch.Tensor): + """ + * `x` are the token embeddings vectors of shape `[seq_len, batch_size, d_model]` + * `mem` are the past token level feature vectors of shape `[mem_len, batch_size, d_model]` for each layer + * `mask` is the masking matrix + """ + # List to store token level feature vectors, + # which will be the memories for the next sequential batch. + new_mem = [] + # Run through each transformer layer + for i, layer in enumerate(self.layers): + # Add to the list of feature vectors + new_mem.append(x.detach()) + # Memory + m = mem[i] if mem else None + # Memory + cm = c_mem[i] if c_mem else None + # Run through the transformer XL layer + x = layer(x=x, mem=m, c_mem=cm, mask=mask) + # Finally, normalize the vectors + return self.norm(x), new_mem + + +class AttentionReconstructionLoss: + def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]): + self.layers = layers + self.loss_func = nn.MSELoss() + + def prepare_for_attn(self, pmha: PrepareForMultiHeadAttention, x: torch.Tensor): + head_shape = x.shape[:-1] + + # Linear transform + weight = pmha.linear.weight.detach() + bias = pmha.linear.bias.detach() if pmha.linear.bias is not None else None + x = F.linear(x, weight, bias) + + # Split last dimension into heads + x = x.view(*head_shape, pmha.heads, pmha.d_k) + + # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]` + return x + + def attn(self, layer: RelativeMultiHeadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): + query = self.prepare_for_attn(layer.query, query) + key = self.prepare_for_attn(layer.key, key) + value = self.prepare_for_attn(layer.value, value) + + # Compute attention scores $Q K^\top$. + # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`. + scores = torch.einsum('ibhd,jbhd->ijbh', query, key) + + # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$ + scores *= layer.scale + + # $softmax$ attention along the key sequence dimension + # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$ + attn = layer.softmax(scores) + + # Multiply by values + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ + return torch.einsum("ijbh,jbhd->ibhd", attn, value) + + def calc_loss(self, layer: CompressiveTransformerLayer, h: torch.Tensor, mem: torch.Tensor): + h = h.detach() + mem = mem.detach() + + c_mem = layer.compress(mem) + + return self.loss_func(self.attn(layer.self_attn, h, mem, mem), + self.attn(layer.self_attn, h, c_mem, c_mem)) + + def __call__(self, h: List[torch.Tensor], mem: List[torch.Tensor]): + losses = [self.calc_loss(layer, h[n], mem[n]) for n, layer in enumerate(self.layers)] + return sum(losses) diff --git a/labml_nn/transformers/compressive/experiment.py b/labml_nn/transformers/compressive/experiment.py new file mode 100644 index 00000000..a7430ca7 --- /dev/null +++ b/labml_nn/transformers/compressive/experiment.py @@ -0,0 +1,327 @@ +""" +--- +title: Compressive Transformer Experiment +summary: This experiment trains a compressive transformer model on tiny Shakespeare dataset. +--- + +# Compressive Transformer Experiment + +This is an annotated PyTorch experiment to train a compressive transformer model. +""" +from typing import List, Tuple, NamedTuple + +import torch +import torch.nn as nn + +from labml import experiment, tracker, monit, logger +from labml.configs import option +from labml.logger import Text +from labml_helpers.metrics.simple_state import SimpleStateModule +from labml_helpers.module import Module +from labml_helpers.train_valid import BatchIndex, hook_model_outputs +from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs +from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \ + CompressiveTransformerLayer, Conv1dCompression + + +class CompressedMemory(NamedTuple): + mem: List[torch.Tensor] + c_mem: List[torch.Tensor] + + +class AutoregressiveModel(Module): + """ + ## Auto regressive model + """ + + def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer): + super().__init__() + # Token embedding module + self.src_embed = nn.Embedding(n_vocab, d_model) + # Transformer + self.transformer = transformer + # Final layer + self.generator = nn.Linear(d_model, n_vocab) + # Masks + self.mask_x = None + self.mask_mem = None + + def forward(self, x: torch.Tensor, mem: CompressedMemory): + # Length of the memory + if mem is not None: + mem, c_mem = mem.mem, mem.c_mem + else: + mem = [] + c_mem = [] + + m_len = len(mem[0]) if mem else 0 + if c_mem: + m_len += len(c_mem[0]) + + # Create a subsequent mask for tokens + if self.mask_x is None or self.mask_x.shape[0] < len(x): + from labml_nn.transformers.utils import subsequent_mask + self.mask_x = subsequent_mask(len(x)).to(x.device) + # Create an all ones (full visibility) mask for memory + if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x): + self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1) + + # Concatenate the masks if there is memory + if m_len: + mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1) + # Use the subsequent mask otherwise + else: + mask = self.mask_x[:len(x), :len(x)] + + # Token embeddings + x = self.src_embed(x) + # Run it through the transformer + res, mem = self.transformer(x, mem, c_mem, mask) + # Generate logits of the next token + res = self.generator(res) + # + return res, mem + + +class Configs(NLPAutoRegressionConfigs): + """ + ## Configurations + + The default configs can and will be over-ridden when we start the experiment + """ + + model: AutoregressiveModel + + # Token embedding size + d_model: int = 128 + # Number of attention heads + heads: int = 4 + # Dropout probability + dropout: float = 0.0 + # Number of features in FFN hidden layer + d_ff: int = 256 + # Number of transformer layers + n_layers: int = 6 + # Number of memories to keep + mem_len: int = 8 + # State module to maintain memories when switching between training and validation + memory = SimpleStateModule() + # Attention Reconstruction Loss + attention_reconstruction_loss: AttentionReconstructionLoss + # Compression ratio + compression_ratio: int = 4 + # Compressed memory length + c_mem_len: int = 128 + + def init(self): + # Set tracker configurations + tracker.set_scalar("accuracy.*", True) + tracker.set_scalar("loss.*", True) + tracker.set_scalar("ar_loss.*", False) + # Add a hook to log module outputs + hook_model_outputs(self.mode, self.model, 'model') + # This will keep the accuracy metric stats and memories separate for training and validation. + self.state_modules = [self.accuracy, self.memory] + + @torch.no_grad() + def merge_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \ + -> Tuple[CompressedMemory, List[torch.Tensor]]: + """ + Concatenate memories and remove old memories to keep a maximum of + `mem_len` memories. + """ + + # If it's configured not to use memory + if self.mem_len == 0: + return CompressedMemory([], []), [] + + if mem is not None: + mem, c_mem = mem.mem, mem.c_mem + else: + mem, c_mem = [], [] + # Concatenate with old memory + if mem: + mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)] + else: + mem = new_mem + + if len(mem[0]) > self.mem_len: + n_c_mem = (len(mem[0]) - self.mem_len + self.compression_ratio - 1) // self.compression_ratio + old_mem = [] + trunc_mem = [] + for m in mem: + n_old = n_c_mem * self.compression_ratio + cm, m = torch.split(m, [n_old, len(m) - n_old]) + old_mem.append(cm) + trunc_mem.append(m) + mem = trunc_mem + + new_c_mem = [] + for i, layer in enumerate(self.model.transformer.layers): + new_c_mem.append(layer.compress(old_mem[i])) + + if c_mem: + c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)] + else: + c_mem = new_c_mem + + # Truncate old memories + if len(c_mem[0]) > self.c_mem_len: + c_mem = [m[-self.c_mem_len:] for m in c_mem] + else: + old_mem = [] + + # + return CompressedMemory(mem, c_mem), old_mem + + def step(self, batch: any, batch_idx: BatchIndex): + """ + ### Training/validation step + """ + + # Move data to the device + data, target = batch[0].to(self.device), batch[1].to(self.device) + + # Update global step (number of tokens processed) when in training mode + if self.mode.is_train: + tracker.add_global_step(data.shape[0] * data.shape[1]) + + # Whether to capture model outputs + with self.mode.update(is_log_activations=batch_idx.is_last): + # Get memories + mem = self.memory.get() + # Run the model + output, new_mem = self.model(data, mem) + # Merge memory + mem, old_mem = self.merge_memory(mem, new_mem) + # Update memories + self.memory.set(mem) + + # Calculate and log cross entropy loss + loss = self.loss_func(output, target) + tracker.add("loss.", loss) + + if old_mem: + ar_loss = self.attention_reconstruction_loss(new_mem, old_mem) + tracker.add("ar_loss.", ar_loss) + # loss = loss + ar_loss + + # Calculate and log accuracy + self.accuracy(output, target) + self.accuracy.track() + + # Train the model + if self.mode.is_train: + # Calculate gradients + loss.backward() + # Clip gradients + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) + # Take optimizer step + self.optimizer.step() + # Log the model parameters and gradients on last batch of every epoch + if batch_idx.is_last: + tracker.add('model', self.model) + # Clear the gradients + self.optimizer.zero_grad() + + # Save the tracked metrics + tracker.save() + + def sample(self): + """ + ### Sampling function to generate samples periodically while training + """ + + # Starting prompt + prompt = self.prompt + # Collect output for printing + log = [(prompt, Text.subtle)] + # memory + mem = CompressedMemory([], []) + # Sample 25 tokens + for i in monit.iterate('Sample', 25): + # Tokenize the prompt + data = self.text.text_to_i(prompt).unsqueeze(-1) + # Move to device + data = data.to(self.device) + # Get the model output + output, new_mem = self.model(data, mem) + # Get the model prediction (greedy) + output = output.argmax(dim=-1).squeeze(1) + # Add the prediction to prompt + prompt += self.prompt_separator + self.text.itos[output[-1]] + # Only feed the last character to model in next iteration, rest will go in as memories + prompt = prompt[-1:] + # Add the prediction for logging + log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)] + # Update memory + mem, _ = self.merge_memory(mem, new_mem) + + # Print the sampled output + logger.log(log) + + +@option(Configs.model) +def autoregressive_model(c: Configs): + """ + ### Initialize the auto-regressive model + """ + from labml_nn.transformers.xl import RelativeMultiHeadAttention + from labml_nn.transformers.feed_forward import FeedForward + m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer( + CompressiveTransformerLayer(d_model=c.d_model, + self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout), + feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout), + dropout_prob=c.dropout, + compress=Conv1dCompression(c.compression_ratio, c.d_model)), c.n_layers)) + return m.to(c.device) + + +@option(Configs.attention_reconstruction_loss) +def attention_reconstruction_loss(c: Configs): + """ + ### Initialize the auto-regressive model + """ + return AttentionReconstructionLoss(c.model.transformer.layers) + + +def main(): + """ + ### Run the experiment + """ + # Create experiment + experiment.create(name="compressive_transformer", comment='') + # Create configs + conf = Configs() + # Load configurations + experiment.configs(conf, + # A dictionary of configurations to override + {'tokenizer': 'character', + 'text': 'tiny_shakespeare', + 'optimizer.learning_rate': 2.5e-4, + 'optimizer.optimizer': 'AdamW', + 'prompt': 'It is', + 'prompt_separator': '', + + 'train_loader': 'sequential_train_loader', + 'valid_loader': 'sequential_valid_loader', + + 'seq_len': 8, + 'mem_len': 8, + 'epochs': 128, + 'batch_size': 32, + 'inner_iterations': 25, + }) + + # Set models for saving and loading + experiment.add_pytorch_models({'model': conf.model}) + + # Start the experiment + with experiment.start(): + # `TrainValidConfigs.run` + conf.run() + + +# +if __name__ == '__main__': + main() diff --git a/labml_nn/utils.py b/labml_nn/utils.py index 4352a55c..1734c974 100644 --- a/labml_nn/utils.py +++ b/labml_nn/utils.py @@ -9,13 +9,11 @@ summary: A bunch of utility functions and classes import copy -from torch import nn - -from labml_helpers.module import Module +from labml_helpers.module import M, TypedModuleList -def clone_module_list(module: Module, n: int): +def clone_module_list(module: M, n: int) -> TypedModuleList[M]: """ ## Make a `nn.ModuleList` with clones of a given layer """ - return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + return TypedModuleList([copy.deepcopy(module) for _ in range(n)])