Auto-regressive NLP model trainer

11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs

Cross entropy loss

28class CrossEntropyLoss(Module):
33    def __init__(self):
34        super().__init__()
35        self.loss = nn.CrossEntropyLoss()
37    def forward(self, outputs, targets):
38        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))

Trainer configurations

This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.

41class NLPAutoRegressionConfigs(TrainValidConfigs):

Optimizer

52    optimizer: torch.optim.Adam

Training device

54    device: torch.device = DeviceConfigs()

Autoregressive model

57    model: Module

Text dataset

59    text: TextDataset

Batch size

61    batch_size: int = 16

Length of the sequence, or context size

63    seq_len: int = 512

Number of token in vocabulary

65    n_tokens: int

Tokenizer

67    tokenizer: Callable = 'character'

Text prompt to start sampling (for illustration)

70    prompt: str

The token separator when sampling (blank for character level tokenization)

72    prompt_separator: str

Whether to periodically save models

75    is_save_models = True

Loss function

78    loss_func = CrossEntropyLoss()

Accuracy function

80    accuracy = Accuracy()

Model embedding size

82    d_model: int = 512

Gradient clipping

84    grad_norm_clip: float = 1.0

Training data loader

87    train_loader: DataLoader = 'shuffled_train_loader'

Validation data loader

89    valid_loader: DataLoader = 'shuffled_valid_loader'

Initialization

91    def init(self):

Set tracker configurations

96        tracker.set_scalar("accuracy.*", True)
97        tracker.set_scalar("loss.*", True)

Add a hook to log module outputs

99        hook_model_outputs(self.mode, self.model, 'model')

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

104        self.state_modules = [self.accuracy]

Override to calculate and log other metrics

106    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
108        pass

Training or validation step

110    def step(self, batch: any, batch_idx: BatchIndex):

Set training/eval mode

116        self.model.train(self.mode.is_train)

Move data to the device

119        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

122        if self.mode.is_train:
123            tracker.add_global_step(data.shape[0] * data.shape[1])

Whether to capture model outputs

126        with self.mode.update(is_log_activations=batch_idx.is_last):

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

130            output, *_ = self.model(data)

Calculate and log loss

133        loss = self.loss_func(output, target)
134        tracker.add("loss.", loss)

Calculate and log accuracy

137        self.accuracy(output, target)
138        self.accuracy.track()
139
140        self.other_metrics(output, target)

Train the model

143        if self.mode.is_train:

Calculate gradients

145            loss.backward()

Clip gradients

147            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

149            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

151            if batch_idx.is_last:
152                tracker.add('model', self.model)

Clear the gradients

154            self.optimizer.zero_grad()

Save the tracked metrics

157        tracker.save()

Sampling function to generate samples periodically while training

159    def sample(self):

Starting prompt

165        prompt = self.prompt

Collect output for printing

167        log = [(prompt, Text.subtle)]

Sample 25 tokens

169        for i in monit.iterate('Sample', 25):

Tokenize the prompt

171            data = self.text.text_to_i(prompt).unsqueeze(-1)
172            data = data.to(self.device)

Get the model output

174            output, *_ = self.model(data)

Get the model prediction (greedy)

176            output = output.argmax(dim=-1).squeeze()

Add the prediction to prompt

178            prompt += self.prompt_separator + self.text.itos[output[-1]]

Add the prediction for logging

180            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

Print the sampled output

183        logger.log(log)
186@option(NLPAutoRegressionConfigs.optimizer)
187def _optimizer(c: NLPAutoRegressionConfigs):
192    optimizer = OptimizerConfigs()
193    optimizer.parameters = c.model.parameters()
194    optimizer.optimizer = 'Adam'
195    optimizer.d_model = c.d_model
196
197    return optimizer

Get number of tokens

200@option(NLPAutoRegressionConfigs.n_tokens)
201def _n_tokens(c: NLPAutoRegressionConfigs):
205    return c.text.n_tokens

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

208@option(NLPAutoRegressionConfigs.tokenizer)
209def basic_english():
223    from torchtext.data import get_tokenizer
224    return get_tokenizer('basic_english')

Character level tokenizer

227def character_tokenizer(x: str):
231    return list(x)

Character level tokenizer configuration

234@option(NLPAutoRegressionConfigs.tokenizer)
235def character():
239    return character_tokenizer

Tiny Shakespeare dataset

It will download from the url if not present

242@option(NLPAutoRegressionConfigs.text)
243def tiny_shakespeare(c: NLPAutoRegressionConfigs):
249    return TextFileDataset(
250        lab.get_data_path() / 'tiny_shakespeare.txt',
251        c.tokenizer,
252        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

Sequential training data loader

255@option(NLPAutoRegressionConfigs.train_loader)
256def sequential_train_loader(c: NLPAutoRegressionConfigs):
260    return SequentialDataLoader(text=c.text.train,
261                                dataset=c.text,
262                                batch_size=c.batch_size,
263                                seq_len=c.seq_len)

Sequential validation data loader

266@option(NLPAutoRegressionConfigs.valid_loader)
267def sequential_valid_loader(c: NLPAutoRegressionConfigs):
271    return SequentialDataLoader(text=c.text.valid,
272                                dataset=c.text,
273                                batch_size=c.batch_size,
274                                seq_len=c.seq_len)

Transpose batch

DataLoader collects the batches on the first dimension. We need to transpose it to be sequence first.

277def transpose_batch(batch):
285    transposed_data = list(zip(*batch))

Stack the batch along the second dimension dim=1

287    src = torch.stack(transposed_data[0], dim=1)
288    tgt = torch.stack(transposed_data[1], dim=1)
289
290    return src, tgt

Shuffled training data loader

293@option(NLPAutoRegressionConfigs.train_loader)
294def shuffled_train_loader(c: NLPAutoRegressionConfigs):
298    return DataLoader(SequentialUnBatchedDataset(text=c.text.train,
299                                                 dataset=c.text,
300                                                 seq_len=c.seq_len),
301                      batch_size=c.batch_size,
302                      collate_fn=transpose_batch,
303                      shuffle=True)

Shuffled validation data loader

306@option(NLPAutoRegressionConfigs.valid_loader)
307def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
311    return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
312                                                 dataset=c.text,
313                                                 seq_len=c.seq_len),
314                      batch_size=c.batch_size,
315                      collate_fn=transpose_batch,
316                      shuffle=True)