diff --git a/docs/transformers/fast_weights/experiment.html b/docs/transformers/fast_weights/experiment.html index 8da1c344..0e7d4c35 100644 --- a/docs/transformers/fast_weights/experiment.html +++ b/docs/transformers/fast_weights/experiment.html @@ -72,17 +72,21 @@
- +This trains a fast weights transformer model for auto-regression.
+Here’s a Colab notebook for training a fast weights transformer on Tiny Shakespeare dataset.
+8import torch
-9from torch import nn
-10
-11from labml import experiment
-12from labml.configs import option
-13from labml.utils.pytorch import get_modules
-14from labml_helpers.module import Module
-15from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
17import torch
+18from torch import nn
+19
+20from labml import experiment
+21from labml.configs import option
+22from labml.utils.pytorch import get_modules
+23from labml_helpers.module import Module
+24from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
18class AutoregressiveModel(Module):
27class AutoregressiveModel(Module):
23 def __init__(self, n_vocab: int, d_model: int, transformer: Module):
-24 super().__init__()
32 def __init__(self, n_vocab: int, d_model: int, transformer: Module):
+33 super().__init__()
Token embedding module
26 self.src_embed = nn.Embedding(n_vocab, d_model)
-27 self.transformer = transformer
-28 self.generator = nn.Linear(d_model, n_vocab)
35 self.src_embed = nn.Embedding(n_vocab, d_model)
+36 self.transformer = transformer
+37 self.generator = nn.Linear(d_model, n_vocab)
30 def forward(self, x: torch.Tensor):
39 def forward(self, x: torch.Tensor):
Embed the tokens
32 x = self.src_embed(x)
41 x = self.src_embed(x)
Run it through the the transformer
34 res = self.transformer(x)
43 res = self.transformer(x)
Generate logits of the next token
36 return self.generator(res), None
45 return self.generator(res), None
The default configs can and will be over-ridden when we start the experiment
39class Configs(NLPAutoRegressionConfigs):
48class Configs(NLPAutoRegressionConfigs):
46 model: AutoregressiveModel
-47
-48 d_model: int = 512
-49 nu: int = 1
-50 heads: int = 8
-51 dropout: float = 0.0
-52 d_ff: int = 2048
-53 n_layers: int = 6
55 model: AutoregressiveModel
+56
+57 d_model: int = 512
+58 nu: int = 1
+59 heads: int = 8
+60 dropout: float = 0.0
+61 d_ff: int = 2048
+62 n_layers: int = 6
Create fast weights transformer.
56@option(Configs.model)
-57def fast_weights_transformer(c: Configs):
65@option(Configs.model)
+66def fast_weights_transformer(c: Configs):
61 from labml_nn.transformers.fast_weights import FastWeightsAttentionTransformer, \
-62 FastWeightsAttentionTransformerLayer, FastWeightsAttention, FeedForward
-63
-64 from labml_nn.transformers.fast_weights import DPFP
-65 return AutoregressiveModel(
-66 c.n_tokens, c.d_model,
-67 FastWeightsAttentionTransformer(
-68 FastWeightsAttentionTransformerLayer(d_model=c.d_model,
-69 attn=FastWeightsAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
-70 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
-71 dropout_prob=c.dropout),
-72 c.n_layers)).to(c.device)
70 from labml_nn.transformers.fast_weights import FastWeightsAttentionTransformer, \
+71 FastWeightsAttentionTransformerLayer, FastWeightsAttention, FeedForward
+72
+73 from labml_nn.transformers.fast_weights import DPFP
+74 return AutoregressiveModel(
+75 c.n_tokens, c.d_model,
+76 FastWeightsAttentionTransformer(
+77 FastWeightsAttentionTransformerLayer(d_model=c.d_model,
+78 attn=FastWeightsAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
+79 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
+80 dropout_prob=c.dropout),
+81 c.n_layers)).to(c.device)
75def main():
84def main():
Create experiment
77 experiment.create(name="fast_weights_transformer")
86 experiment.create(name="fast_weights_transformer")
Create configs
79 conf = Configs()
88 conf = Configs()
Load configurations
81 experiment.configs(conf,
90 experiment.configs(conf,
A dictionary of configurations to override
83 {'tokenizer': 'character',
-84 'text': 'tiny_shakespeare',
-85 'optimizer.learning_rate': 1.0,
-86 'optimizer.optimizer': 'Noam',
-87 'prompt': 'It is',
-88 'prompt_separator': '',
-89
-90 'train_loader': 'shuffled_train_loader',
-91 'valid_loader': 'shuffled_valid_loader',
-92
-93 'seq_len': 128,
-94 'epochs': 128,
-95 'batch_size': 16,
-96 'inner_iterations': 25})
92 {'tokenizer': 'character',
+93 'text': 'tiny_shakespeare',
+94 'optimizer.learning_rate': 1.0,
+95 'optimizer.optimizer': 'Noam',
+96 'prompt': 'It is',
+97 'prompt_separator': '',
+98
+99 'train_loader': 'shuffled_train_loader',
+100 'valid_loader': 'shuffled_valid_loader',
+101
+102 'seq_len': 128,
+103 'epochs': 128,
+104 'batch_size': 16,
+105 'inner_iterations': 25})
Set models for saving and loading
99 experiment.add_pytorch_models(get_modules(conf))
108 experiment.add_pytorch_models(get_modules(conf))
Start the experiment
102 with experiment.start():
111 with experiment.start():
Run the training loop
104 conf.run()
-105
-106
-107if __name__ == '__main__':
-108 main()
113 conf.run()
+114
+115
+116if __name__ == '__main__':
+117 main()
The paper introduces a new linear attention projection function $\color{lightgreen}{\phi}$ a new update rule for $\color{cyan}{W^{(i)}} = f(\color{cyan}{W^{(i-1)}})$ and change the normalization $\frac{1}{z^{(i)} \cdot \color{lightgreen}{\phi(q^{(i)})}}$
+Here’s the training code and a notebook for training a fast weights + transformer on Tiny Shakespeare dataset.
+86import torch
-87from torch import nn
-88
-89from labml_helpers.module import Module
-90from labml_nn.transformers.feed_forward import FeedForward
-91from labml_nn.transformers.mha import PrepareForMultiHeadAttention
-92from labml_nn.utils import clone_module_list
92import torch
+93from torch import nn
+94
+95from labml_helpers.module import Module
+96from labml_nn.transformers.feed_forward import FeedForward
+97from labml_nn.transformers.mha import PrepareForMultiHeadAttention
+98from labml_nn.utils import clone_module_list
Check the paper for derivation.
95class DPFP(Module):
101class DPFP(Module):
129 def __init__(self, nu: int = 1, eps: float = 1e-6):
135 def __init__(self, nu: int = 1, eps: float = 1e-6):
134 super().__init__()
-135 self.nu = nu
-136 self.relu = nn.ReLU()
-137 self.eps = eps
140 super().__init__()
+141 self.nu = nu
+142 self.relu = nn.ReLU()
+143 self.eps = eps
139 def __call__(self, k: torch.Tensor):
145 def __call__(self, k: torch.Tensor):
Get $\color{lightgreen}{\phi(k)}$
141 k = self.dpfp(k)
147 k = self.dpfp(k)
Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$
143 return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)
149 return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)
145 def dpfp(self, k: torch.Tensor):
151 def dpfp(self, k: torch.Tensor):
$x = \text{ReLU}\Big(\big[k, -k\big]\Big)$
150 x = self.relu(torch.cat([k, -k], dim=-1))
156 x = self.relu(torch.cat([k, -k], dim=-1))