Masked Language Model (MLM) Experiment

This is an annotated PyTorch experiment to train a Masked Language Model.

11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_nn.helpers.metrics import Accuracy
20from labml_nn.helpers.trainer import BatchIndex
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
22from labml_nn.transformers import Encoder, Generator
23from labml_nn.transformers import TransformerConfigs
24from labml_nn.transformers.mlm import MLM

Transformer based model for MLM

27class TransformerMLM(nn.Module):
32    def __init__(self, *, encoder: Encoder, src_embed: nn.Module, generator: Generator):
39        super().__init__()
40        self.generator = generator
41        self.src_embed = src_embed
42        self.encoder = encoder
44    def forward(self, x: torch.Tensor):

Get the token embeddings with positional encodings

46        x = self.src_embed(x)

Transformer encoder

48        x = self.encoder(x, None)

Logits for the output

50        y = self.generator(x)

Return results (second value is for state, since our trainer is used with RNNs also)

54        return y, None

Configurations

This inherits from NLPAutoRegressionConfigs because it has the data pipeline implementations that we reuse here. We have implemented a custom training step form MLM.

57class Configs(NLPAutoRegressionConfigs):

MLM model

68    model: TransformerMLM

Transformer

70    transformer: TransformerConfigs

Number of tokens

73    n_tokens: int = 'n_tokens_mlm'

Tokens that shouldn't be masked

75    no_mask_tokens: List[int] = []

Probability of masking a token

77    masking_prob: float = 0.15

Probability of replacing the mask with a random token

79    randomize_prob: float = 0.1

Probability of replacing the mask with original token

81    no_change_prob: float = 0.1

Masked Language Model (MLM) class to generate the mask

83    mlm: MLM

[MASK] token

86    mask_token: int

[PADDING] token

88    padding_token: int

Prompt to sample

91    prompt: str = [
92        "We are accounted poor citizens, the patricians good.",
93        "What authority surfeits on would relieve us: if they",
94        "would yield us but the superfluity, while it were",
95        "wholesome, we might guess they relieved us humanely;",
96        "but they think we are too dear: the leanness that",
97        "afflicts us, the object of our misery, is as an",
98        "inventory to particularise their abundance; our",
99        "sufferance is a gain to them Let us revenge this with",
100        "our pikes, ere we become rakes: for the gods know I",
101        "speak this in hunger for bread, not in thirst for revenge.",
102    ]

Initialization

104    def init(self):

[MASK] token

110        self.mask_token = self.n_tokens - 1

[PAD] token

112        self.padding_token = self.n_tokens - 2

Masked Language Model (MLM) class to generate the mask

115        self.mlm = MLM(padding_token=self.padding_token,
116                       mask_token=self.mask_token,
117                       no_mask_tokens=self.no_mask_tokens,
118                       n_tokens=self.n_tokens,
119                       masking_prob=self.masking_prob,
120                       randomize_prob=self.randomize_prob,
121                       no_change_prob=self.no_change_prob)

Accuracy metric (ignore the labels equal to [PAD] )

124        self.accuracy = Accuracy(ignore_index=self.padding_token)

Cross entropy loss (ignore the labels equal to [PAD] )

126        self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)

128        super().init()

Training or validation step

130    def step(self, batch: any, batch_idx: BatchIndex):

Move the input to the device

136        data = batch[0].to(self.device)

Update global step (number of tokens processed) when in training mode

139        if self.mode.is_train:
140            tracker.add_global_step(data.shape[0] * data.shape[1])

Get the masked input and labels

143        with torch.no_grad():
144            data, labels = self.mlm(data)

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet.

149        output, *_ = self.model(data)

Calculate and log the loss

152        loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
153        tracker.add("loss.", loss)

Calculate and log accuracy

156        self.accuracy(output, labels)
157        self.accuracy.track()

Train the model

160        if self.mode.is_train:

Calculate gradients

162            loss.backward()

Clip gradients

164            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

166            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

168            if batch_idx.is_last:
169                tracker.add('model', self.model)

Clear the gradients

171            self.optimizer.zero_grad()

Save the tracked metrics

174        tracker.save()

Sampling function to generate samples periodically while training

176    @torch.no_grad()
177    def sample(self):

Empty tensor for data filled with [PAD] .

183        data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)

Add the prompts one by one

185        for i, p in enumerate(self.prompt):

Get token indexes

187            d = self.text.text_to_i(p)

Add to the tensor

189            s = min(self.seq_len, len(d))
190            data[:s, i] = d[:s]

Move the tensor to current device

192        data = data.to(self.device)

Get masked input and labels

195        data, labels = self.mlm(data)

Get model outputs

197        output, *_ = self.model(data)

Print the samples generated

200        for j in range(data.shape[1]):

Collect output from printing

202            log = []

For each token

204            for i in range(len(data)):

If the label is not [PAD]

206                if labels[i, j] != self.padding_token:

Get the prediction

208                    t = output[i, j].argmax().item()

If it's a printable character

210                    if t < len(self.text.itos):

Correct prediction

212                        if t == labels[i, j]:
213                            log.append((self.text.itos[t], Text.value))

Incorrect prediction

215                        else:
216                            log.append((self.text.itos[t], Text.danger))

If it's not a printable character

218                    else:
219                        log.append(('*', Text.danger))

If the label is [PAD] (unmasked) print the original.

221                elif data[i, j] < len(self.text.itos):
222                    log.append((self.text.itos[data[i, j]], Text.subtle))

Print

225            logger.log(log)

Number of tokens including [PAD] and [MASK]

228@option(Configs.n_tokens)
229def n_tokens_mlm(c: Configs):
233    return c.text.n_tokens + 2

Transformer configurations

236@option(Configs.transformer)
237def _transformer_configs(c: Configs):
244    conf = TransformerConfigs()

Set the vocabulary sizes for embeddings and generating logits

246    conf.n_src_vocab = c.n_tokens
247    conf.n_tgt_vocab = c.n_tokens

Embedding size

249    conf.d_model = c.d_model

252    return conf

Create classification model

255@option(Configs.model)
256def _model(c: Configs):
260    m = TransformerMLM(encoder=c.transformer.encoder,
261                       src_embed=c.transformer.src_embed,
262                       generator=c.transformer.generator).to(c.device)
263
264    return m
267def main():

Create experiment

269    experiment.create(name="mlm")

Create configs

271    conf = Configs()

Override configurations

273    experiment.configs(conf, {

Batch size

275        'batch_size': 64,

Sequence length of . We use a short sequence length to train faster. Otherwise it takes forever to train.

278        'seq_len': 32,

Train for 1024 epochs.

281        'epochs': 1024,

Switch between training and validation for times per epoch

284        'inner_iterations': 1,

Transformer configurations (same as defaults)

287        'd_model': 128,
288        'transformer.ffn.d_ff': 256,
289        'transformer.n_heads': 8,
290        'transformer.n_layers': 6,
293        'optimizer.optimizer': 'Noam',
294        'optimizer.learning_rate': 1.,
295    })

Set models for saving and loading

298    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

301    with experiment.start():

Run training

303        conf.run()

307if __name__ == '__main__':
308    main()