This is an annotated PyTorch experiment to train a ALiBi model.
This is based on our GPT model.
16import torch
17from torch.utils.data import DataLoader
18
19from labml import experiment, tracker
20from labml.configs import option, calculate
21from labml_helpers.datasets.text import SequentialUnBatchedDataset
22from labml_nn.transformers.alibi import AlibiMultiHeadAttention
23from labml_nn.experiments.nlp_autoregression import transpose_batch
24from labml_nn.transformers import TransformerConfigs
25from labml_nn.transformers.gpt import Configs as GPTConfigs28class Configs(GPTConfigs):ALiBi based transformer (defined below)
36    transformer: TransformerConfigs = 'GPT_ALiBi'Longer validation set
38    valid_seq_len: int = 128
39    valid_loader = 'shuffled_longer_valid_loader'Log losses at the initial and final tokens
41    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):If there are more tokens that the training sequence length (during validation),
46        if self.seq_len < output.shape[0]:Log the loss at training sequence length
48            tracker.add(f'loss.{self.seq_len - 1}.', self.loss_func(output[self.seq_len - 1], target[self.seq_len - 1]))Log the loss at the first token
50            tracker.add(f'loss.0.', self.loss_func(output[0], target[0]))Log the loss at the final token
52        tracker.add(f'loss.{int(output.shape[0]) - 1}.', self.loss_func(output[-1], target[-1]))Create an ALiBi attention module
55def _alibi_mha(c: TransformerConfigs):59    return AlibiMultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)Set all attention mechanisms to ALiBi
63calculate(TransformerConfigs.encoder_attn, 'alibi_mha', _alibi_mha)
64calculate(TransformerConfigs.decoder_attn, 'alibi_mha', _alibi_mha)
65calculate(TransformerConfigs.decoder_mem_attn, 'alibi_mha', _alibi_mha) Shuffled validation data loader with valid_seq_len
 sequence length
68@option(Configs.valid_loader)
69def shuffled_longer_valid_loader(c: Configs):73    return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
74                                                 dataset=c.text,
75                                                 seq_len=c.valid_seq_len),
76                      batch_size=c.batch_size,
77                      collate_fn=transpose_batch,
78                      shuffle=True)81@option(Configs.transformer, 'GPT_ALiBi')
82def _transformer_configs(c: Configs):We use our configurable transformer implementation
89    conf = TransformerConfigs()Set the vocabulary sizes for embeddings and generating logits
91    conf.n_src_vocab = c.n_tokens
92    conf.n_tgt_vocab = c.n_tokensGPT uses GELU activation for position wise feedforward
94    conf.ffn.activation = 'GELU'ALiBi doesn't use positional embeddings
97    conf.src_embed = 'no_pos'
98    conf.tgt_embed = 'no_pos'Set all attention mechanisms to ALiBi
101    conf.encoder_attn = 'alibi_mha'
102    conf.decoder_attn = 'alibi_mha'
103    conf.decoder_mem_attn = 'alibi_mha'106    return conf109def main():Create experiment
111    experiment.create(name="gpt_alibi")Create configs
113    conf = Configs()Override configurations
115    experiment.configs(conf, {Use character level tokenizer
117        'tokenizer': 'character',Prompt separator is blank
119        'prompt_separator': '',Starting prompt for sampling
121        'prompt': 'It is ',Use Tiny Shakespeare dataset
123        'text': 'tiny_shakespeare','text': 'tiny_shakespeare_no_split',
Use a context size of
127        'seq_len': 64,Use a context size of
129        'valid_seq_len': 80,Train for epochs
131        'epochs': 128,Batch size
133        'batch_size': 128,Switch between training and validation for times per epoch
136        'inner_iterations': 10,Transformer configurations
139        'transformer.d_model': 128,
140        'transformer.ffn.d_ff': 512,
141        'transformer.n_heads': 8,
142        'transformer.n_layers': 4,
143        'transformer.dropout': 0.1,
144    })Set models for saving and loading
147    experiment.add_pytorch_models({'model': conf.model})Start the experiment
150    with experiment.start():Run training
152        conf.run()156if __name__ == '__main__':
157    main()