Attention Free Transformer (AFT) Experiment

This is an annotated PyTorch experiment to train a AFT model.

This is based on general training loop and configurations for auto-regressive NLP task.

14import torch
15from labml import experiment
16from labml.configs import option
17from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
18from labml_nn.transformers import TransformerConfigs, Encoder
19from labml_nn.transformers.utils import subsequent_mask
20from torch import nn

Simple autoregressive model

This consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.

23class AutoregressiveTransformer(nn.Module):
31    def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):
38        super().__init__()
39        self.src_embed = src_embed
40        self.encoder = encoder
41        self.generator = generator

The mask will be initialized on the first call

44        self.mask = None
46    def forward(self, x: torch.Tensor):

Create subsequent mask if mask is not initialized or if the size of the mask is different

49        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

51            self.mask = subsequent_mask(len(x)).to(x.device)

Get the token embeddings with positional encodings

54        x = self.src_embed(x)

Transformer encoder

56        x = self.encoder(x, self.mask)

Get logits

58        x = self.generator(x)

Return results (second value is for state, since our trainer is used with RNNs also)

62        return x, None

Configurations

This inherits from NLPAutoRegressionConfigs

65class Configs(NLPAutoRegressionConfigs):

GPT model

74    model: AutoregressiveTransformer

Transformer

76    transformer: TransformerConfigs
77
78    local_window_size: int = 32

Transformer configurations

81@option(Configs.transformer, 'Transformer')
82def _transformer_configs(c: Configs):
89    conf = TransformerConfigs()

Set the vocabulary sizes for embeddings and generating logits

91    conf.n_src_vocab = c.n_tokens
92    conf.n_tgt_vocab = c.n_tokens

Set the embedding size

94    conf.d_model = c.d_model

Replace self-attention with an AFT Local Module

96    from labml_nn.transformers.aft import AFTLocal
97    conf.encoder_attn = AFTLocal(c.d_model, c.seq_len, c.local_window_size)

100    return conf

Create an auto-regressive model

103@option(Configs.model)
104def _model(c: Configs):
108    m = AutoregressiveTransformer(c.transformer.encoder,
109                                  c.transformer.src_embed,
110                                  c.transformer.generator).to(c.device)
111
112    return m
115def main():

Create experiment

117    experiment.create(name="aft")

Create configs

119    conf = Configs()

Override configurations

121    experiment.configs(conf, {

Use character level tokenizer

123        'tokenizer': 'character',

Prompt separator is blank

125        'prompt_separator': '',

Starting prompt for sampling

127        'prompt': 'It is ',

Use Tiny Shakespeare dataset

129        'text': 'tiny_shakespeare',

Use a context size of

132        'seq_len': 256,

Train for epochs

134        'epochs': 128,

Batch size

136        'batch_size': 32,

Switch between training and validation for times per epoch

139        'inner_iterations': 10,

Embedding size

142        'd_model': 128,

FFN hidden dimension size

144        'transformer.ffn.d_ff': 256,

Optimizer

147        'optimizer.optimizer': 'Noam',
148        'optimizer.learning_rate': 1.,
149    })

Set models for saving and loading

152    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

155    with experiment.start():

Run training

157        conf.run()

161if __name__ == '__main__':
162    main()