fast weights experiment

This commit is contained in:
Varuna Jayasiri
2021-03-08 11:32:19 +05:30
parent 2cb1e3f516
commit 9a49ae47a0
2 changed files with 132 additions and 9 deletions

View File

@ -6,6 +6,7 @@ summary: >
Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch.
---
"""
from typing import Optional
import torch
from torch import nn
@ -16,16 +17,27 @@ from labml_nn.transformers.mha import PrepareForMultiHeadAttention
from labml_nn.utils import clone_module_list
class LinearAttentionFunction(Module):
def __init__(self):
class DPFP(Module):
def __init__(self, nu: int = 1, eps: float = 1e-6):
super().__init__()
self.nu = nu
self.r = nn.ReLU()
self.eps = eps
def __call__(self, x: torch.Tensor):
return x
x = self.dpfp(x)
return x / (torch.sum(x, dim=-1, keepdim=True) + self.eps)
def dpfp(self, x: torch.Tensor):
x = torch.cat([self.r(x), self.r(-x)], dim=-1)
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu + 1)], dim=-1)
x_repeat = torch.cat([x] * self.nu, dim=-1)
return x_repeat * x_rolled
class FastWeightAttention(Module):
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
def __init__(self, heads: int, d_model: int, dropout_prob: float, sigma: DPFP):
super().__init__()
# Number of features per head
@ -42,18 +54,21 @@ class FastWeightAttention(Module):
self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
nn.Sigmoid())
self.sigma = LinearAttentionFunction()
self.sigma = sigma
# Output layer
self.output = nn.Linear(d_model, d_model)
# Dropout
self.dropout = nn.Dropout(dropout_prob)
def __call__(self, x: torch.Tensor, weights: torch.Tensor):
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
query = self.sigma(self.query(x))
key = self.sigma(self.key(x))
value = self.value(x)
if weights is None:
weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2]))
value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
beta = self.gate(x)
@ -87,7 +102,7 @@ class FastWeightAttentionTransformerLayer(Module):
self.norm_self_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def __call__(self, x: torch.Tensor, weights: torch.Tensor):
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
attn, weights = self.attn(x, weights)
# Add the self attention results
x = x + self.dropout(attn)
@ -117,13 +132,13 @@ class FastWeightAttentionTransformer(Module):
# List to store the outputs
res = []
# For each input step
weights = [torch.zeros() for _ in range(len(self.layers))]
weights = [None for _ in range(len(self.layers))]
for x in x_seq:
# Run through each layer
for i, layer in enumerate(self.layers):
# Get layer output
x = layer(x, weights[i])
x, weights[i] = layer(x, weights[i])
res.append(x)

View File

@ -0,0 +1,108 @@
"""
---
title: Train Fast Weights Transformer
summary: This is training code with notes for a Fast Weights Transformer.
---
"""
import torch
from torch import nn
from labml import experiment
from labml.configs import option
from labml.utils.pytorch import get_modules
from labml_helpers.module import Module
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
class AutoregressiveModel(Module):
"""
## Auto regressive model
"""
def __init__(self, n_vocab: int, d_model: int, transformer: Module):
super().__init__()
# Token embedding module
self.src_embed = nn.Embedding(n_vocab, d_model)
self.transformer = transformer
self.generator = nn.Linear(d_model, n_vocab)
def forward(self, x: torch.Tensor):
# Embed the tokens
x = self.src_embed(x)
# Run it through the the transformer
res = self.transformer(x)
# Generate logits of the next token
return self.generator(res), None
class Configs(NLPAutoRegressionConfigs):
"""
## Configurations
The default configs can and will be over-ridden when we start the experiment
"""
model: AutoregressiveModel
d_model: int = 512
nu: int = 1
heads: int = 8
dropout: float = 0.0
d_ff: int = 2048
n_layers: int = 6
@option(Configs.model)
def fast_weights_transformer(c: Configs):
"""
Create [fast weights transformer](index.html).
"""
from labml_nn.transformers.fast_weights import FastWeightAttentionTransformer, \
FastWeightAttentionTransformerLayer, FastWeightAttention, FeedForward
from labml_nn.transformers.fast_weights import DPFP
return AutoregressiveModel(
c.n_tokens, c.d_model,
FastWeightAttentionTransformer(
FastWeightAttentionTransformerLayer(d_model=c.d_model,
attn=FastWeightAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
dropout_prob=c.dropout),
c.n_layers)).to(c.device)
def main():
# Create experiment
experiment.create(name="fast_weights_transformer")
# Create configs
conf = Configs()
# Load configurations
experiment.configs(conf,
# A dictionary of configurations to override
{'tokenizer': 'character',
'text': 'tiny_shakespeare',
'optimizer.learning_rate': 1.0,
'optimizer.optimizer': 'Noam',
'prompt': 'It is',
'prompt_separator': '',
'train_loader': 'shuffled_train_loader',
'valid_loader': 'shuffled_valid_loader',
'seq_len': 128,
'epochs': 128,
'batch_size': 16,
'inner_iterations': 25})
# Set models for saving and loading
experiment.add_pytorch_models(get_modules(conf))
# Start the experiment
with experiment.start():
# Run the training loop
conf.run()
if __name__ == '__main__':
main()