Attention Free Transformer (AFT) Experiment

This is an annotated PyTorch experiment to train a AFT model.

This is based on general training loop and configurations for auto-regressive NLP task.

View Run

16import torch
17
18from labml import experiment
19from labml.configs import option
20from labml_helpers.module import Module
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
22from labml_nn.transformers import TransformerConfigs, Encoder
23from labml_nn.transformers.utils import subsequent_mask

Simple autoregressive model

This consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.

26class AutoregressiveTransformer(Module):
34    def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
41        super().__init__()
42        self.src_embed = src_embed
43        self.encoder = encoder
44        self.generator = generator

The mask will be initialized on the first call

47        self.mask = None
49    def forward(self, x: torch.Tensor):

Create subsequent mask if mask is not initialized or if the size of the mask is different

52        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

54            self.mask = subsequent_mask(len(x)).to(x.device)

Get the token embeddings with positional encodings

57        x = self.src_embed(x)

Transformer encoder

59        x = self.encoder(x, self.mask)

Get logits

61        x = self.generator(x)

Return results (second value is for state, since our trainer is used with RNNs also)

65        return x, None

Configurations

This inherits from NLPAutoRegressionConfigs

68class Configs(NLPAutoRegressionConfigs):

GPT model

77    model: AutoregressiveTransformer

Transformer

79    transformer: TransformerConfigs
80
81    local_window_size: int = 32

Transformer configurations

84@option(Configs.transformer, 'Transformer')
85def _transformer_configs(c: Configs):
92    conf = TransformerConfigs()

Set the vocabulary sizes for embeddings and generating logits

94    conf.n_src_vocab = c.n_tokens
95    conf.n_tgt_vocab = c.n_tokens

Set the embedding size

97    conf.d_model = c.d_model

Replace self-attention with an AFT Local Module

99    from labml_nn.transformers.aft import AFTLocal
100    conf.encoder_attn = AFTLocal(c.d_model, c.seq_len, c.local_window_size)
103    return conf

Create an auto-regressive model

106@option(Configs.model)
107def _model(c: Configs):
111    m = AutoregressiveTransformer(c.transformer.encoder,
112                                  c.transformer.src_embed,
113                                  c.transformer.generator).to(c.device)
114
115    return m
118def main():

Create experiment

120    experiment.create(name="aft")

Create configs

122    conf = Configs()

Override configurations

124    experiment.configs(conf, {

Use character level tokenizer

126        'tokenizer': 'character',

Prompt separator is blank

128        'prompt_separator': '',

Starting prompt for sampling

130        'prompt': 'It is ',

Use Tiny Shakespeare dataset

132        'text': 'tiny_shakespeare',

Use a context size of $128$

135        'seq_len': 256,

Train for $32$ epochs

137        'epochs': 128,

Batch size $128$

139        'batch_size': 32,

Switch between training and validation for $10$ times per epoch

142        'inner_iterations': 10,

Embedding size

145        'd_model': 128,

FFN hidden dimension size

147        'transformer.ffn.d_ff': 256,

Optimizer

150        'optimizer.optimizer': 'Noam',
151        'optimizer.learning_rate': 1.,
152    })

Set models for saving and loading

155    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

158    with experiment.start():

Run training

160        conf.run()
164if __name__ == '__main__':
165    main()