This is an annotated PyTorch experiment to train a Masked Language Model.
11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers import Encoder, Generator
24from labml_nn.transformers import TransformerConfigs
25from labml_nn.transformers.mlm import MLM28class TransformerMLM(nn.Module):encoder
 is the transformer Encoder src_embed
 is the token embedding module (with positional encodings) generator
 is the final fully connected layer that gives the logits.33    def __init__(self, *, encoder: Encoder, src_embed: Module, generator: Generator):40        super().__init__()
41        self.generator = generator
42        self.src_embed = src_embed
43        self.encoder = encoder45    def forward(self, x: torch.Tensor):Get the token embeddings with positional encodings
47        x = self.src_embed(x)Transformer encoder
49        x = self.encoder(x, None)Logits for the output
51        y = self.generator(x)Return results (second value is for state, since our trainer is used with RNNs also)
55        return y, NoneThis inherits from NLPAutoRegressionConfigs
 because it has the data pipeline implementations that we reuse here. We have implemented a custom training step form MLM.
58class Configs(NLPAutoRegressionConfigs):MLM model
69    model: TransformerMLMTransformer
71    transformer: TransformerConfigsNumber of tokens
74    n_tokens: int = 'n_tokens_mlm'Tokens that shouldn't be masked
76    no_mask_tokens: List[int] = []Probability of masking a token
78    masking_prob: float = 0.15Probability of replacing the mask with a random token
80    randomize_prob: float = 0.1Probability of replacing the mask with original token
82    no_change_prob: float = 0.1Masked Language Model (MLM) class to generate the mask
84    mlm: MLM[MASK]
 token 
87    mask_token: int[PADDING]
 token 
89    padding_token: intPrompt to sample
92    prompt: str = [
93        "We are accounted poor citizens, the patricians good.",
94        "What authority surfeits on would relieve us: if they",
95        "would yield us but the superfluity, while it were",
96        "wholesome, we might guess they relieved us humanely;",
97        "but they think we are too dear: the leanness that",
98        "afflicts us, the object of our misery, is as an",
99        "inventory to particularise their abundance; our",
100        "sufferance is a gain to them Let us revenge this with",
101        "our pikes, ere we become rakes: for the gods know I",
102        "speak this in hunger for bread, not in thirst for revenge.",
103    ]105    def init(self):[MASK]
 token 
111        self.mask_token = self.n_tokens - 1[PAD]
 token 
113        self.padding_token = self.n_tokens - 2Masked Language Model (MLM) class to generate the mask
116        self.mlm = MLM(padding_token=self.padding_token,
117                       mask_token=self.mask_token,
118                       no_mask_tokens=self.no_mask_tokens,
119                       n_tokens=self.n_tokens,
120                       masking_prob=self.masking_prob,
121                       randomize_prob=self.randomize_prob,
122                       no_change_prob=self.no_change_prob)Accuracy metric (ignore the labels equal to [PAD]
) 
125        self.accuracy = Accuracy(ignore_index=self.padding_token)Cross entropy loss (ignore the labels equal to [PAD]
) 
127        self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)129        super().init()131    def step(self, batch: any, batch_idx: BatchIndex):Move the input to the device
137        data = batch[0].to(self.device)Update global step (number of tokens processed) when in training mode
140        if self.mode.is_train:
141            tracker.add_global_step(data.shape[0] * data.shape[1])Get the masked input and labels
144        with torch.no_grad():
145            data, labels = self.mlm(data)Whether to capture model outputs
148        with self.mode.update(is_log_activations=batch_idx.is_last):Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet.
152            output, *_ = self.model(data)Calculate and log the loss
155        loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
156        tracker.add("loss.", loss)Calculate and log accuracy
159        self.accuracy(output, labels)
160        self.accuracy.track()Train the model
163        if self.mode.is_train:Calculate gradients
165            loss.backward()Clip gradients
167            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
169            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
171            if batch_idx.is_last:
172                tracker.add('model', self.model)Clear the gradients
174            self.optimizer.zero_grad()Save the tracked metrics
177        tracker.save()179    @torch.no_grad()
180    def sample(self):Empty tensor for data filled with [PAD]
. 
186        data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)Add the prompts one by one
188        for i, p in enumerate(self.prompt):Get token indexes
190            d = self.text.text_to_i(p)Add to the tensor
192            s = min(self.seq_len, len(d))
193            data[:s, i] = d[:s]Move the tensor to current device
195        data = data.to(self.device)Get masked input and labels
198        data, labels = self.mlm(data)Get model outputs
200        output, *_ = self.model(data)Print the samples generated
203        for j in range(data.shape[1]):Collect output from printing
205            log = []For each token
207            for i in range(len(data)):If the label is not [PAD]
 
209                if labels[i, j] != self.padding_token:Get the prediction
211                    t = output[i, j].argmax().item()If it's a printable character
213                    if t < len(self.text.itos):Correct prediction
215                        if t == labels[i, j]:
216                            log.append((self.text.itos[t], Text.value))Incorrect prediction
218                        else:
219                            log.append((self.text.itos[t], Text.danger))If it's not a printable character
221                    else:
222                        log.append(('*', Text.danger))If the label is [PAD]
 (unmasked) print the original. 
224                elif data[i, j] < len(self.text.itos):
225                    log.append((self.text.itos[data[i, j]], Text.subtle))228            logger.log(log) Number of tokens including [PAD]
 and [MASK]
231@option(Configs.n_tokens)
232def n_tokens_mlm(c: Configs):236    return c.text.n_tokens + 2239@option(Configs.transformer)
240def _transformer_configs(c: Configs):We use our configurable transformer implementation
247    conf = TransformerConfigs()Set the vocabulary sizes for embeddings and generating logits
249    conf.n_src_vocab = c.n_tokens
250    conf.n_tgt_vocab = c.n_tokensEmbedding size
252    conf.d_model = c.d_model255    return confCreate classification model
258@option(Configs.model)
259def _model(c: Configs):263    m = TransformerMLM(encoder=c.transformer.encoder,
264                       src_embed=c.transformer.src_embed,
265                       generator=c.transformer.generator).to(c.device)
266
267    return m270def main():Create experiment
272    experiment.create(name="mlm")Create configs
274    conf = Configs()Override configurations
276    experiment.configs(conf, {Batch size
278        'batch_size': 64,Sequence length of . We use a short sequence length to train faster. Otherwise it takes forever to train.
281        'seq_len': 32,Train for 1024 epochs.
284        'epochs': 1024,Switch between training and validation for times per epoch
287        'inner_iterations': 1,Transformer configurations (same as defaults)
290        'd_model': 128,
291        'transformer.ffn.d_ff': 256,
292        'transformer.n_heads': 8,
293        'transformer.n_layers': 6,Use Noam optimizer
296        'optimizer.optimizer': 'Noam',
297        'optimizer.learning_rate': 1.,
298    })Set models for saving and loading
301    experiment.add_pytorch_models({'model': conf.model})Start the experiment
304    with experiment.start():Run training
306        conf.run()310if __name__ == '__main__':
311    main()