This is an annotated PyTorch experiment to train a gMLP model. The paper also applies a Stochastic Depth regularization where some layers are removed randomly during training. We have not implemented that here.
This is based on training loop and configurations for a simple transformer auto-regressive NLP task.
18from labml import experiment
19from labml.configs import option
20from labml_nn.transformers import TransformerConfigs
21from labml_nn.transformers.basic.autoregressive_experiment import Configs as BasicAutoRegressionConfigs
22from labml_nn.transformers.gmlp import GMLPBlockThis inherits from training loop and configurations for a simple transformer auto-regressive NLP task.
25class Configs(BasicAutoRegressionConfigs):Transformer
34    transformer: TransformerConfigs = 'gMLP'gMLP Block
36    gmlp: GMLPBlockd_ffn
 for gMLP projection layer 
38    d_ffn: int = 204841@option(Configs.gmlp, 'gMLP')
42def _gmlp_configs(c: Configs):46    return GMLPBlock(c.d_model, c.d_ffn, c.seq_len)49@option(Configs.transformer, 'gMLP')
50def _transformer_configs(c: Configs):We use our configurable transformer implementation
57    conf = TransformerConfigs()Set the vocabulary sizes for embeddings and generating logits
59    conf.n_src_vocab = c.n_tokens
60    conf.n_tgt_vocab = c.n_tokensSet model size
62    conf.d_model = c.d_modelReplace the encoder layer with a gMLP layer
64    conf.encoder_layer = c.gmlp
65
66    return conf69def main():Create experiment
71    experiment.create(name="gMLP")Create configs
73    conf = Configs()Override configurations
75    experiment.configs(conf, {Use character level tokenizer
77        'tokenizer': 'character',Prompt separator is blank
79        'prompt_separator': '',Starting prompt for sampling
81        'prompt': 'It is ',Use Tiny Shakespeare dataset
83        'text': 'tiny_shakespeare',Use a context size of
86        'seq_len': 256,Train for epochs
88        'epochs': 128,Batch size
90        'batch_size': 32,Switch between training and validation for times per epoch
93        'inner_iterations': 10,Model size
96        'd_model': 512,
97        'd_ffn': 2048,Use Noam optimizer
100        'optimizer.optimizer': 'Noam',
101        'optimizer.learning_rate': 1.,
102    })Set models for saving and loading
105    experiment.add_pytorch_models({'model': conf.model})Start the experiment
108    with experiment.start():Run training
110        conf.run()114if __name__ == '__main__':
115    main()