This is an annotated PyTorch experiment to train a Masked Language Model.
11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_nn.helpers.metrics import Accuracy
20from labml_nn.helpers.trainer import BatchIndex
21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
22from labml_nn.transformers import Encoder, Generator
23from labml_nn.transformers import TransformerConfigs
24from labml_nn.transformers.mlm import MLM
27class TransformerMLM(nn.Module):
encoder
is the transformer Encoder src_embed
is the token embedding module (with positional encodings) generator
is the final fully connected layer that gives the logits.32 def __init__(self, *, encoder: Encoder, src_embed: nn.Module, generator: Generator):
39 super().__init__()
40 self.generator = generator
41 self.src_embed = src_embed
42 self.encoder = encoder
44 def forward(self, x: torch.Tensor):
Get the token embeddings with positional encodings
46 x = self.src_embed(x)
Transformer encoder
48 x = self.encoder(x, None)
Logits for the output
50 y = self.generator(x)
Return results (second value is for state, since our trainer is used with RNNs also)
54 return y, None
This inherits from NLPAutoRegressionConfigs
because it has the data pipeline implementations that we reuse here. We have implemented a custom training step form MLM.
57class Configs(NLPAutoRegressionConfigs):
MLM model
68 model: TransformerMLM
Transformer
70 transformer: TransformerConfigs
Number of tokens
73 n_tokens: int = 'n_tokens_mlm'
Tokens that shouldn't be masked
75 no_mask_tokens: List[int] = []
Probability of masking a token
77 masking_prob: float = 0.15
Probability of replacing the mask with a random token
79 randomize_prob: float = 0.1
Probability of replacing the mask with original token
81 no_change_prob: float = 0.1
Masked Language Model (MLM) class to generate the mask
83 mlm: MLM
[MASK]
token
86 mask_token: int
[PADDING]
token
88 padding_token: int
Prompt to sample
91 prompt: str = [
92 "We are accounted poor citizens, the patricians good.",
93 "What authority surfeits on would relieve us: if they",
94 "would yield us but the superfluity, while it were",
95 "wholesome, we might guess they relieved us humanely;",
96 "but they think we are too dear: the leanness that",
97 "afflicts us, the object of our misery, is as an",
98 "inventory to particularise their abundance; our",
99 "sufferance is a gain to them Let us revenge this with",
100 "our pikes, ere we become rakes: for the gods know I",
101 "speak this in hunger for bread, not in thirst for revenge.",
102 ]
104 def init(self):
[MASK]
token
110 self.mask_token = self.n_tokens - 1
[PAD]
token
112 self.padding_token = self.n_tokens - 2
Masked Language Model (MLM) class to generate the mask
115 self.mlm = MLM(padding_token=self.padding_token,
116 mask_token=self.mask_token,
117 no_mask_tokens=self.no_mask_tokens,
118 n_tokens=self.n_tokens,
119 masking_prob=self.masking_prob,
120 randomize_prob=self.randomize_prob,
121 no_change_prob=self.no_change_prob)
Accuracy metric (ignore the labels equal to [PAD]
)
124 self.accuracy = Accuracy(ignore_index=self.padding_token)
Cross entropy loss (ignore the labels equal to [PAD]
)
126 self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)
128 super().init()
130 def step(self, batch: any, batch_idx: BatchIndex):
Move the input to the device
136 data = batch[0].to(self.device)
Update global step (number of tokens processed) when in training mode
139 if self.mode.is_train:
140 tracker.add_global_step(data.shape[0] * data.shape[1])
Get the masked input and labels
143 with torch.no_grad():
144 data, labels = self.mlm(data)
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet.
149 output, *_ = self.model(data)
Calculate and log the loss
152 loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
153 tracker.add("loss.", loss)
Calculate and log accuracy
156 self.accuracy(output, labels)
157 self.accuracy.track()
Train the model
160 if self.mode.is_train:
Calculate gradients
162 loss.backward()
Clip gradients
164 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
166 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
168 if batch_idx.is_last:
169 tracker.add('model', self.model)
Clear the gradients
171 self.optimizer.zero_grad()
Save the tracked metrics
174 tracker.save()
176 @torch.no_grad()
177 def sample(self):
Empty tensor for data filled with [PAD]
.
183 data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)
Add the prompts one by one
185 for i, p in enumerate(self.prompt):
Get token indexes
187 d = self.text.text_to_i(p)
Add to the tensor
189 s = min(self.seq_len, len(d))
190 data[:s, i] = d[:s]
Move the tensor to current device
192 data = data.to(self.device)
Get masked input and labels
195 data, labels = self.mlm(data)
Get model outputs
197 output, *_ = self.model(data)
Print the samples generated
200 for j in range(data.shape[1]):
Collect output from printing
202 log = []
For each token
204 for i in range(len(data)):
If the label is not [PAD]
206 if labels[i, j] != self.padding_token:
Get the prediction
208 t = output[i, j].argmax().item()
If it's a printable character
210 if t < len(self.text.itos):
Correct prediction
212 if t == labels[i, j]:
213 log.append((self.text.itos[t], Text.value))
Incorrect prediction
215 else:
216 log.append((self.text.itos[t], Text.danger))
If it's not a printable character
218 else:
219 log.append(('*', Text.danger))
If the label is [PAD]
(unmasked) print the original.
221 elif data[i, j] < len(self.text.itos):
222 log.append((self.text.itos[data[i, j]], Text.subtle))
225 logger.log(log)
Number of tokens including [PAD]
and [MASK]
228@option(Configs.n_tokens)
229def n_tokens_mlm(c: Configs):
233 return c.text.n_tokens + 2
236@option(Configs.transformer)
237def _transformer_configs(c: Configs):
We use our configurable transformer implementation
244 conf = TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
246 conf.n_src_vocab = c.n_tokens
247 conf.n_tgt_vocab = c.n_tokens
Embedding size
249 conf.d_model = c.d_model
252 return conf
Create classification model
255@option(Configs.model)
256def _model(c: Configs):
260 m = TransformerMLM(encoder=c.transformer.encoder,
261 src_embed=c.transformer.src_embed,
262 generator=c.transformer.generator).to(c.device)
263
264 return m
267def main():
Create experiment
269 experiment.create(name="mlm")
Create configs
271 conf = Configs()
Override configurations
273 experiment.configs(conf, {
Batch size
275 'batch_size': 64,
Sequence length of . We use a short sequence length to train faster. Otherwise it takes forever to train.
278 'seq_len': 32,
Train for 1024 epochs.
281 'epochs': 1024,
Switch between training and validation for times per epoch
284 'inner_iterations': 1,
Transformer configurations (same as defaults)
287 'd_model': 128,
288 'transformer.ffn.d_ff': 256,
289 'transformer.n_heads': 8,
290 'transformer.n_layers': 6,
Use Noam optimizer
293 'optimizer.optimizer': 'Noam',
294 'optimizer.learning_rate': 1.,
295 })
Set models for saving and loading
298 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
301 with experiment.start():
Run training
303 conf.run()
307if __name__ == '__main__':
308 main()