This is an annotated PyTorch experiment to train a AFT model.
This is based on general training loop and configurations for auto-regressive NLP task.
16import torch
17
18from labml import experiment
19from labml.configs import option
20from labml_helpers.module import Module
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
22from labml_nn.transformers import TransformerConfigs, Encoder
23from labml_nn.transformers.utils import subsequent_maskThis consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.
26class AutoregressiveTransformer(Module):encoder
 is the transformer Encoder src_embed
 is the token embedding module (with positional encodings) generator
 is the final fully connected layer that gives the logits.34    def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):41        super().__init__()
42        self.src_embed = src_embed
43        self.encoder = encoder
44        self.generator = generatorThe mask will be initialized on the first call
47        self.mask = None49    def forward(self, x: torch.Tensor):Create subsequent mask if mask is not initialized or if the size of the mask is different
52        if self.mask is None or self.mask.size(0) != len(x):Subsequent mask, will mask out tokens from seeing future tokens
54            self.mask = subsequent_mask(len(x)).to(x.device)Get the token embeddings with positional encodings
57        x = self.src_embed(x)Transformer encoder
59        x = self.encoder(x, self.mask)Get logits
61        x = self.generator(x)Return results (second value is for state, since our trainer is used with RNNs also)
65        return x, None68class Configs(NLPAutoRegressionConfigs):GPT model
77    model: AutoregressiveTransformerTransformer
79    transformer: TransformerConfigs
80
81    local_window_size: int = 3284@option(Configs.transformer, 'Transformer')
85def _transformer_configs(c: Configs):We use our configurable transformer implementation
92    conf = TransformerConfigs()Set the vocabulary sizes for embeddings and generating logits
94    conf.n_src_vocab = c.n_tokens
95    conf.n_tgt_vocab = c.n_tokensSet the embedding size
97    conf.d_model = c.d_modelReplace self-attention with an AFT Local Module
99    from labml_nn.transformers.aft import AFTLocal
100    conf.encoder_attn = AFTLocal(c.d_model, c.seq_len, c.local_window_size)103    return confCreate an auto-regressive model
106@option(Configs.model)
107def _model(c: Configs):111    m = AutoregressiveTransformer(c.transformer.encoder,
112                                  c.transformer.src_embed,
113                                  c.transformer.generator).to(c.device)
114
115    return m118def main():Create experiment
120    experiment.create(name="aft")Create configs
122    conf = Configs()Override configurations
124    experiment.configs(conf, {Use character level tokenizer
126        'tokenizer': 'character',Prompt separator is blank
128        'prompt_separator': '',Starting prompt for sampling
130        'prompt': 'It is ',Use Tiny Shakespeare dataset
132        'text': 'tiny_shakespeare',Use a context size of
135        'seq_len': 256,Train for epochs
137        'epochs': 128,Batch size
139        'batch_size': 32,Switch between training and validation for times per epoch
142        'inner_iterations': 10,Embedding size
145        'd_model': 128,FFN hidden dimension size
147        'transformer.ffn.d_ff': 256,Optimizer
150        'optimizer.optimizer': 'Noam',
151        'optimizer.learning_rate': 1.,
152    })Set models for saving and loading
155    experiment.add_pytorch_models({'model': conf.model})Start the experiment
158    with experiment.start():Run training
160        conf.run()164if __name__ == '__main__':
165    main()