Auto-regressive NLP model trainer

11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader, RandomSampler
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs

Cross entropy loss

28class CrossEntropyLoss(Module):
33    def __init__(self):
34        super().__init__()
35        self.loss = nn.CrossEntropyLoss()
37    def forward(self, outputs, targets):
38        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))

Trainer configurations

This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.

41class NLPAutoRegressionConfigs(TrainValidConfigs):

Optimizer

52    optimizer: torch.optim.Adam

Training device

54    device: torch.device = DeviceConfigs()

Autoregressive model

57    model: Module

Text dataset

59    text: TextDataset

Batch size

61    batch_size: int = 16

Length of the sequence, or context size

63    seq_len: int = 512

Number of token in vocabulary

65    n_tokens: int

Tokenizer

67    tokenizer: Callable = 'character'

Text prompt to start sampling (for illustration)

70    prompt: str

The token separator when sampling (blank for character level tokenization)

72    prompt_separator: str

Whether to periodically save models

75    is_save_models = True

Loss function

78    loss_func = CrossEntropyLoss()

Accuracy function

80    accuracy = Accuracy()

Model embedding size

82    d_model: int = 512

Gradient clipping

84    grad_norm_clip: float = 1.0

Training data loader

87    train_loader: DataLoader = 'shuffled_train_loader'

Validation data loader

89    valid_loader: DataLoader = 'shuffled_valid_loader'

Data loaders shuffle with replacement

92    dataloader_shuffle_with_replacement: bool = False

Initialization

94    def init(self):

Set tracker configurations

99        tracker.set_scalar("accuracy.*", True)
100        tracker.set_scalar("loss.*", True)

Add a hook to log module outputs

102        hook_model_outputs(self.mode, self.model, 'model')

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

107        self.state_modules = [self.accuracy]

Override to calculate and log other metrics

109    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
111        pass

Training or validation step

113    def step(self, batch: any, batch_idx: BatchIndex):

Set training/eval mode

119        self.model.train(self.mode.is_train)

Move data to the device

122        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

125        if self.mode.is_train:
126            tracker.add_global_step(data.shape[0] * data.shape[1])

Whether to capture model outputs

129        with self.mode.update(is_log_activations=batch_idx.is_last):

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

133            output, *_ = self.model(data)

Calculate and log loss

136        loss = self.loss_func(output, target)
137        tracker.add("loss.", loss)

Calculate and log accuracy

140        self.accuracy(output, target)
141        self.accuracy.track()
142
143        self.other_metrics(output, target)

Train the model

146        if self.mode.is_train:

Calculate gradients

148            loss.backward()

Clip gradients

150            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

152            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

154            if batch_idx.is_last:
155                tracker.add('model', self.model)

Clear the gradients

157            self.optimizer.zero_grad()

Save the tracked metrics

160        tracker.save()

Sampling function to generate samples periodically while training

162    def sample(self):

Starting prompt

168        prompt = self.prompt

Collect output for printing

170        log = [(prompt, Text.subtle)]

Sample 25 tokens

172        for i in monit.iterate('Sample', 25):

Tokenize the prompt

174            data = self.text.text_to_i(prompt).unsqueeze(-1)
175            data = data.to(self.device)

Get the model output

177            output, *_ = self.model(data)

Get the model prediction (greedy)

179            output = output.argmax(dim=-1).squeeze()

Add the prediction to prompt

181            prompt += self.prompt_separator + self.text.itos[output[-1]]

Add the prediction for logging

183            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

Print the sampled output

186        logger.log(log)
189@option(NLPAutoRegressionConfigs.optimizer)
190def _optimizer(c: NLPAutoRegressionConfigs):
195    optimizer = OptimizerConfigs()
196    optimizer.parameters = c.model.parameters()
197    optimizer.optimizer = 'Adam'
198    optimizer.d_model = c.d_model
199
200    return optimizer

Get number of tokens

203@option(NLPAutoRegressionConfigs.n_tokens)
204def _n_tokens(c: NLPAutoRegressionConfigs):
208    return c.text.n_tokens

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

211@option(NLPAutoRegressionConfigs.tokenizer)
212def basic_english():
226    from torchtext.data import get_tokenizer
227    return get_tokenizer('basic_english')

Character level tokenizer

230def character_tokenizer(x: str):
234    return list(x)

Character level tokenizer configuration

237@option(NLPAutoRegressionConfigs.tokenizer)
238def character():
242    return character_tokenizer

Tiny Shakespeare dataset

It will download from the url if not present

245@option(NLPAutoRegressionConfigs.text)
246def tiny_shakespeare(c: NLPAutoRegressionConfigs):
252    return TextFileDataset(
253        lab.get_data_path() / 'tiny_shakespeare.txt',
254        c.tokenizer,
255        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

Sequential training data loader

258@option(NLPAutoRegressionConfigs.train_loader)
259def sequential_train_loader(c: NLPAutoRegressionConfigs):
263    return SequentialDataLoader(text=c.text.train,
264                                dataset=c.text,
265                                batch_size=c.batch_size,
266                                seq_len=c.seq_len)

Sequential validation data loader

269@option(NLPAutoRegressionConfigs.valid_loader)
270def sequential_valid_loader(c: NLPAutoRegressionConfigs):
274    return SequentialDataLoader(text=c.text.valid,
275                                dataset=c.text,
276                                batch_size=c.batch_size,
277                                seq_len=c.seq_len)

Transpose batch

DataLoader collects the batches on the first dimension. We need to transpose it to be sequence first.

280def transpose_batch(batch):
288    transposed_data = list(zip(*batch))

Stack the batch along the second dimension dim=1

290    src = torch.stack(transposed_data[0], dim=1)
291    tgt = torch.stack(transposed_data[1], dim=1)
292
293    return src, tgt

Shuffled training data loader

296@option(NLPAutoRegressionConfigs.train_loader)
297def shuffled_train_loader(c: NLPAutoRegressionConfigs):
301    dataset = SequentialUnBatchedDataset(text=c.text.train,
302                                         dataset=c.text,
303                                         seq_len=c.seq_len)
304    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
305
306    return DataLoader(dataset,
307                      batch_size=c.batch_size,
308                      collate_fn=transpose_batch,
309                      sampler=sampler)

Shuffled validation data loader

312@option(NLPAutoRegressionConfigs.valid_loader)
313def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
317    dataset = SequentialUnBatchedDataset(text=c.text.valid,
318                                         dataset=c.text,
319                                         seq_len=c.seq_len)
320    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
321
322    return DataLoader(dataset,
323                      batch_size=c.batch_size,
324                      collate_fn=transpose_batch,
325                      sampler=sampler)