This is an annotated PyTorch experiment to train a FNet model.
This is based on general training loop and configurations for AG News classification task.
15import torch
16from torch import nn
17
18from labml import experiment
19from labml.configs import option
20from labml_helpers.module import Module
21from labml_nn.experiments.nlp_classification import NLPClassificationConfigs
22from labml_nn.transformers import Encoder
23from labml_nn.transformers import TransformerConfigs26class TransformerClassifier(nn.Module):encoder
 is the transformer Encoder src_embed
 is the token embedding module (with positional encodings) generator
 is the final fully connected layer that gives the logits.30    def __init__(self, encoder: Encoder, src_embed: Module, generator: nn.Linear):37        super().__init__()
38        self.src_embed = src_embed
39        self.encoder = encoder
40        self.generator = generator42    def forward(self, x: torch.Tensor):Get the token embeddings with positional encodings
44        x = self.src_embed(x)Transformer encoder
46        x = self.encoder(x, None)Get logits for classification.
We set the [CLS]
 token at the last position of the sequence. This is extracted by x[-1]
, where x
 is of shape [seq_len, batch_size, d_model]
 
52        x = self.generator(x[-1])Return results (second value is for state, since our trainer is used with RNNs also)
56        return x, None59class Configs(NLPClassificationConfigs):Classification model
68    model: TransformerClassifierTransformer
70    transformer: TransformerConfigs73@option(Configs.transformer)
74def _transformer_configs(c: Configs):We use our configurable transformer implementation
81    conf = TransformerConfigs()Set the vocabulary sizes for embeddings and generating logits
83    conf.n_src_vocab = c.n_tokens
84    conf.n_tgt_vocab = c.n_tokens87    return conf Create FNetMix
 module that can replace the self-attention in transformer encoder layer .
90@option(TransformerConfigs.encoder_attn)
91def fnet_mix():97    from labml_nn.transformers.fnet import FNetMix
98    return FNetMix()Create classification model
101@option(Configs.model)
102def _model(c: Configs):106    m = TransformerClassifier(c.transformer.encoder,
107                              c.transformer.src_embed,
108                              nn.Linear(c.d_model, c.n_classes)).to(c.device)
109
110    return m113def main():Create experiment
115    experiment.create(name="fnet")Create configs
117    conf = Configs()Override configurations
119    experiment.configs(conf, {Use world level tokenizer
121        'tokenizer': 'basic_english',Train for epochs
124        'epochs': 32,Switch between training and validation for times per epoch
127        'inner_iterations': 10,Transformer configurations (same as defaults)
130        'transformer.d_model': 512,
131        'transformer.ffn.d_ff': 2048,
132        'transformer.n_heads': 8,
133        'transformer.n_layers': 6,Use Noam optimizer
140        'optimizer.optimizer': 'Noam',
141        'optimizer.learning_rate': 1.,
142    })Set models for saving and loading
145    experiment.add_pytorch_models({'model': conf.model})Start the experiment
148    with experiment.start():Run training
150        conf.run()154if __name__ == '__main__':
155    main()