This is an annotated PyTorch experiment to train a AFT model.
This is based on general training loop and configurations for auto-regressive NLP task.
16import torch
17
18from labml import experiment
19from labml.configs import option
20from labml_helpers.module import Module
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
22from labml_nn.transformers import TransformerConfigs, Encoder
23from labml_nn.transformers.utils import subsequent_mask
This consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.
26class AutoregressiveTransformer(Module):
encoder
is the transformer Encodersrc_embed
is the token
embedding module (with positional encodings)generator
is the final fully connected layer that gives the logits.34 def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
41 super().__init__()
42 self.src_embed = src_embed
43 self.encoder = encoder
44 self.generator = generator
The mask will be initialized on the first call
47 self.mask = None
49 def forward(self, x: torch.Tensor):
Create subsequent mask if mask is not initialized or if the size of the mask is different
52 if self.mask is None or self.mask.size(0) != len(x):
Subsequent mask, will mask out tokens from seeing future tokens
54 self.mask = subsequent_mask(len(x)).to(x.device)
Get the token embeddings with positional encodings
57 x = self.src_embed(x)
Transformer encoder
59 x = self.encoder(x, self.mask)
Get logits
61 x = self.generator(x)
Return results (second value is for state, since our trainer is used with RNNs also)
65 return x, None
68class Configs(NLPAutoRegressionConfigs):
GPT model
77 model: AutoregressiveTransformer
Transformer
79 transformer: TransformerConfigs
80
81 local_window_size: int = 32
84@option(Configs.transformer, 'Transformer')
85def _transformer_configs(c: Configs):
We use our configurable transformer implementation
92 conf = TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
94 conf.n_src_vocab = c.n_tokens
95 conf.n_tgt_vocab = c.n_tokens
Set the embedding size
97 conf.d_model = c.d_model
Replace self-attention with an AFT Local Module
99 from labml_nn.transformers.aft import AFTLocal
100 conf.encoder_attn = AFTLocal(c.d_model, c.seq_len, c.local_window_size)
103 return conf
Create an auto-regressive model
106@option(Configs.model)
107def _model(c: Configs):
111 m = AutoregressiveTransformer(c.transformer.encoder,
112 c.transformer.src_embed,
113 c.transformer.generator).to(c.device)
114
115 return m
118def main():
Create experiment
120 experiment.create(name="aft")
Create configs
122 conf = Configs()
Override configurations
124 experiment.configs(conf, {
Use character level tokenizer
126 'tokenizer': 'character',
Prompt separator is blank
128 'prompt_separator': '',
Starting prompt for sampling
130 'prompt': 'It is ',
Use Tiny Shakespeare dataset
132 'text': 'tiny_shakespeare',
Use a context size of $128$
135 'seq_len': 256,
Train for $32$ epochs
137 'epochs': 128,
Batch size $128$
139 'batch_size': 32,
Switch between training and validation for $10$ times per epoch
142 'inner_iterations': 10,
Embedding size
145 'd_model': 128,
FFN hidden dimension size
147 'transformer.ffn.d_ff': 256,
Optimizer
150 'optimizer.optimizer': 'Noam',
151 'optimizer.learning_rate': 1.,
152 })
Set models for saving and loading
155 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
158 with experiment.start():
Run training
160 conf.run()
164if __name__ == '__main__':
165 main()