diff --git a/docs/transformers/fast_weights/experiment.html b/docs/transformers/fast_weights/experiment.html index 8da1c344..0e7d4c35 100644 --- a/docs/transformers/fast_weights/experiment.html +++ b/docs/transformers/fast_weights/experiment.html @@ -72,17 +72,21 @@ - +

Train Fast Weights Transformer

+

This trains a fast weights transformer model for auto-regression.

+

Here’s a Colab notebook for training a fast weights transformer on Tiny Shakespeare dataset.

+

Open In Colab +View Run

-
8import torch
-9from torch import nn
-10
-11from labml import experiment
-12from labml.configs import option
-13from labml.utils.pytorch import get_modules
-14from labml_helpers.module import Module
-15from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+
17import torch
+18from torch import nn
+19
+20from labml import experiment
+21from labml.configs import option
+22from labml.utils.pytorch import get_modules
+23from labml_helpers.module import Module
+24from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
@@ -93,7 +97,7 @@

Auto regressive model

-
18class AutoregressiveModel(Module):
+
27class AutoregressiveModel(Module):
@@ -104,8 +108,8 @@
-
23    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
-24        super().__init__()
+
32    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
+33        super().__init__()
@@ -116,9 +120,9 @@

Token embedding module

-
26        self.src_embed = nn.Embedding(n_vocab, d_model)
-27        self.transformer = transformer
-28        self.generator = nn.Linear(d_model, n_vocab)
+
35        self.src_embed = nn.Embedding(n_vocab, d_model)
+36        self.transformer = transformer
+37        self.generator = nn.Linear(d_model, n_vocab)
@@ -129,7 +133,7 @@
-
30    def forward(self, x: torch.Tensor):
+
39    def forward(self, x: torch.Tensor):
@@ -140,7 +144,7 @@

Embed the tokens

-
32        x = self.src_embed(x)
+
41        x = self.src_embed(x)
@@ -151,7 +155,7 @@

Run it through the the transformer

-
34        res = self.transformer(x)
+
43        res = self.transformer(x)
@@ -162,7 +166,7 @@

Generate logits of the next token

-
36        return self.generator(res), None
+
45        return self.generator(res), None
@@ -174,7 +178,7 @@

The default configs can and will be over-ridden when we start the experiment

-
39class Configs(NLPAutoRegressionConfigs):
+
48class Configs(NLPAutoRegressionConfigs):
@@ -185,14 +189,14 @@
-
46    model: AutoregressiveModel
-47
-48    d_model: int = 512
-49    nu: int = 1
-50    heads: int = 8
-51    dropout: float = 0.0
-52    d_ff: int = 2048
-53    n_layers: int = 6
+
55    model: AutoregressiveModel
+56
+57    d_model: int = 512
+58    nu: int = 1
+59    heads: int = 8
+60    dropout: float = 0.0
+61    d_ff: int = 2048
+62    n_layers: int = 6
@@ -203,8 +207,8 @@

Create fast weights transformer.

-
56@option(Configs.model)
-57def fast_weights_transformer(c: Configs):
+
65@option(Configs.model)
+66def fast_weights_transformer(c: Configs):
@@ -215,18 +219,18 @@
-
61    from labml_nn.transformers.fast_weights import FastWeightsAttentionTransformer, \
-62        FastWeightsAttentionTransformerLayer, FastWeightsAttention, FeedForward
-63
-64    from labml_nn.transformers.fast_weights import DPFP
-65    return AutoregressiveModel(
-66        c.n_tokens, c.d_model,
-67        FastWeightsAttentionTransformer(
-68            FastWeightsAttentionTransformerLayer(d_model=c.d_model,
-69                                                 attn=FastWeightsAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
-70                                                 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
-71                                                 dropout_prob=c.dropout),
-72            c.n_layers)).to(c.device)
+
70    from labml_nn.transformers.fast_weights import FastWeightsAttentionTransformer, \
+71        FastWeightsAttentionTransformerLayer, FastWeightsAttention, FeedForward
+72
+73    from labml_nn.transformers.fast_weights import DPFP
+74    return AutoregressiveModel(
+75        c.n_tokens, c.d_model,
+76        FastWeightsAttentionTransformer(
+77            FastWeightsAttentionTransformerLayer(d_model=c.d_model,
+78                                                 attn=FastWeightsAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
+79                                                 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
+80                                                 dropout_prob=c.dropout),
+81            c.n_layers)).to(c.device)
@@ -237,7 +241,7 @@
-
75def main():
+
84def main():
@@ -248,7 +252,7 @@

Create experiment

-
77    experiment.create(name="fast_weights_transformer")
+
86    experiment.create(name="fast_weights_transformer")
@@ -259,7 +263,7 @@

Create configs

-
79    conf = Configs()
+
88    conf = Configs()
@@ -270,7 +274,7 @@

Load configurations

-
81    experiment.configs(conf,
+
90    experiment.configs(conf,
@@ -281,20 +285,20 @@

A dictionary of configurations to override

-
83                       {'tokenizer': 'character',
-84                        'text': 'tiny_shakespeare',
-85                        'optimizer.learning_rate': 1.0,
-86                        'optimizer.optimizer': 'Noam',
-87                        'prompt': 'It is',
-88                        'prompt_separator': '',
-89
-90                        'train_loader': 'shuffled_train_loader',
-91                        'valid_loader': 'shuffled_valid_loader',
-92
-93                        'seq_len': 128,
-94                        'epochs': 128,
-95                        'batch_size': 16,
-96                        'inner_iterations': 25})
+
92                       {'tokenizer': 'character',
+93                        'text': 'tiny_shakespeare',
+94                        'optimizer.learning_rate': 1.0,
+95                        'optimizer.optimizer': 'Noam',
+96                        'prompt': 'It is',
+97                        'prompt_separator': '',
+98
+99                        'train_loader': 'shuffled_train_loader',
+100                        'valid_loader': 'shuffled_valid_loader',
+101
+102                        'seq_len': 128,
+103                        'epochs': 128,
+104                        'batch_size': 16,
+105                        'inner_iterations': 25})
@@ -305,7 +309,7 @@

Set models for saving and loading

-
99    experiment.add_pytorch_models(get_modules(conf))
+
108    experiment.add_pytorch_models(get_modules(conf))
@@ -316,7 +320,7 @@

Start the experiment

-
102    with experiment.start():
+
111    with experiment.start():
@@ -327,11 +331,11 @@

Run the training loop

-
104        conf.run()
-105
-106
-107if __name__ == '__main__':
-108    main()
+
113        conf.run()
+114
+115
+116if __name__ == '__main__':
+117    main()
diff --git a/docs/transformers/fast_weights/index.html b/docs/transformers/fast_weights/index.html index 21a15609..7604576a 100644 --- a/docs/transformers/fast_weights/index.html +++ b/docs/transformers/fast_weights/index.html @@ -140,15 +140,19 @@ y^{(i)} &= \frac{1}{z^{(i)} \cdot \color{lightgreen}{\phi(q^{(i)})}}

The paper introduces a new linear attention projection function $\color{lightgreen}{\phi}$ a new update rule for $\color{cyan}{W^{(i)}} = f(\color{cyan}{W^{(i-1)}})$ and change the normalization $\frac{1}{z^{(i)} \cdot \color{lightgreen}{\phi(q^{(i)})}}$

+

Here’s the training code and a notebook for training a fast weights + transformer on Tiny Shakespeare dataset.

+

Open In Colab +View Run

-
86import torch
-87from torch import nn
-88
-89from labml_helpers.module import Module
-90from labml_nn.transformers.feed_forward import FeedForward
-91from labml_nn.transformers.mha import PrepareForMultiHeadAttention
-92from labml_nn.utils import clone_module_list
+
92import torch
+93from torch import nn
+94
+95from labml_helpers.module import Module
+96from labml_nn.transformers.feed_forward import FeedForward
+97from labml_nn.transformers.mha import PrepareForMultiHeadAttention
+98from labml_nn.utils import clone_module_list
@@ -183,7 +187,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.

Check the paper for derivation.

-
95class DPFP(Module):
+
101class DPFP(Module):
@@ -197,7 +201,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.

-
129    def __init__(self, nu: int = 1, eps: float = 1e-6):
+
135    def __init__(self, nu: int = 1, eps: float = 1e-6):
@@ -208,10 +212,10 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.

-
134        super().__init__()
-135        self.nu = nu
-136        self.relu = nn.ReLU()
-137        self.eps = eps
+
140        super().__init__()
+141        self.nu = nu
+142        self.relu = nn.ReLU()
+143        self.eps = eps
@@ -222,7 +226,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.

-
139    def __call__(self, k: torch.Tensor):
+
145    def __call__(self, k: torch.Tensor):
@@ -233,7 +237,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.

Get $\color{lightgreen}{\phi(k)}$

-
141        k = self.dpfp(k)
+
147        k = self.dpfp(k)
@@ -244,7 +248,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.

Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$

-
143        return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)
+
149        return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)
@@ -257,7 +261,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.

-
145    def dpfp(self, k: torch.Tensor):
+
151    def dpfp(self, k: torch.Tensor):
@@ -268,7 +272,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.

$x = \text{ReLU}\Big(\big[k, -k\big]\Big)$

-
150        x = self.relu(torch.cat([k, -k], dim=-1))
+
156        x = self.relu(torch.cat([k, -k], dim=-1))
@@ -281,7 +285,7 @@ to get