mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 20:28:41 +08:00
fast weights experiment
This commit is contained in:
@ -6,6 +6,7 @@ summary: >
|
|||||||
Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch.
|
Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch.
|
||||||
---
|
---
|
||||||
"""
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -16,16 +17,27 @@ from labml_nn.transformers.mha import PrepareForMultiHeadAttention
|
|||||||
from labml_nn.utils import clone_module_list
|
from labml_nn.utils import clone_module_list
|
||||||
|
|
||||||
|
|
||||||
class LinearAttentionFunction(Module):
|
class DPFP(Module):
|
||||||
def __init__(self):
|
def __init__(self, nu: int = 1, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.nu = nu
|
||||||
|
self.r = nn.ReLU()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def __call__(self, x: torch.Tensor):
|
||||||
return x
|
x = self.dpfp(x)
|
||||||
|
return x / (torch.sum(x, dim=-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def dpfp(self, x: torch.Tensor):
|
||||||
|
x = torch.cat([self.r(x), self.r(-x)], dim=-1)
|
||||||
|
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu + 1)], dim=-1)
|
||||||
|
x_repeat = torch.cat([x] * self.nu, dim=-1)
|
||||||
|
|
||||||
|
return x_repeat * x_rolled
|
||||||
|
|
||||||
|
|
||||||
class FastWeightAttention(Module):
|
class FastWeightAttention(Module):
|
||||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
def __init__(self, heads: int, d_model: int, dropout_prob: float, sigma: DPFP):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Number of features per head
|
# Number of features per head
|
||||||
@ -42,18 +54,21 @@ class FastWeightAttention(Module):
|
|||||||
self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
|
self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
|
||||||
nn.Sigmoid())
|
nn.Sigmoid())
|
||||||
|
|
||||||
self.sigma = LinearAttentionFunction()
|
self.sigma = sigma
|
||||||
|
|
||||||
# Output layer
|
# Output layer
|
||||||
self.output = nn.Linear(d_model, d_model)
|
self.output = nn.Linear(d_model, d_model)
|
||||||
# Dropout
|
# Dropout
|
||||||
self.dropout = nn.Dropout(dropout_prob)
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor, weights: torch.Tensor):
|
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||||
query = self.sigma(self.query(x))
|
query = self.sigma(self.query(x))
|
||||||
key = self.sigma(self.key(x))
|
key = self.sigma(self.key(x))
|
||||||
value = self.value(x)
|
value = self.value(x)
|
||||||
|
|
||||||
|
if weights is None:
|
||||||
|
weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2]))
|
||||||
|
|
||||||
value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
|
value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
|
||||||
|
|
||||||
beta = self.gate(x)
|
beta = self.gate(x)
|
||||||
@ -87,7 +102,7 @@ class FastWeightAttentionTransformerLayer(Module):
|
|||||||
self.norm_self_attn = nn.LayerNorm([d_model])
|
self.norm_self_attn = nn.LayerNorm([d_model])
|
||||||
self.norm_ff = nn.LayerNorm([d_model])
|
self.norm_ff = nn.LayerNorm([d_model])
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor, weights: torch.Tensor):
|
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
|
||||||
attn, weights = self.attn(x, weights)
|
attn, weights = self.attn(x, weights)
|
||||||
# Add the self attention results
|
# Add the self attention results
|
||||||
x = x + self.dropout(attn)
|
x = x + self.dropout(attn)
|
||||||
@ -117,13 +132,13 @@ class FastWeightAttentionTransformer(Module):
|
|||||||
# List to store the outputs
|
# List to store the outputs
|
||||||
res = []
|
res = []
|
||||||
# For each input step
|
# For each input step
|
||||||
weights = [torch.zeros() for _ in range(len(self.layers))]
|
weights = [None for _ in range(len(self.layers))]
|
||||||
|
|
||||||
for x in x_seq:
|
for x in x_seq:
|
||||||
# Run through each layer
|
# Run through each layer
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
# Get layer output
|
# Get layer output
|
||||||
x = layer(x, weights[i])
|
x, weights[i] = layer(x, weights[i])
|
||||||
|
|
||||||
res.append(x)
|
res.append(x)
|
||||||
|
|
||||||
|
|||||||
108
labml_nn/transformers/fast_weights/experiment.py
Normal file
108
labml_nn/transformers/fast_weights/experiment.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
---
|
||||||
|
title: Train Fast Weights Transformer
|
||||||
|
summary: This is training code with notes for a Fast Weights Transformer.
|
||||||
|
---
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from labml import experiment
|
||||||
|
from labml.configs import option
|
||||||
|
from labml.utils.pytorch import get_modules
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||||
|
|
||||||
|
|
||||||
|
class AutoregressiveModel(Module):
|
||||||
|
"""
|
||||||
|
## Auto regressive model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_vocab: int, d_model: int, transformer: Module):
|
||||||
|
super().__init__()
|
||||||
|
# Token embedding module
|
||||||
|
self.src_embed = nn.Embedding(n_vocab, d_model)
|
||||||
|
self.transformer = transformer
|
||||||
|
self.generator = nn.Linear(d_model, n_vocab)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# Embed the tokens
|
||||||
|
x = self.src_embed(x)
|
||||||
|
# Run it through the the transformer
|
||||||
|
res = self.transformer(x)
|
||||||
|
# Generate logits of the next token
|
||||||
|
return self.generator(res), None
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(NLPAutoRegressionConfigs):
|
||||||
|
"""
|
||||||
|
## Configurations
|
||||||
|
|
||||||
|
The default configs can and will be over-ridden when we start the experiment
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: AutoregressiveModel
|
||||||
|
|
||||||
|
d_model: int = 512
|
||||||
|
nu: int = 1
|
||||||
|
heads: int = 8
|
||||||
|
dropout: float = 0.0
|
||||||
|
d_ff: int = 2048
|
||||||
|
n_layers: int = 6
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.model)
|
||||||
|
def fast_weights_transformer(c: Configs):
|
||||||
|
"""
|
||||||
|
Create [fast weights transformer](index.html).
|
||||||
|
"""
|
||||||
|
from labml_nn.transformers.fast_weights import FastWeightAttentionTransformer, \
|
||||||
|
FastWeightAttentionTransformerLayer, FastWeightAttention, FeedForward
|
||||||
|
|
||||||
|
from labml_nn.transformers.fast_weights import DPFP
|
||||||
|
return AutoregressiveModel(
|
||||||
|
c.n_tokens, c.d_model,
|
||||||
|
FastWeightAttentionTransformer(
|
||||||
|
FastWeightAttentionTransformerLayer(d_model=c.d_model,
|
||||||
|
attn=FastWeightAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
|
||||||
|
feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
|
||||||
|
dropout_prob=c.dropout),
|
||||||
|
c.n_layers)).to(c.device)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create experiment
|
||||||
|
experiment.create(name="fast_weights_transformer")
|
||||||
|
# Create configs
|
||||||
|
conf = Configs()
|
||||||
|
# Load configurations
|
||||||
|
experiment.configs(conf,
|
||||||
|
# A dictionary of configurations to override
|
||||||
|
{'tokenizer': 'character',
|
||||||
|
'text': 'tiny_shakespeare',
|
||||||
|
'optimizer.learning_rate': 1.0,
|
||||||
|
'optimizer.optimizer': 'Noam',
|
||||||
|
'prompt': 'It is',
|
||||||
|
'prompt_separator': '',
|
||||||
|
|
||||||
|
'train_loader': 'shuffled_train_loader',
|
||||||
|
'valid_loader': 'shuffled_valid_loader',
|
||||||
|
|
||||||
|
'seq_len': 128,
|
||||||
|
'epochs': 128,
|
||||||
|
'batch_size': 16,
|
||||||
|
'inner_iterations': 25})
|
||||||
|
|
||||||
|
# Set models for saving and loading
|
||||||
|
experiment.add_pytorch_models(get_modules(conf))
|
||||||
|
|
||||||
|
# Start the experiment
|
||||||
|
with experiment.start():
|
||||||
|
# Run the training loop
|
||||||
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user