This is an annotated PyTorch experiment to train a FNet model.
This is based on general training loop and configurations for AG News classification task.
15import torch
16from torch import nn
17
18from labml import experiment
19from labml.configs import option
20from labml_nn.experiments.nlp_classification import NLPClassificationConfigs
21from labml_nn.transformers import Encoder
22from labml_nn.transformers import TransformerConfigs25class TransformerClassifier(nn.Module):encoder
 is the transformer Encoder src_embed
 is the token embedding module (with positional encodings) generator
 is the final fully connected layer that gives the logits.29    def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Linear):36        super().__init__()
37        self.src_embed = src_embed
38        self.encoder = encoder
39        self.generator = generator41    def forward(self, x: torch.Tensor):Get the token embeddings with positional encodings
43        x = self.src_embed(x)Transformer encoder
45        x = self.encoder(x, None)Get logits for classification.
We set the [CLS]
 token at the last position of the sequence. This is extracted by x[-1]
, where x
 is of shape [seq_len, batch_size, d_model]
 
51        x = self.generator(x[-1])Return results (second value is for state, since our trainer is used with RNNs also)
55        return x, None58class Configs(NLPClassificationConfigs):Classification model
67    model: TransformerClassifierTransformer
69    transformer: TransformerConfigs72@option(Configs.transformer)
73def _transformer_configs(c: Configs):We use our configurable transformer implementation
80    conf = TransformerConfigs()Set the vocabulary sizes for embeddings and generating logits
82    conf.n_src_vocab = c.n_tokens
83    conf.n_tgt_vocab = c.n_tokens86    return conf Create FNetMix
 module that can replace the self-attention in transformer encoder layer .
89@option(TransformerConfigs.encoder_attn)
90def fnet_mix():96    from labml_nn.transformers.fnet import FNetMix
97    return FNetMix()Create classification model
100@option(Configs.model)
101def _model(c: Configs):105    m = TransformerClassifier(c.transformer.encoder,
106                              c.transformer.src_embed,
107                              nn.Linear(c.d_model, c.n_classes)).to(c.device)
108
109    return m112def main():Create experiment
114    experiment.create(name="fnet")Create configs
116    conf = Configs()Override configurations
118    experiment.configs(conf, {Use world level tokenizer
120        'tokenizer': 'basic_english',Train for epochs
123        'epochs': 32,Switch between training and validation for times per epoch
126        'inner_iterations': 10,Transformer configurations (same as defaults)
129        'transformer.d_model': 512,
130        'transformer.ffn.d_ff': 2048,
131        'transformer.n_heads': 8,
132        'transformer.n_layers': 6,Use Noam optimizer
139        'optimizer.optimizer': 'Noam',
140        'optimizer.learning_rate': 1.,
141    })Set models for saving and loading
144    experiment.add_pytorch_models({'model': conf.model})Start the experiment
147    with experiment.start():Run training
149        conf.run()153if __name__ == '__main__':
154    main()