mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
fast weights experiment
This commit is contained in:
@ -6,6 +6,7 @@ summary: >
|
||||
Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch.
|
||||
---
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -16,16 +17,27 @@ from labml_nn.transformers.mha import PrepareForMultiHeadAttention
|
||||
from labml_nn.utils import clone_module_list
|
||||
|
||||
|
||||
class LinearAttentionFunction(Module):
|
||||
def __init__(self):
|
||||
class DPFP(Module):
|
||||
def __init__(self, nu: int = 1, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.nu = nu
|
||||
self.r = nn.ReLU()
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
return x
|
||||
x = self.dpfp(x)
|
||||
return x / (torch.sum(x, dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
def dpfp(self, x: torch.Tensor):
|
||||
x = torch.cat([self.r(x), self.r(-x)], dim=-1)
|
||||
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu + 1)], dim=-1)
|
||||
x_repeat = torch.cat([x] * self.nu, dim=-1)
|
||||
|
||||
return x_repeat * x_rolled
|
||||
|
||||
|
||||
class FastWeightAttention(Module):
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float, sigma: DPFP):
|
||||
super().__init__()
|
||||
|
||||
# Number of features per head
|
||||
@ -42,18 +54,21 @@ class FastWeightAttention(Module):
|
||||
self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
|
||||
nn.Sigmoid())
|
||||
|
||||
self.sigma = LinearAttentionFunction()
|
||||
self.sigma = sigma
|
||||
|
||||
# Output layer
|
||||
self.output = nn.Linear(d_model, d_model)
|
||||
# Dropout
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def __call__(self, x: torch.Tensor, weights: torch.Tensor):
|
||||
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||
query = self.sigma(self.query(x))
|
||||
key = self.sigma(self.key(x))
|
||||
value = self.value(x)
|
||||
|
||||
if weights is None:
|
||||
weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2]))
|
||||
|
||||
value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
|
||||
|
||||
beta = self.gate(x)
|
||||
@ -87,7 +102,7 @@ class FastWeightAttentionTransformerLayer(Module):
|
||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||
self.norm_ff = nn.LayerNorm([d_model])
|
||||
|
||||
def __call__(self, x: torch.Tensor, weights: torch.Tensor):
|
||||
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||
attn, weights = self.attn(x, weights)
|
||||
# Add the self attention results
|
||||
x = x + self.dropout(attn)
|
||||
@ -117,13 +132,13 @@ class FastWeightAttentionTransformer(Module):
|
||||
# List to store the outputs
|
||||
res = []
|
||||
# For each input step
|
||||
weights = [torch.zeros() for _ in range(len(self.layers))]
|
||||
weights = [None for _ in range(len(self.layers))]
|
||||
|
||||
for x in x_seq:
|
||||
# Run through each layer
|
||||
for i, layer in enumerate(self.layers):
|
||||
# Get layer output
|
||||
x = layer(x, weights[i])
|
||||
x, weights[i] = layer(x, weights[i])
|
||||
|
||||
res.append(x)
|
||||
|
||||
|
||||
108
labml_nn/transformers/fast_weights/experiment.py
Normal file
108
labml_nn/transformers/fast_weights/experiment.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""
|
||||
---
|
||||
title: Train Fast Weights Transformer
|
||||
summary: This is training code with notes for a Fast Weights Transformer.
|
||||
---
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml import experiment
|
||||
from labml.configs import option
|
||||
from labml.utils.pytorch import get_modules
|
||||
from labml_helpers.module import Module
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||
|
||||
|
||||
class AutoregressiveModel(Module):
|
||||
"""
|
||||
## Auto regressive model
|
||||
"""
|
||||
|
||||
def __init__(self, n_vocab: int, d_model: int, transformer: Module):
|
||||
super().__init__()
|
||||
# Token embedding module
|
||||
self.src_embed = nn.Embedding(n_vocab, d_model)
|
||||
self.transformer = transformer
|
||||
self.generator = nn.Linear(d_model, n_vocab)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Embed the tokens
|
||||
x = self.src_embed(x)
|
||||
# Run it through the the transformer
|
||||
res = self.transformer(x)
|
||||
# Generate logits of the next token
|
||||
return self.generator(res), None
|
||||
|
||||
|
||||
class Configs(NLPAutoRegressionConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
The default configs can and will be over-ridden when we start the experiment
|
||||
"""
|
||||
|
||||
model: AutoregressiveModel
|
||||
|
||||
d_model: int = 512
|
||||
nu: int = 1
|
||||
heads: int = 8
|
||||
dropout: float = 0.0
|
||||
d_ff: int = 2048
|
||||
n_layers: int = 6
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def fast_weights_transformer(c: Configs):
|
||||
"""
|
||||
Create [fast weights transformer](index.html).
|
||||
"""
|
||||
from labml_nn.transformers.fast_weights import FastWeightAttentionTransformer, \
|
||||
FastWeightAttentionTransformerLayer, FastWeightAttention, FeedForward
|
||||
|
||||
from labml_nn.transformers.fast_weights import DPFP
|
||||
return AutoregressiveModel(
|
||||
c.n_tokens, c.d_model,
|
||||
FastWeightAttentionTransformer(
|
||||
FastWeightAttentionTransformerLayer(d_model=c.d_model,
|
||||
attn=FastWeightAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
|
||||
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
|
||||
dropout_prob=c.dropout),
|
||||
c.n_layers)).to(c.device)
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name="fast_weights_transformer")
|
||||
# Create configs
|
||||
conf = Configs()
|
||||
# Load configurations
|
||||
experiment.configs(conf,
|
||||
# A dictionary of configurations to override
|
||||
{'tokenizer': 'character',
|
||||
'text': 'tiny_shakespeare',
|
||||
'optimizer.learning_rate': 1.0,
|
||||
'optimizer.optimizer': 'Noam',
|
||||
'prompt': 'It is',
|
||||
'prompt_separator': '',
|
||||
|
||||
'train_loader': 'shuffled_train_loader',
|
||||
'valid_loader': 'shuffled_valid_loader',
|
||||
|
||||
'seq_len': 128,
|
||||
'epochs': 128,
|
||||
'batch_size': 16,
|
||||
'inner_iterations': 25})
|
||||
|
||||
# Set models for saving and loading
|
||||
experiment.add_pytorch_models(get_modules(conf))
|
||||
|
||||
# Start the experiment
|
||||
with experiment.start():
|
||||
# Run the training loop
|
||||
conf.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user