This is an annotated PyTorch experiment to train a Masked Language Model.
11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_nn.helpers.metrics import Accuracy
20from labml_nn.helpers.trainer import BatchIndex
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
22from labml_nn.transformers import Encoder, Generator
23from labml_nn.transformers import TransformerConfigs
24from labml_nn.transformers.mlm import MLM27class TransformerMLM(nn.Module):encoder
 is the transformer Encoder src_embed
 is the token embedding module (with positional encodings) generator
 is the final fully connected layer that gives the logits.32    def __init__(self, *, encoder: Encoder, src_embed: nn.Module, generator: Generator):39        super().__init__()
40        self.generator = generator
41        self.src_embed = src_embed
42        self.encoder = encoder44    def forward(self, x: torch.Tensor):Get the token embeddings with positional encodings
46        x = self.src_embed(x)Transformer encoder
48        x = self.encoder(x, None)Logits for the output
50        y = self.generator(x)Return results (second value is for state, since our trainer is used with RNNs also)
54        return y, NoneThis inherits from NLPAutoRegressionConfigs
 because it has the data pipeline implementations that we reuse here. We have implemented a custom training step form MLM.
57class Configs(NLPAutoRegressionConfigs):MLM model
68    model: TransformerMLMTransformer
70    transformer: TransformerConfigsNumber of tokens
73    n_tokens: int = 'n_tokens_mlm'Tokens that shouldn't be masked
75    no_mask_tokens: List[int] = []Probability of masking a token
77    masking_prob: float = 0.15Probability of replacing the mask with a random token
79    randomize_prob: float = 0.1Probability of replacing the mask with original token
81    no_change_prob: float = 0.1Masked Language Model (MLM) class to generate the mask
83    mlm: MLM[MASK]
 token 
86    mask_token: int[PADDING]
 token 
88    padding_token: intPrompt to sample
91    prompt: str = [
92        "We are accounted poor citizens, the patricians good.",
93        "What authority surfeits on would relieve us: if they",
94        "would yield us but the superfluity, while it were",
95        "wholesome, we might guess they relieved us humanely;",
96        "but they think we are too dear: the leanness that",
97        "afflicts us, the object of our misery, is as an",
98        "inventory to particularise their abundance; our",
99        "sufferance is a gain to them Let us revenge this with",
100        "our pikes, ere we become rakes: for the gods know I",
101        "speak this in hunger for bread, not in thirst for revenge.",
102    ]104    def init(self):[MASK]
 token 
110        self.mask_token = self.n_tokens - 1[PAD]
 token 
112        self.padding_token = self.n_tokens - 2Masked Language Model (MLM) class to generate the mask
115        self.mlm = MLM(padding_token=self.padding_token,
116                       mask_token=self.mask_token,
117                       no_mask_tokens=self.no_mask_tokens,
118                       n_tokens=self.n_tokens,
119                       masking_prob=self.masking_prob,
120                       randomize_prob=self.randomize_prob,
121                       no_change_prob=self.no_change_prob)Accuracy metric (ignore the labels equal to [PAD]
) 
124        self.accuracy = Accuracy(ignore_index=self.padding_token)Cross entropy loss (ignore the labels equal to [PAD]
) 
126        self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)128        super().init()130    def step(self, batch: any, batch_idx: BatchIndex):Move the input to the device
136        data = batch[0].to(self.device)Update global step (number of tokens processed) when in training mode
139        if self.mode.is_train:
140            tracker.add_global_step(data.shape[0] * data.shape[1])Get the masked input and labels
143        with torch.no_grad():
144            data, labels = self.mlm(data)Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet.
149        output, *_ = self.model(data)Calculate and log the loss
152        loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
153        tracker.add("loss.", loss)Calculate and log accuracy
156        self.accuracy(output, labels)
157        self.accuracy.track()Train the model
160        if self.mode.is_train:Calculate gradients
162            loss.backward()Clip gradients
164            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
166            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
168            if batch_idx.is_last:
169                tracker.add('model', self.model)Clear the gradients
171            self.optimizer.zero_grad()Save the tracked metrics
174        tracker.save()176    @torch.no_grad()
177    def sample(self):Empty tensor for data filled with [PAD]
. 
183        data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)Add the prompts one by one
185        for i, p in enumerate(self.prompt):Get token indexes
187            d = self.text.text_to_i(p)Add to the tensor
189            s = min(self.seq_len, len(d))
190            data[:s, i] = d[:s]Move the tensor to current device
192        data = data.to(self.device)Get masked input and labels
195        data, labels = self.mlm(data)Get model outputs
197        output, *_ = self.model(data)Print the samples generated
200        for j in range(data.shape[1]):Collect output from printing
202            log = []For each token
204            for i in range(len(data)):If the label is not [PAD]
 
206                if labels[i, j] != self.padding_token:Get the prediction
208                    t = output[i, j].argmax().item()If it's a printable character
210                    if t < len(self.text.itos):Correct prediction
212                        if t == labels[i, j]:
213                            log.append((self.text.itos[t], Text.value))Incorrect prediction
215                        else:
216                            log.append((self.text.itos[t], Text.danger))If it's not a printable character
218                    else:
219                        log.append(('*', Text.danger))If the label is [PAD]
 (unmasked) print the original. 
221                elif data[i, j] < len(self.text.itos):
222                    log.append((self.text.itos[data[i, j]], Text.subtle))225            logger.log(log) Number of tokens including [PAD]
 and [MASK]
228@option(Configs.n_tokens)
229def n_tokens_mlm(c: Configs):233    return c.text.n_tokens + 2236@option(Configs.transformer)
237def _transformer_configs(c: Configs):We use our configurable transformer implementation
244    conf = TransformerConfigs()Set the vocabulary sizes for embeddings and generating logits
246    conf.n_src_vocab = c.n_tokens
247    conf.n_tgt_vocab = c.n_tokensEmbedding size
249    conf.d_model = c.d_model252    return confCreate classification model
255@option(Configs.model)
256def _model(c: Configs):260    m = TransformerMLM(encoder=c.transformer.encoder,
261                       src_embed=c.transformer.src_embed,
262                       generator=c.transformer.generator).to(c.device)
263
264    return m267def main():Create experiment
269    experiment.create(name="mlm")Create configs
271    conf = Configs()Override configurations
273    experiment.configs(conf, {Batch size
275        'batch_size': 64,Sequence length of . We use a short sequence length to train faster. Otherwise it takes forever to train.
278        'seq_len': 32,Train for 1024 epochs.
281        'epochs': 1024,Switch between training and validation for times per epoch
284        'inner_iterations': 1,Transformer configurations (same as defaults)
287        'd_model': 128,
288        'transformer.ffn.d_ff': 256,
289        'transformer.n_heads': 8,
290        'transformer.n_layers': 6,Use Noam optimizer
293        'optimizer.optimizer': 'Noam',
294        'optimizer.learning_rate': 1.,
295    })Set models for saving and loading
298    experiment.add_pytorch_models({'model': conf.model})Start the experiment
301    with experiment.start():Run training
303        conf.run()307if __name__ == '__main__':
308    main()