This is an annotated PyTorch experiment to train a transformer model with Rotary Positional Embeddings (RoPE).
14from labml import experiment
15from labml.configs import option, calculate
16from labml_nn.transformers import TransformerConfigs
17from labml_nn.transformers.basic.autoregressive_experiment import AutoregressiveTransformer, Configs21def _rotary_pe_mha(c: TransformerConfigs):
22 from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
23 return RotaryPEMultiHeadAttention(c.n_heads, c.d_model)Configuration options
27calculate(TransformerConfigs.encoder_attn, 'rotary', _rotary_pe_mha)
28calculate(TransformerConfigs.decoder_attn, 'rotary', _rotary_pe_mha)
29calculate(TransformerConfigs.decoder_mem_attn, 'rotary', _rotary_pe_mha)Create an autoregressive model and initialize weights
32@option(Configs.model, 'rotary_pe_transformer')
33def _model(c: Configs):37 m = AutoregressiveTransformer(c.transformer.encoder,
38 c.transformer.src_embed,
39 c.transformer.generator).to(c.device)
40
41 return m44def main():Create experiment
46 experiment.create(name="rotary_pe_transformer")Create configs
48 conf = Configs()Override configurations
50 experiment.configs(conf, {
51 'device.cuda_device': 1,No fixed positional embeddings
54 'transformer.src_embed': 'no_pos',
55 'transformer.tgt_embed': 'no_pos',Encoder with RoPE
58 'transformer.encoder_attn': 'rotary',61 'model': 'rotary_pe_transformer',Use character level tokenizer
64 'tokenizer': 'character',Prompt separator is blank
66 'prompt_separator': '',Starting prompt for sampling
68 'prompt': 'It is ',Use Tiny Shakespeare dataset
70 'text': 'tiny_shakespeare',Use a context size of
73 'seq_len': 256,Train for epochs
75 'epochs': 128,Batch size
77 'batch_size': 32,Switch between training and validation for times per epoch
80 'inner_iterations': 10,Model size
83 'd_model': 128,
84 'transformer.ffn.d_ff': 256,Use Noam optimizer
87 'optimizer.optimizer': 'Noam',
88 'optimizer.learning_rate': 1.,
89 })Set models for saving and loading
92 experiment.add_pytorch_models({'model': conf.model})Start the experiment
95 with experiment.start():Run training
97 conf.run()101if __name__ == '__main__':
102 main()