diff --git a/docs/sitemap.xml b/docs/sitemap.xml index 72c09fb8..af3bc135 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -253,14 +253,21 @@ https://nn.labml.ai/transformers/feedback/experiment.html - 2021-01-14T16:30:00+00:00 + 2021-01-29T16:30:00+00:00 1.00 https://nn.labml.ai/transformers/feedback/index.html - 2021-01-25T16:30:00+00:00 + 2021-01-29T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/transformers/feedback/experiment.html + 2021-01-29T16:30:00+00:00 1.00 diff --git a/docs/transformers/feedback/experiment.html b/docs/transformers/feedback/experiment.html new file mode 100644 index 00000000..bfdefdeb --- /dev/null +++ b/docs/transformers/feedback/experiment.html @@ -0,0 +1,422 @@ + + + + + + + + + + + + + + + + + + + + + + + Train Feedback Transformer + + + + + + + + +
+
+
+
+

+ home + transformers + feedback +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Train Feedback Transformer

+

This trains a feedback transformer model for auto-regression. +You can pick the original feedback transformer or the new version +where the keys and values are precalculated.

+

Here’s a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.

+

Open In Colab +View Run

+
+
+
19import torch
+20from torch import nn
+21
+22from labml import experiment
+23from labml.configs import option
+24from labml.utils.pytorch import get_modules
+25from labml_helpers.module import Module
+26
+27from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+28from labml_nn.transformers import Encoder, Generator, TransformerConfigs
+29from labml_nn.transformers.utils import subsequent_mask
+
+
+
+
+ +

Auto regressive model

+
+
+
32class AutoregressiveModel(Module):
+
+
+
+
+ + +
+
+
37    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
+38        super().__init__()
+
+
+
+
+ +

Token embedding module

+
+
+
40        self.src_embed = nn.Embedding(n_vocab, d_model)
+41        self.transformer = transformer
+42        self.generator = nn.Linear(d_model, n_vocab)
+
+
+
+
+ + +
+
+
44    def __call__(self, x: torch.Tensor):
+
+
+
+
+ +

Embed the tokens

+
+
+
46        x = self.src_embed(x)
+
+
+
+
+ +

Run it through the the transformer

+
+
+
48        res = self.transformer(x)
+
+
+
+
+ +

Generate logits of the next token

+
+
+
50        return self.generator(res), None
+
+
+
+
+ +

Configurations

+

The default configs can and will be over-ridden when we start the experiment

+
+
+
53class Configs(NLPAutoRegressionConfigs):
+
+
+
+
+ + +
+
+
60    model: AutoregressiveModel
+61
+62    d_model: int = 512
+63    heads: int = 8
+64    dropout: float = 0.0
+65    d_ff: int = 2048
+66    n_layers: int = 6
+
+
+
+
+ +

Create original feedback transformer.

+
+
+
69@option(Configs.model)
+70def feedback_transformer(c: Configs):
+
+
+
+
+ + +
+
+
74    from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \
+75        FeedbackAttention, FeedForward
+76
+77    return AutoregressiveModel(
+78        c.n_tokens, c.d_model,
+79        FeedbackTransformer(
+80            FeedbackTransformerLayer(d_model=c.d_model,
+81                                     attn=FeedbackAttention(c.heads, c.d_model, c.dropout),
+82                                     feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
+83                                     dropout_prob=c.dropout),
+84            c.n_layers)).to(c.device)
+
+
+
+
+ +

Create updated feedback transformer, with precalculated keys and values.

+
+
+
87@option(Configs.model)
+88def feedback_transformer_kv(c: Configs):
+
+
+
+
+ + +
+
+
92    from labml_nn.transformers.feedback import FeedbackTransformerKV, FeedbackTransformerLayer, \
+93        FeedbackAttention, FeedForward
+94
+95    return AutoregressiveModel(
+96        c.n_tokens, c.d_model,
+97        FeedbackTransformerKV(
+98            FeedbackTransformerLayer(d_model=c.d_model,
+99                                     attn=FeedbackAttention(c.heads, c.d_model, c.dropout,
+100                                                            is_kv_precomputed=True),
+101                                     feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
+102                                     dropout_prob=c.dropout),
+103            c.n_layers, c.d_model, c.heads)).to(c.device)
+
+
+
+
+ + +
+
+
106def main():
+
+
+
+
+ +

Create experiment

+
+
+
108    experiment.create(name="feedback_transformer")
+
+
+
+
+ +

Create configs

+
+
+
110    conf = Configs()
+
+
+
+
+ +

Load configurations

+
+
+
112    experiment.configs(conf,
+
+
+
+
+ +

A dictionary of configurations to override

+
+
+
114                       {'tokenizer': 'character',
+115                        'text': 'tiny_shakespeare',
+116                        'optimizer.learning_rate': 1.0,
+117                        'optimizer.optimizer': 'Noam',
+118                        'prompt': 'It is',
+119                        'prompt_separator': '',
+
+
+
+
+ +

Use feedback_transformer for original feedback transformer

+
+
+
122                        'model': 'feedback_transformer_kv',
+123
+124                        'train_loader': 'shuffled_train_loader',
+125                        'valid_loader': 'shuffled_valid_loader',
+126
+127                        'seq_len': 128,
+128                        'epochs': 128,
+129                        'batch_size': 64,
+130                        'inner_iterations': 25})
+
+
+
+
+ +

Set models for saving and loading

+
+
+
133    experiment.add_pytorch_models(get_modules(conf))
+
+
+
+
+ +

Start the experiment

+
+
+
136    with experiment.start():
+
+
+
+
+ +

Run the training loop

+
+
+
138        conf.run()
+139
+140
+141if __name__ == '__main__':
+142    main()
+
+
+
+ + + + + + \ No newline at end of file diff --git a/docs/transformers/feedback/index.html b/docs/transformers/feedback/index.html index 6ea9e2ce..f6807bef 100644 --- a/docs/transformers/feedback/index.html +++ b/docs/transformers/feedback/index.html @@ -85,24 +85,31 @@ if you cache the memory vectors.

In order to speed up the training the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.

-

The feedback transformer doesn’t keep the outputs of all layers. +

The original feedback transformer doesn’t keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. -This reduces the memory used for caching during prediction.

-

Here’s a notebook for training a feedback transformer on Tiny Shakespeare dataset.

+This reduces the memory used for caching during prediction. +The first half of this file implements this.

+

The updated feedback transformer shares weights $W^l_k$ and $W^l_v$ used +to calculate keys and values for among the layers. +We then calculate the keys and values for each step only once and keep +them cached. +The second half of this file implements this. +We implemented a custom PyTorch function to improve performance.

+

Here’s the training code and a notebook for training a feedback transformer on Tiny Shakespeare dataset.

Open In Colab View Run

-
35import math
-36from typing import Optional
-37
-38import torch
-39from torch import nn
-40
-41from labml_helpers.module import Module
-42from labml_nn.transformers.mha import PrepareForMultiHeadAttention
-43from labml_nn.transformers.feed_forward import FeedForward
-44from labml_nn.utils import clone_module_list
+
43import math
+44from typing import Optional
+45
+46import torch
+47from torch import nn
+48
+49from labml_helpers.module import Module
+50from labml_nn.transformers.feed_forward import FeedForward
+51from labml_nn.transformers.mha import PrepareForMultiHeadAttention
+52from labml_nn.utils import clone_module_list
@@ -118,7 +125,7 @@ paper.

-
47class FeedbackAttention(Module):
+
55class FeedbackAttention(Module):
@@ -130,10 +137,12 @@ paper.

  • ‘heads’ is the number of attention heads
  • d_model is the number of features in the transformer
  • dropout_prob is the attention dropout probability
  • +
  • is_kv_precomputed is whether key, value tensors are already calculated
  • -
    58    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
    +
    66    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
    +67                 is_kv_precomputed: bool = False):
    @@ -144,7 +153,7 @@ paper.

    -
    65        super().__init__()
    +
    75        super().__init__()
    @@ -155,7 +164,7 @@ paper.

    Number of features per head

    -
    68        self.d_k = d_model // heads
    +
    78        self.d_k = d_model // heads
    @@ -166,7 +175,7 @@ paper.

    -
    70        self.heads = heads
    +
    80        self.heads = heads
    @@ -174,12 +183,10 @@ paper.

    -

    These transform the query, key and value vectors for multi-headed attention.

    +

    These transform the query multi-headed attention.

    -
    73        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=False)
    -74        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=False)
    -75        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=True)
    +
    83        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
    @@ -187,10 +194,12 @@ paper.

    -

    Output layer

    +

    These transform the key and value fir multi-headed attention.

    -
    78        self.output = nn.Linear(d_model, d_model)
    +
    85        if not is_kv_precomputed:
    +86            self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
    +87            self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
    @@ -198,10 +207,12 @@ paper.

    -

    Dropout

    +

    Keys and values are already calculated

    -
    80        self.dropout = nn.Dropout(dropout_prob)
    +
    89        else:
    +90            self.key = None
    +91            self.value = None
    @@ -209,10 +220,10 @@ paper.

    -

    Scaling factor before the softmax

    +

    Output layer

    -
    82        self.scale = 1 / math.sqrt(self.d_k)
    +
    94        self.output = nn.Linear(d_model, d_model)
    @@ -220,10 +231,10 @@ paper.

    -

    Softmax for attention along the time dimension of key

    +

    Dropout

    -
    85        self.softmax = nn.Softmax(dim=0)
    +
    96        self.dropout = nn.Dropout(dropout_prob)
    @@ -231,10 +242,10 @@ paper.

    -

    Number of relative positions

    +

    Scaling factor before the softmax

    -
    88        self.P = 2 ** 12
    +
    98        self.scale = 1 / math.sqrt(self.d_k)
    @@ -242,10 +253,10 @@ paper.

    -

    Relative positional embeddings for key relative to the query.

    +

    Softmax for attention along the time dimension of key

    -
    91        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
    +
    101        self.softmax = nn.Softmax(dim=0)
    @@ -253,10 +264,10 @@ paper.

    -

    Positional embeddings for the query is independent of the position of the query

    +

    Number of relative positions

    -
    93        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
    +
    104        self.P = 2 ** 12
    @@ -264,32 +275,21 @@ paper.

    -

    We store attentions so that it can used for logging, or other computations if needed

    +

    Relative positional embeddings for key relative to the query.

    -
    96        self.attn = None
    +
    107        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
    -
    +
    -

    Get attention scores

    -

    - -

    -

    where $Q, K_j$, are linear transformations of - original embeddings $X^q, X^k_j$ - and $U^Q, U^K_j$ are linear transformations of - absolute positional encodings $P_q, P_j$.

    +

    Relative positional embedding bias for key relative to the query.

    -
    98    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
    +
    109        self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
    @@ -297,10 +297,10 @@ A_{j} &= Q^\top K_j \\ -

    $U^K_j$

    +

    Positional embeddings for the query is independent of the position of the query

    -
    115        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
    +
    111        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
    @@ -308,38 +308,51 @@ A_{j} &= Q^\top K_j \\ -

    $U^Q$

    +

    We store attentions so that it can used for logging, or other computations if needed

    -
    117        query_pos_bias = self.query_pos_bias[None, :, :]
    +
    114        self.attn = None
    -
    +
    -

    $(Q + U^Q)^\top(K_j + U^K_j)$

    +

    Get attention scores

    +

    We use relative positional encodings for attention, similar +to relative multi-head attention form Transformer-XL paper.

    +

    Attention from current step’s query to key in step $j$ (relative to current step) is,

    +

    + +

    +

    where $Q, K_j$, are linear transformations of + original embeddings $X^q, X^k_j$ + and $U^Q, U^K_j$ are linear transformations of + positional encodings $P_q, P_j$.

    +

    We replace term $\color{lightgreen}{D}$ with $S_j$.

    -
    120        return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
    +
    116    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
    -
    +
    -
      -
    • query has shape [batch_size, d_model]
    • -
    • key and value has shape [seq_len, batch_size, d_model]
    • -
    +

    $U^K_j$

    -
    122    def __call__(self, *,
    -123                 query: torch.Tensor,
    -124                 key: torch.Tensor,
    -125                 value: torch.Tensor):
    +
    144        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
    @@ -347,14 +360,10 @@ A_{j} &= Q^\top K_j \\ -

    Prepare query, key and value for attention computation -key and value will then have shape [seq_len, batch_size, heads, d_k] -and query will have shape [batch_size, heads, d_k]

    +

    $U^Q$

    -
    134        query = self.query(query)
    -135        key = self.key(key)
    -136        value = self.value(value)
    +
    146        query_pos_bias = self.query_pos_bias[None, :, :]
    @@ -362,11 +371,10 @@ and query will have shape [batch_size, heads, d_k]

    -

    Compute attention scores -Results in a tensor of shape [seq_len, batch_size, heads]

    +

    $S_j$

    -
    140        scores = self.get_scores(query, key)
    +
    148        key_pos_bias = self.key_pos_bias[-key.shape[0]:]
    @@ -374,10 +382,10 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Scale scores $\frac{1}{\sqrt{d_k}}$

    +

    $\underset{\color{lightgreen}{A}}{Q^\top K_j} + \underset{\color{lightgreen}{C}}{{U^Q}^\top K_j}$

    -
    143        scores *= self.scale
    +
    151        ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
    @@ -385,10 +393,10 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Softmax

    +

    $\underset{\color{lightgreen}{B}}{Q^\top U^K_j} + \underset{\color{lightgreen}{D}}{S_j}$

    -
    146        attn = self.softmax(scores)
    +
    153        bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]
    @@ -396,21 +404,27 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Apply dropout

    +

    $A_j$

    -
    149        attn = self.dropout(attn)
    +
    156        return ac + bd
    -
    +
    -

    Multiply by the values

    +
      +
    • query has shape [batch_size, d_model]
    • +
    • key and value has shape [seq_len, batch_size, d_model]
    • +
    -
    152        x = torch.einsum("jbh,jbhd->bhd", attn, value)
    +
    158    def __call__(self, *,
    +159                 query: torch.Tensor,
    +160                 key: torch.Tensor,
    +161                 value: torch.Tensor):
    @@ -418,10 +432,16 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Concatenate multiple heads

    +

    Prepare query, key and value for attention computation +key and value will then have shape [seq_len, batch_size, heads, d_k] +and query will have shape [batch_size, heads, d_k]

    -
    155        x = x.reshape(x.shape[0], -1)
    +
    170        query = self.query(query)
    +171        if self.key:
    +172            key = self.key(key)
    +173        if self.value:
    +174            value = self.value(value)
    @@ -429,28 +449,95 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Output layer

    +

    Compute attention scores +Results in a tensor of shape [seq_len, batch_size, heads]

    -
    158        return self.output(x)
    +
    178        scores = self.get_scores(query, key)
    -
    +
    +

    Scale scores $\frac{1}{\sqrt{d_k}}$

    +
    +
    +
    181        scores *= self.scale
    +
    +
    +
    +
    + +

    Softmax

    +
    +
    +
    184        attn = self.softmax(scores)
    +
    +
    +
    +
    + +

    Apply dropout

    +
    +
    +
    187        attn = self.dropout(attn)
    +
    +
    +
    +
    + +

    Multiply by the values

    +
    +
    +
    190        x = torch.einsum("jbh,jbhd->bhd", attn, value)
    +
    +
    +
    +
    + +

    Concatenate multiple heads

    +
    +
    +
    193        x = x.reshape(x.shape[0], -1)
    +
    +
    +
    +
    + +

    Output layer

    +
    +
    +
    196        return self.output(x)
    +
    +
    +
    +
    +

    Feedback Transformer Layer

    This implements a single transformer layer in the feedback transformer.

    -
    161class FeedbackTransformerLayer(Module):
    +
    199class FeedbackTransformerLayer(Module):
    -
    +
    • d_model is the number of features in the transformer
    • @@ -460,82 +547,11 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -
    168    def __init__(self, *,
    -169                 d_model: int,
    -170                 attn: FeedbackAttention,
    -171                 feed_forward: FeedForward,
    -172                 dropout_prob: float):
    -
    -
    -
    -
    - - -
    -
    -
    179        super().__init__()
    -
    -
    -
    -
    - -

    Transformer size $d_{model}$

    -
    -
    -
    181        self.size = d_model
    -
    -
    -
    -
    - - -
    -
    -
    183        self.attn = attn
    -184        self.feed_forward = feed_forward
    -185        self.dropout = nn.Dropout(dropout_prob)
    -
    -
    -
    -
    - -

    Normalization layers

    -
    -
    -
    188        self.norm_self_attn = nn.LayerNorm([d_model])
    -189        self.norm_ff = nn.LayerNorm([d_model])
    -
    -
    -
    -
    - - -
    -
    -
    191    def __call__(self, *,
    -192                 x: torch.Tensor,
    -193                 mem: Optional[torch.Tensor]):
    -
    -
    -
    -
    - -

    If there is memory

    -
    -
    -
    195        if mem is not None:
    +
    206    def __init__(self, *,
    +207                 d_model: int,
    +208                 attn: FeedbackAttention,
    +209                 feed_forward: FeedForward,
    +210                 dropout_prob: float):
    @@ -543,10 +559,10 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Normalize the vectors before doing self attention

    +
    -
    197            z = self.norm_self_attn(x)
    +
    217        super().__init__()
    @@ -554,10 +570,10 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Run through self attention, i.e. keys and values are from self

    +

    Transformer size $d_{model}$

    -
    199            self_attn = self.attn(query=z, key=mem, value=mem)
    +
    219        self.size = d_model
    @@ -565,10 +581,12 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Add the self attention results

    +
    -
    201            x = x + self.dropout(self_attn)
    +
    221        self.attn = attn
    +222        self.feed_forward = feed_forward
    +223        self.dropout = nn.Dropout(dropout_prob)
    @@ -576,10 +594,11 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Normalize for feed-forward

    +

    Normalization layers

    -
    204        z = self.norm_ff(x)
    +
    226        self.norm_self_attn = nn.LayerNorm([d_model])
    +227        self.norm_ff = nn.LayerNorm([d_model])
    @@ -587,10 +606,13 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Pass through the feed-forward network

    +
    -
    206        ff = self.feed_forward(z)
    +
    229    def __call__(self, *,
    +230                 x: torch.Tensor,
    +231                 key: Optional[torch.Tensor],
    +232                 value: Optional[torch.Tensor]):
    @@ -598,10 +620,10 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Add the feed-forward results back

    +

    If there is memory

    -
    208        x = x + self.dropout(ff)
    +
    234        if key is not None:
    @@ -609,35 +631,32 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    - +

    Normalize the vectors before doing self attention

    -
    211        return x
    +
    236            z = self.norm_self_attn(x)
    -
    +
    -

    Feedback Transformer Module

    +

    Run through self attention, i.e. keys and values are from self

    -
    214class FeedbackTransformer(Module):
    +
    238            self_attn = self.attn(query=z, key=key, value=value)
    -
    +
    -
      -
    • layer is the feedback transformer layer, which we clone for each layer
    • -
    • n_layers is the number of layers in the transformer
    • -
    +

    Add the self attention results

    -
    219    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
    +
    240            x = x + self.dropout(self_attn)
    @@ -645,10 +664,10 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    - +

    Normalize for feed-forward

    -
    225        super().__init__()
    +
    243        z = self.norm_ff(x)
    @@ -656,10 +675,10 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Make copies of the transformer layer

    +

    Pass through the feed-forward network

    -
    227        self.layers = clone_module_list(layer, n_layers)
    +
    245        ff = self.feed_forward(z)
    @@ -667,10 +686,10 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Final normalization layer

    +

    Add the feed-forward results back

    -
    229        self.norm = nn.LayerNorm([layer.size])
    +
    247        x = x + self.dropout(ff)
    @@ -678,22 +697,21 @@ Results in a tensor of shape [seq_len, batch_size, heads]

    -

    Memory vectors are computed as a weighted sum of representations of each layer. -This is the weights parameter for that.

    +
    -
    232        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
    +
    250        return x
    -
    +
    -

    Softmax for weights before taking the weighted sum

    +

    Feedback Transformer Module

    -
    234        self.softmax = nn.Softmax(0)
    +
    253class FeedbackTransformer(Module):
    @@ -702,11 +720,12 @@ This is the weights parameter for that.

    #
    -
    236    def __call__(self, x_seq: torch.Tensor):
    +
    258    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
    @@ -714,10 +733,10 @@ This is the weights parameter for that.

    -

    Split the input to a list along the sequence axis

    +
    -
    242        x_seq = torch.unbind(x_seq, dim=0)
    +
    264        super().__init__()
    @@ -725,10 +744,10 @@ This is the weights parameter for that.

    -

    List to store the outputs

    +

    Make copies of the transformer layer

    -
    244        res = []
    +
    266        self.layers = clone_module_list(layer, n_layers)
    @@ -736,10 +755,10 @@ This is the weights parameter for that.

    -

    List to store the memory vectors

    +

    Final normalization layer

    -
    246        mem = []
    +
    268        self.norm = nn.LayerNorm([layer.size])
    @@ -747,10 +766,11 @@ This is the weights parameter for that.

    -

    For each input step

    +

    Memory vectors are computed as a weighted sum of representations of each layer. +This is the weights parameter for that.

    -
    248        for x in x_seq:
    +
    271        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
    @@ -758,21 +778,23 @@ This is the weights parameter for that.

    -

    List to store layer outputs

    +

    Softmax for weights before taking the weighted sum

    -
    250            layer_outputs = [x]
    +
    273        self.softmax = nn.Softmax(0)
    -
    +
    -

    If there is memory, stack them into a vector

    +
      +
    • x_seq is the input with shape [seq_len, batch_size, d_model]
    • +
    -
    253            mem_tensor = torch.stack(mem) if mem else None
    +
    275    def __call__(self, x_seq: torch.Tensor):
    @@ -780,10 +802,10 @@ This is the weights parameter for that.

    -

    Run through each layer

    +

    Split the input to a list along the sequence axis

    -
    256            for layer in self.layers:
    +
    281        x_seq = torch.unbind(x_seq, dim=0)
    @@ -791,10 +813,10 @@ This is the weights parameter for that.

    -

    Get layer output

    +

    List to store the outputs

    -
    258                x = layer(x=x, mem=mem_tensor)
    +
    283        res = []
    @@ -802,10 +824,10 @@ This is the weights parameter for that.

    -

    Append them to the list of layer outputs

    +

    List to store the memory vectors

    -
    260                layer_outputs.append(x)
    +
    285        mem = []
    @@ -813,10 +835,10 @@ This is the weights parameter for that.

    -

    Stack the layer outputs to a tensor

    +

    For each input step

    -
    263            layer_outputs = torch.stack(layer_outputs)
    +
    287        for x in x_seq:
    @@ -824,10 +846,10 @@ This is the weights parameter for that.

    -

    Calculate the memory vector as a weighted sum of layer outputs

    +

    List to store layer outputs

    -
    265            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))
    +
    289            layer_outputs = [x]
    @@ -835,10 +857,10 @@ This is the weights parameter for that.

    -

    Append the output to results

    +

    If there is memory, stack them into a vector

    -
    267            res.append(x)
    +
    292            mem_tensor = torch.stack(mem) if mem else None
    @@ -846,10 +868,10 @@ This is the weights parameter for that.

    -

    Stack the output tensors

    +

    Run through each layer

    -
    270        res = torch.stack(res)
    +
    295            for layer in self.layers:
    @@ -857,10 +879,765 @@ This is the weights parameter for that.

    +

    Get layer output

    +
    +
    +
    297                x = layer(x=x, key=mem_tensor, value=mem_tensor)
    +
    + +
    +
    + +

    Append them to the list of layer outputs

    +
    +
    +
    299                layer_outputs.append(x)
    +
    +
    +
    +
    + +

    Stack the layer outputs to a tensor

    +
    +
    +
    302            layer_outputs = torch.stack(layer_outputs)
    +
    +
    +
    +
    + +

    Calculate the memory vector as a weighted sum of layer outputs

    +
    +
    +
    304            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))
    +
    +
    +
    +
    + +

    Append the output to results

    +
    +
    +
    306            res.append(x)
    +
    +
    +
    +
    + +

    Stack the output tensors

    +
    +
    +
    309        res = torch.stack(res)
    +
    +
    +
    +
    +

    Normalize the output

    -
    272        return self.norm(res)
    +
    311        return self.norm(res)
    +
    +
    +
    +
    + +

    +

    Shared keys and values for among layers

    +

    +
    +
    +
    +
    +
    +
    +
    + +

    Stack Function implementation

    +

    We implement a custom function instead of appending to a python list +and then doing torch.stack. +This greatly improves the performance over calling torch.stack at +each step along the sequence. +Everytime torch.stack is called it creates a new tensor, while +this method and the accompanying class Stack share memory for each step.

    +
    +
    +
    318class StackFunction(torch.autograd.Function):
    +
    +
    +
    +
    + + +

    This returns the stacked tensor for steps upto n.

    +
    +
    +
    330    @staticmethod
    +331    def forward(ctx, memory, memory_grad, last, n):
    +
    +
    +
    +
    + +

    Cache accumulated gradients

    +
    +
    +
    343        ctx._mem_grad = memory_grad
    +
    +
    +
    +
    + +

    Cache the size of the stack

    +
    +
    +
    345        ctx._n = n
    +
    +
    +
    +
    + +

    Return the stack

    +
    +
    +
    347        return memory[:n + 1]
    +
    +
    +
    +
    + + +

    This accumulates the gradients in the shared memory tensor and return the +gradients with respect to the last result in the stack.

    +
    +
    +
    349    @staticmethod
    +350    def backward(ctx, grad_output):
    +
    +
    +
    +
    + +

    Get the current size of the stack

    +
    +
    +
    358        n = ctx._n
    +
    +
    +
    +
    + +

    Get the accumulated gradients

    +
    +
    +
    360        memory_grad = ctx._mem_grad
    +
    +
    +
    +
    + +

    Add the gradients

    +
    +
    +
    362        memory_grad[:n + 1] += grad_output
    +
    +
    +
    +
    + +

    Return the gradients w.r.t to last value in the stack

    +
    +
    +
    364        return None, None, memory_grad[n], None
    +
    +
    +
    +
    + +

    Stack Module

    +

    This uses the stack function defined above, and does the necessary initializations.

    +
    +
    +
    367class Stack:
    +
    +
    +
    +
    + + +
    +
    +
    373    def __init__(self, max_len: int):
    +
    +
    +
    +
    + + +
    +
    +
    377        self.max_len = max_len
    +378        self.memory = None
    +379        self.memory_grad = None
    +380        self.last = None
    +381        self.n = -1
    +382        self.last_get_n = -1
    +
    +
    +
    +
    + + +
    +
    +
    384    def append(self, n: int, value: torch.Tensor):
    +
    +
    +
    +
    + +

    You need to get (use) the stack after adding a value. +Otherwise this implementation fails

    +
    +
    +
    392        assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"
    +
    +
    +
    +
    + +

    Do this without gradients

    +
    +
    +
    395        with torch.no_grad():
    +
    +
    +
    +
    + +

    Initialize the shared memory tensor to keep the stack

    +
    +
    +
    397            if self.memory is None or self.memory.shape[1:] != value.shape:
    +
    +
    +
    +
    + +

    This should only happen when the stack is empty

    +
    +
    +
    399                assert n == 0
    +
    +
    +
    +
    + +

    Create a tensor for the stack

    +
    +
    +
    401                self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)
    +
    +
    +
    +
    + +

    Create a tensor to accumulate the gradients

    +
    +
    +
    403                self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)
    +
    +
    +
    +
    + +

    The memory is already initialized but we are resetting the stack.

    +

    This could have been another function like reset, but +we found this easier to use.

    +
    +
    +
    408            elif n == 0:
    +
    +
    +
    +
    + +

    Reset accumulated gradients

    +
    +
    +
    410                self.memory_grad.fill_(0.)
    +
    +
    +
    +
    + +

    Set the value in the correct position of the stack

    +
    +
    +
    413            self.memory.data[n] = value.detach()
    +
    +
    +
    +
    + +

    Keep track of the stack (for debugging)

    +
    +
    +
    415            self.n = n
    +
    +
    +
    +
    + +

    Keep track of the last value added to the stack. +We need this to be passed on to StackFunction in order +to get the gradients propagated backwards.

    +
    +
    +
    420        self.last = value
    +
    +
    +
    +
    + +

    Returns the stack

    +
    +
    +
    422    def get(self):
    +
    +
    +
    +
    + +

    Keep track of the size of the stack when it was used. +This is used for a sanity check in append.

    +
    +
    +
    429        self.last_get_n = self.n
    +
    +
    +
    +
    + +

    Take it all through StackFunction so that StackFunction.backwards +is called by PyTorch during backpropagation.

    +
    +
    +
    432        return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)
    +
    +
    +
    +
    + +

    Updated Feedback Transformer Module

    +

    This is the updated feedback transformer module that caches the keys and values.

    +
    +
    +
    435class FeedbackTransformerKV(Module):
    +
    +
    +
    +
    + + +
    +
    +
    442    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
    +
    +
    +
    +
    + + +
    +
    +
    450        super().__init__()
    +
    +
    +
    +
    + +

    Make copies of the transformer layer

    +
    +
    +
    452        self.layers = clone_module_list(layer, n_layers)
    +
    +
    +
    +
    + +

    Final normalization layer

    +
    +
    +
    454        self.norm = nn.LayerNorm([layer.size])
    +
    +
    +
    +
    + +

    Memory vectors are computed as a weighted sum of representations of each layer. +This is the weights parameter for that.

    +
    +
    +
    457        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
    +
    +
    +
    +
    + +

    Softmax for weights before taking the weighted sum

    +
    +
    +
    459        self.softmax = nn.Softmax(0)
    +
    +
    +
    +
    + +

    Number of features in a head

    +
    +
    +
    462        d_k = d_model // heads
    +
    +
    +
    +
    + +

    Module to transform embeddings (memory) to get keys

    +
    +
    +
    464        self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
    +
    +
    +
    +
    + +

    Module to transform embeddings (memory) to get keys

    +
    +
    +
    466        self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
    +
    +
    +
    +
    + +

    Memory for stacked keys

    +
    +
    +
    469        self.mem_key = Stack(512)
    +
    +
    +
    +
    + +

    Memory for stacked values

    +
    +
    +
    471        self.mem_value = Stack(512)
    +
    +
    +
    +
    + + +
    +
    +
    473    def __call__(self, x_seq: torch.Tensor):
    +
    +
    +
    +
    + +

    Split the input to a list along the sequence axis

    +
    +
    +
    479        x_seq = torch.unbind(x_seq, dim=0)
    +
    +
    +
    +
    + +

    List to store the outputs

    +
    +
    +
    481        res = []
    +
    +
    +
    +
    + +

    For each input step

    +
    +
    +
    483        for step, x in enumerate(x_seq):
    +
    +
    +
    +
    + +

    List to store layer outputs

    +
    +
    +
    485            layer_outputs = [x]
    +
    +
    +
    +
    + +

    Stack of keys and values

    +
    +
    +
    488            key_tensor = None
    +489            value_tensor = None
    +
    +
    +
    +
    + +

    Get the keys and values tensors if we are beyond the initial step

    +
    +
    +
    491            if step > 0:
    +492                key_tensor = self.mem_key.get()
    +493                value_tensor = self.mem_value.get()
    +
    +
    +
    +
    + +

    Run through each layer

    +
    +
    +
    496            for layer in self.layers:
    +
    +
    +
    +
    + +

    Get layer output

    +
    +
    +
    498                x = layer(x=x, key=key_tensor, value=value_tensor)
    +
    +
    +
    +
    + +

    Append them to the list of layer outputs

    +
    +
    +
    500                layer_outputs.append(x)
    +
    +
    +
    +
    + +

    Stack the layer outputs to a tensor

    +
    +
    +
    503            layer_outputs = torch.stack(layer_outputs)
    +
    +
    +
    +
    + +

    Calculate the memory vector as a weighted sum of layer outputs

    +
    +
    +
    505            mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))
    +
    +
    +
    +
    + +

    Calculate the keys from memory and add it to the stack

    +
    +
    +
    507            self.mem_key.append(step, self.key(mem))
    +
    +
    +
    +
    + +

    Calculate the values from memory and add it to the stack

    +
    +
    +
    509            self.mem_value.append(step, self.value(mem))
    +
    +
    +
    +
    + +

    Append the output to results

    +
    +
    +
    511            res.append(x)
    +
    +
    +
    +
    + +

    Stack the output tensors

    +
    +
    +
    514        res = torch.stack(res)
    +
    +
    +
    +
    + +

    Normalize the output

    +
    +
    +
    516        return self.norm(res)
    diff --git a/labml_nn/transformers/feedback/__init__.py b/labml_nn/transformers/feedback/__init__.py index 6db5e377..8d8b0862 100644 --- a/labml_nn/transformers/feedback/__init__.py +++ b/labml_nn/transformers/feedback/__init__.py @@ -22,11 +22,19 @@ In order to speed up the training the paper discusses starting with a short sequ gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point. -The feedback transformer doesn't keep the outputs of all layers. +The original feedback transformer doesn't keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction. +The first half of this file implements this. -Here's a notebook for training a feedback transformer on Tiny Shakespeare dataset. +The updated feedback transformer shares weights $W^l_k$ and $W^l_v$ used +to calculate keys and values for among the layers. +We then calculate the keys and values for each step only once and keep +them cached. +The [second half](#shared_kv) of this file implements this. +We implemented a custom PyTorch function to improve performance. + +Here's [the training code](experiment.html) and a notebook for training a feedback transformer on Tiny Shakespeare dataset. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb) [![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002) @@ -39,8 +47,8 @@ import torch from torch import nn from labml_helpers.module import Module -from labml_nn.transformers.mha import PrepareForMultiHeadAttention from labml_nn.transformers.feed_forward import FeedForward +from labml_nn.transformers.mha import PrepareForMultiHeadAttention from labml_nn.utils import clone_module_list @@ -61,6 +69,7 @@ class FeedbackAttention(Module): * 'heads' is the number of attention heads * `d_model` is the number of features in the transformer * `dropout_prob` is the attention dropout probability + * `is_kv_precomputed` is whether key, value tensors are already calculated """ super().__init__() @@ -70,11 +79,13 @@ class FeedbackAttention(Module): # self.heads = heads - # These transform the `query`, `key` and `value` vectors for multi-headed attention. + # These transform the `query` multi-headed attention. self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) + # These transform the `key` and `value` fir multi-headed attention. if not is_kv_precomputed: self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) + # Keys and values are already calculated else: self.key = None self.value = None @@ -94,6 +105,8 @@ class FeedbackAttention(Module): # Relative positional embeddings for key relative to the query. self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True) + # Relative positional embedding bias for key relative to the query. + self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True) # Positional embeddings for the query is independent of the position of the query self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) @@ -104,27 +117,42 @@ class FeedbackAttention(Module): """ ### Get attention scores + We use relative positional encodings for attention, similar + to [relative multi-head attention form Transformer-XL paper](../relative_mha.html). + + Attention from current step's query to key in step $j$ (relative to current step) is, + \begin{align} A_{j} &= Q^\top K_j \\ &= lin_q(X^q + P_q)^\top lin_k(X^k_j + P_j) \\ - &= (Q + U^Q)^\top(K_j + U^K_j) + &= (Q + U^Q)^\top(K_j + U^K_j) \\ + &= \underset{\color{lightgreen}{A}}{Q^\top K_j} + + \underset{\color{lightgreen}{B}}{Q^\top U^K_j} + + \underset{\color{lightgreen}{C}}{{U^Q}^\top K_j} + + \underset{\color{lightgreen}{D}}{{U^Q}^\top U^K_j} \end{align} where $Q, K_j$, are linear transformations of original embeddings $X^q, X^k_j$ and $U^Q, U^K_j$ are linear transformations of - absolute positional encodings $P_q, P_j$. + positional encodings $P_q, P_j$. + + We replace term $\color{lightgreen}{D}$ with $S_j$. """ # $U^K_j$ key_pos_emb = self.key_pos_embeddings[-key.shape[0]:] # $U^Q$ query_pos_bias = self.query_pos_bias[None, :, :] + # $S_j$ + key_pos_bias = self.key_pos_bias[-key.shape[0]:] - # $(Q + U^Q)^\top(K_j + U^K_j)$ + # $\underset{\color{lightgreen}{A}}{Q^\top K_j} + \underset{\color{lightgreen}{C}}{{U^Q}^\top K_j}$ ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key) - bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + # $\underset{\color{lightgreen}{B}}{Q^\top U^K_j} + \underset{\color{lightgreen}{D}}{S_j}$ + bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :] + # $A_j$ return ac + bd def __call__(self, *, @@ -283,23 +311,69 @@ class FeedbackTransformer(Module): return self.norm(res) +# +# # Shared keys and values for among layers +# + class StackFunction(torch.autograd.Function): + """ + ### Stack Function implementation + + We implement a custom function instead of appending to a python list + and then doing `torch.stack`. + This greatly improves the performance over calling `torch.stack` at + each step along the sequence. + Everytime `torch.stack` is called it creates a new tensor, while + this method and the accompanying class `Stack` share memory for each step. + """ + @staticmethod def forward(ctx, memory, memory_grad, last, n): + """ + * `ctx` is the context of the function (which lets as cache stuff) + * `memory` is the shared memory tensor where we stack and store the values of each step (keys & values) + * `memory_grad` is the shared memory tensor to store and accumulate gradients of each step + * `last` is the last value stacked + * `n` is the number of steps (i.e. size of the stack) + + This returns the stacked tensor for steps upto `n`. + """ + + # Cache accumulated gradients ctx._mem_grad = memory_grad + # Cache the size of the stack ctx._n = n + # Return the stack return memory[:n + 1] @staticmethod def backward(ctx, grad_output): + """ + * `grad_output` is the gradient with respect to the output of about `forward` function + + This accumulates the gradients in the shared memory tensor and return the + gradients with respect to the `last` result in the stack. + """ + # Get the current size of the stack n = ctx._n + # Get the accumulated gradients memory_grad = ctx._mem_grad + # Add the gradients memory_grad[:n + 1] += grad_output + # Return the gradients w.r.t to last value in the stack return None, None, memory_grad[n], None class Stack: + """ + ### Stack Module + + This uses the stack function defined above, and does the necessary initializations. + """ def __init__(self, max_len: int): + """ + * `max_len` is the maximum size of the stack + """ self.max_len = max_len self.memory = None self.memory_grad = None @@ -307,37 +381,70 @@ class Stack: self.n = -1 self.last_get_n = -1 - def append(self, n: int, vector: torch.Tensor): + def append(self, n: int, value: torch.Tensor): + """ + * `n` is the size of the stack + * `value` is the tensor that needs to be added to the stack + """ + + # You need to get (use) the stack after adding a value. + # Otherwise this implementation fails assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}" + # Do this without gradients with torch.no_grad(): - if self.memory is None or self.memory.shape[1:] != vector.shape: + # Initialize the shared memory tensor to keep the stack + if self.memory is None or self.memory.shape[1:] != value.shape: + # This should only happen when the stack is empty assert n == 0 - self.memory = vector.new_zeros(self.max_len, *vector.shape, requires_grad=False) - self.memory_grad = vector.new_zeros(self.memory.shape, requires_grad=False) + # Create a tensor for the stack + self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False) + # Create a tensor to accumulate the gradients + self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False) + # The memory is already initialized but we are resetting the stack. + # + # This could have been another function like `reset`, but + # we found this easier to use. elif n == 0: + # Reset accumulated gradients self.memory_grad.fill_(0.) - # memory[n] = vector.detach() - self.memory.data[n] = vector.detach() + # Set the value in the correct position of the stack + self.memory.data[n] = value.detach() + # Keep track of the stack (for debugging) self.n = n - self.last = vector + # Keep track of the last value added to the stack. + # We need this to be passed on to `StackFunction` in order + # to get the gradients propagated backwards. + self.last = value def get(self): + """ + Returns the stack + """ + + # Keep track of the size of the stack when it was used. + # This is used for a sanity check in `append`. self.last_get_n = self.n + # Take it all through `StackFunction` so that `StackFunction.backwards` + # is called by PyTorch during backpropagation. return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n) class FeedbackTransformerKV(Module): """ - ## Feedback Transformer Module + ## Updated Feedback Transformer Module + + This is the updated feedback transformer module that caches the keys and values. """ def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int): """ * `layer` is the feedback transformer layer, which we clone for each layer * `n_layers` is the number of layers in the transformer + * `d_model` is the number of features in the transformer + * 'heads' is the number of attention heads """ super().__init__() @@ -351,11 +458,16 @@ class FeedbackTransformerKV(Module): # Softmax for weights before taking the weighted sum self.softmax = nn.Softmax(0) + # Number of features in a head d_k = d_model // heads + # Module to transform embeddings (memory) to get keys self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False) + # Module to transform embeddings (memory) to get keys self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False) + # Memory for stacked keys self.mem_key = Stack(512) + # Memory for stacked values self.mem_value = Stack(512) def __call__(self, x_seq: torch.Tensor): @@ -372,9 +484,10 @@ class FeedbackTransformerKV(Module): # List to store layer outputs layer_outputs = [x] - # If there is memory, stack them into a vector + # Stack of keys and values key_tensor = None value_tensor = None + # Get the keys and values tensors if we are beyond the initial step if step > 0: key_tensor = self.mem_key.get() value_tensor = self.mem_value.get() @@ -390,7 +503,9 @@ class FeedbackTransformerKV(Module): layer_outputs = torch.stack(layer_outputs) # Calculate the memory vector as a weighted sum of layer outputs mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)) + # Calculate the keys from memory and add it to the stack self.mem_key.append(step, self.key(mem)) + # Calculate the values from memory and add it to the stack self.mem_value.append(step, self.value(mem)) # Append the output to results res.append(x) diff --git a/labml_nn/transformers/feedback/experiment.py b/labml_nn/transformers/feedback/experiment.py index 4499fb17..791869c2 100644 --- a/labml_nn/transformers/feedback/experiment.py +++ b/labml_nn/transformers/feedback/experiment.py @@ -7,6 +7,13 @@ summary: This is training code with notes for a feedback transformer. # Train Feedback Transformer This trains a [feedback transformer](index.html) model for auto-regression. +You can pick the original feedback transformer or the new version +where the keys and values are precalculated. + +Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb) +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002) """ import torch @@ -35,8 +42,9 @@ class AutoregressiveModel(Module): self.generator = nn.Linear(d_model, n_vocab) def __call__(self, x: torch.Tensor): + # Embed the tokens x = self.src_embed(x) - # Embed the tokens (`src`) and run it through the the transformer + # Run it through the the transformer res = self.transformer(x) # Generate logits of the next token return self.generator(res), None @@ -60,6 +68,9 @@ class Configs(NLPAutoRegressionConfigs): @option(Configs.model) def feedback_transformer(c: Configs): + """ + Create [original feedback transformer](index.html). + """ from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \ FeedbackAttention, FeedForward @@ -75,6 +86,9 @@ def feedback_transformer(c: Configs): @option(Configs.model) def feedback_transformer_kv(c: Configs): + """ + Create [updated feedback transformer](index.html#kv_shared), with precalculated keys and values. + """ from labml_nn.transformers.feedback import FeedbackTransformerKV, FeedbackTransformerLayer, \ FeedbackAttention, FeedForward @@ -104,6 +118,7 @@ def main(): 'prompt': 'It is', 'prompt_separator': '', + # Use `feedback_transformer` for original feedback transformer 'model': 'feedback_transformer_kv', 'train_loader': 'shuffled_train_loader', @@ -119,7 +134,7 @@ def main(): # Start the experiment with experiment.start(): - # `TrainValidConfigs.run` + # Run the training loop conf.run()