Auto-regressive NLP model trainer

11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader, RandomSampler
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs

Cross entropy loss

28class CrossEntropyLoss(Module):
33    def __init__(self):
34        super().__init__()
35        self.loss = nn.CrossEntropyLoss()
37    def forward(self, outputs, targets):
38        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))

Trainer configurations

This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.

41class NLPAutoRegressionConfigs(TrainValidConfigs):

Optimizer

52    optimizer: torch.optim.Adam

Training device

54    device: torch.device = DeviceConfigs()

Autoregressive model

57    model: Module

Text dataset

59    text: TextDataset

Batch size

61    batch_size: int = 16

Length of the sequence, or context size

63    seq_len: int = 512

Number of token in vocabulary

65    n_tokens: int

Tokenizer

67    tokenizer: Callable = 'character'

Text prompt to start sampling (for illustration)

70    prompt: str

The token separator when sampling (blank for character level tokenization)

72    prompt_separator: str

Whether to periodically save models

75    is_save_models = True

Loss function

78    loss_func = CrossEntropyLoss()

Accuracy function

80    accuracy = Accuracy()

Model embedding size

82    d_model: int = 512

Gradient clipping

84    grad_norm_clip: float = 1.0

Training data loader

87    train_loader: DataLoader = 'shuffled_train_loader'

Validation data loader

89    valid_loader: DataLoader = 'shuffled_valid_loader'

Data loaders shuffle with replacement

92    dataloader_shuffle_with_replacement: bool = False

Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

97    is_log_model_params_grads: bool = False

Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

102    is_log_model_activations: bool = False

Initialization

104    def init(self):

Set tracker configurations

109        tracker.set_scalar("accuracy.*", True)
110        tracker.set_scalar("loss.*", True)

Add a hook to log module outputs

112        hook_model_outputs(self.mode, self.model, 'model')

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

117        self.state_modules = [self.accuracy]

Override to calculate and log other metrics

119    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
121        pass

Training or validation step

123    def step(self, batch: any, batch_idx: BatchIndex):

Set training/eval mode

129        self.model.train(self.mode.is_train)

Move data to the device

132        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

135        if self.mode.is_train:
136            tracker.add_global_step(data.shape[0] * data.shape[1])

Whether to capture model outputs

139        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

143            output, *_ = self.model(data)

Calculate and log loss

146        loss = self.loss_func(output, target)
147        tracker.add("loss.", loss)

Calculate and log accuracy

150        self.accuracy(output, target)
151        self.accuracy.track()
152
153        self.other_metrics(output, target)

Train the model

156        if self.mode.is_train:

Calculate gradients

158            loss.backward()

Clip gradients

160            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

162            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

164            if batch_idx.is_last and self.is_log_model_params_grads:
165                tracker.add('model', self.model)

Clear the gradients

167            self.optimizer.zero_grad()

Save the tracked metrics

170        tracker.save()

Sampling function to generate samples periodically while training

172    def sample(self):

Starting prompt

178        prompt = self.prompt

Collect output for printing

180        log = [(prompt, Text.subtle)]

Sample 25 tokens

182        for i in monit.iterate('Sample', 25):

Tokenize the prompt

184            data = self.text.text_to_i(prompt).unsqueeze(-1)
185            data = data.to(self.device)

Get the model output

187            output, *_ = self.model(data)

Get the model prediction (greedy)

189            output = output.argmax(dim=-1).squeeze()

Add the prediction to prompt

191            prompt += self.prompt_separator + self.text.itos[output[-1]]

Add the prediction for logging

193            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

Print the sampled output

196        logger.log(log)
199@option(NLPAutoRegressionConfigs.optimizer)
200def _optimizer(c: NLPAutoRegressionConfigs):
205    optimizer = OptimizerConfigs()
206    optimizer.parameters = c.model.parameters()
207    optimizer.optimizer = 'Adam'
208    optimizer.d_model = c.d_model
209
210    return optimizer

Get number of tokens

213@option(NLPAutoRegressionConfigs.n_tokens)
214def _n_tokens(c: NLPAutoRegressionConfigs):
218    return c.text.n_tokens

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

221@option(NLPAutoRegressionConfigs.tokenizer)
222def basic_english():
236    from torchtext.data import get_tokenizer
237    return get_tokenizer('basic_english')

Character level tokenizer

240def character_tokenizer(x: str):
244    return list(x)

Character level tokenizer configuration

247@option(NLPAutoRegressionConfigs.tokenizer)
248def character():
252    return character_tokenizer

Tiny Shakespeare dataset

It will download from the url if not present

255@option(NLPAutoRegressionConfigs.text)
256def tiny_shakespeare(c: NLPAutoRegressionConfigs):
262    return TextFileDataset(
263        lab.get_data_path() / 'tiny_shakespeare.txt',
264        c.tokenizer,
265        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

Sequential training data loader

268@option(NLPAutoRegressionConfigs.train_loader)
269def sequential_train_loader(c: NLPAutoRegressionConfigs):
273    return SequentialDataLoader(text=c.text.train,
274                                dataset=c.text,
275                                batch_size=c.batch_size,
276                                seq_len=c.seq_len)

Sequential validation data loader

279@option(NLPAutoRegressionConfigs.valid_loader)
280def sequential_valid_loader(c: NLPAutoRegressionConfigs):
284    return SequentialDataLoader(text=c.text.valid,
285                                dataset=c.text,
286                                batch_size=c.batch_size,
287                                seq_len=c.seq_len)

Transpose batch

DataLoader collects the batches on the first dimension. We need to transpose it to be sequence first.

290def transpose_batch(batch):
298    transposed_data = list(zip(*batch))

Stack the batch along the second dimension dim=1

300    src = torch.stack(transposed_data[0], dim=1)
301    tgt = torch.stack(transposed_data[1], dim=1)
302
303    return src, tgt

Shuffled training data loader

306@option(NLPAutoRegressionConfigs.train_loader)
307def shuffled_train_loader(c: NLPAutoRegressionConfigs):
311    dataset = SequentialUnBatchedDataset(text=c.text.train,
312                                         dataset=c.text,
313                                         seq_len=c.seq_len)
314    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
315
316    return DataLoader(dataset,
317                      batch_size=c.batch_size,
318                      collate_fn=transpose_batch,
319                      sampler=sampler)

Shuffled validation data loader

322@option(NLPAutoRegressionConfigs.valid_loader)
323def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
327    dataset = SequentialUnBatchedDataset(text=c.text.valid,
328                                         dataset=c.text,
329                                         seq_len=c.seq_len)
330    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
331
332    return DataLoader(dataset,
333                      batch_size=c.batch_size,
334                      collate_fn=transpose_batch,
335                      sampler=sampler)