diff --git a/labml_nn/transformers/fast_weights/__init__.py b/labml_nn/transformers/fast_weights/__init__.py index c142e94b..3acde5a6 100644 --- a/labml_nn/transformers/fast_weights/__init__.py +++ b/labml_nn/transformers/fast_weights/__init__.py @@ -6,6 +6,7 @@ summary: > Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch. --- """ +from typing import Optional import torch from torch import nn @@ -16,16 +17,27 @@ from labml_nn.transformers.mha import PrepareForMultiHeadAttention from labml_nn.utils import clone_module_list -class LinearAttentionFunction(Module): - def __init__(self): +class DPFP(Module): + def __init__(self, nu: int = 1, eps: float = 1e-6): super().__init__() + self.nu = nu + self.r = nn.ReLU() + self.eps = eps def __call__(self, x: torch.Tensor): - return x + x = self.dpfp(x) + return x / (torch.sum(x, dim=-1, keepdim=True) + self.eps) + + def dpfp(self, x: torch.Tensor): + x = torch.cat([self.r(x), self.r(-x)], dim=-1) + x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu + 1)], dim=-1) + x_repeat = torch.cat([x] * self.nu, dim=-1) + + return x_repeat * x_rolled class FastWeightAttention(Module): - def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + def __init__(self, heads: int, d_model: int, dropout_prob: float, sigma: DPFP): super().__init__() # Number of features per head @@ -42,18 +54,21 @@ class FastWeightAttention(Module): self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False), nn.Sigmoid()) - self.sigma = LinearAttentionFunction() + self.sigma = sigma # Output layer self.output = nn.Linear(d_model, d_model) # Dropout self.dropout = nn.Dropout(dropout_prob) - def __call__(self, x: torch.Tensor, weights: torch.Tensor): + def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]): query = self.sigma(self.query(x)) key = self.sigma(self.key(x)) value = self.value(x) + if weights is None: + weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2])) + value_existing = torch.einsum('bhvk,bhk->bhv', weights, key) beta = self.gate(x) @@ -87,7 +102,7 @@ class FastWeightAttentionTransformerLayer(Module): self.norm_self_attn = nn.LayerNorm([d_model]) self.norm_ff = nn.LayerNorm([d_model]) - def __call__(self, x: torch.Tensor, weights: torch.Tensor): + def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]): attn, weights = self.attn(x, weights) # Add the self attention results x = x + self.dropout(attn) @@ -117,13 +132,13 @@ class FastWeightAttentionTransformer(Module): # List to store the outputs res = [] # For each input step - weights = [torch.zeros() for _ in range(len(self.layers))] + weights = [None for _ in range(len(self.layers))] for x in x_seq: # Run through each layer for i, layer in enumerate(self.layers): # Get layer output - x = layer(x, weights[i]) + x, weights[i] = layer(x, weights[i]) res.append(x) diff --git a/labml_nn/transformers/fast_weights/experiment.py b/labml_nn/transformers/fast_weights/experiment.py new file mode 100644 index 00000000..2d7f51d0 --- /dev/null +++ b/labml_nn/transformers/fast_weights/experiment.py @@ -0,0 +1,108 @@ +""" +--- +title: Train Fast Weights Transformer +summary: This is training code with notes for a Fast Weights Transformer. +--- +""" + +import torch +from torch import nn + +from labml import experiment +from labml.configs import option +from labml.utils.pytorch import get_modules +from labml_helpers.module import Module +from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs + + +class AutoregressiveModel(Module): + """ + ## Auto regressive model + """ + + def __init__(self, n_vocab: int, d_model: int, transformer: Module): + super().__init__() + # Token embedding module + self.src_embed = nn.Embedding(n_vocab, d_model) + self.transformer = transformer + self.generator = nn.Linear(d_model, n_vocab) + + def forward(self, x: torch.Tensor): + # Embed the tokens + x = self.src_embed(x) + # Run it through the the transformer + res = self.transformer(x) + # Generate logits of the next token + return self.generator(res), None + + +class Configs(NLPAutoRegressionConfigs): + """ + ## Configurations + + The default configs can and will be over-ridden when we start the experiment + """ + + model: AutoregressiveModel + + d_model: int = 512 + nu: int = 1 + heads: int = 8 + dropout: float = 0.0 + d_ff: int = 2048 + n_layers: int = 6 + + +@option(Configs.model) +def fast_weights_transformer(c: Configs): + """ + Create [fast weights transformer](index.html). + """ + from labml_nn.transformers.fast_weights import FastWeightAttentionTransformer, \ + FastWeightAttentionTransformerLayer, FastWeightAttention, FeedForward + + from labml_nn.transformers.fast_weights import DPFP + return AutoregressiveModel( + c.n_tokens, c.d_model, + FastWeightAttentionTransformer( + FastWeightAttentionTransformerLayer(d_model=c.d_model, + attn=FastWeightAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)), + feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout), + dropout_prob=c.dropout), + c.n_layers)).to(c.device) + + +def main(): + # Create experiment + experiment.create(name="fast_weights_transformer") + # Create configs + conf = Configs() + # Load configurations + experiment.configs(conf, + # A dictionary of configurations to override + {'tokenizer': 'character', + 'text': 'tiny_shakespeare', + 'optimizer.learning_rate': 1.0, + 'optimizer.optimizer': 'Noam', + 'prompt': 'It is', + 'prompt_separator': '', + + 'train_loader': 'shuffled_train_loader', + 'valid_loader': 'shuffled_valid_loader', + + 'seq_len': 128, + 'epochs': 128, + 'batch_size': 16, + 'inner_iterations': 25}) + + # Set models for saving and loading + experiment.add_pytorch_models(get_modules(conf)) + + # Start the experiment + with experiment.start(): + # Run the training loop + conf.run() + + +if __name__ == '__main__': + main()