Auto-regressive NLP model trainer

11from typing import Callable
12
13import torch
14import torch.nn as nn
15from labml import lab, monit, logger, tracker
16from labml.configs import option
17from labml.logger import Text
18from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
19from labml_nn.helpers.device import DeviceConfigs
20from labml_nn.helpers.metrics import Accuracy
21from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
22from labml_nn.optimizers.configs import OptimizerConfigs
23from torch.utils.data import DataLoader, RandomSampler

Cross entropy loss

26class CrossEntropyLoss(nn.Module):
31    def __init__(self):
32        super().__init__()
33        self.loss = nn.CrossEntropyLoss()
35    def forward(self, outputs, targets):
36        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))

Trainer configurations

This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.

39class NLPAutoRegressionConfigs(TrainValidConfigs):

Optimizer

50    optimizer: torch.optim.Adam

Training device

52    device: torch.device = DeviceConfigs()

Autoregressive model

55    model: nn.Module

Text dataset

57    text: TextDataset

Batch size

59    batch_size: int = 16

Length of the sequence, or context size

61    seq_len: int = 512

Number of token in vocabulary

63    n_tokens: int

Tokenizer

65    tokenizer: Callable = 'character'

Text prompt to start sampling (for illustration)

68    prompt: str

The token separator when sampling (blank for character level tokenization)

70    prompt_separator: str

Whether to periodically save models

73    is_save_models = True

Loss function

76    loss_func = CrossEntropyLoss()

Accuracy function

78    accuracy = Accuracy()

Model embedding size

80    d_model: int = 512

Gradient clipping

82    grad_norm_clip: float = 1.0

Training data loader

85    train_loader: DataLoader = 'shuffled_train_loader'

Validation data loader

87    valid_loader: DataLoader = 'shuffled_valid_loader'

Data loaders shuffle with replacement

90    dataloader_shuffle_with_replacement: bool = False

Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

95    is_log_model_params_grads: bool = False

Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

100    is_log_model_activations: bool = False

Initialization

102    def init(self):

Set tracker configurations

107        tracker.set_scalar("accuracy.*", True)
108        tracker.set_scalar("loss.*", True)
109        tracker.set_text("sampled", False)

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

114        self.state_modules = [self.accuracy]

Override to calculate and log other metrics

116    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
118        pass

Training or validation step

120    def step(self, batch: any, batch_idx: BatchIndex):

Set training/eval mode

126        self.model.train(self.mode.is_train)

Move data to the device

129        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

132        if self.mode.is_train:
133            tracker.add_global_step(data.shape[0] * data.shape[1])

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

138        output, *_ = self.model(data)

Calculate and log loss

141        loss = self.loss_func(output, target)
142        tracker.add("loss.", loss)

Calculate and log accuracy

145        self.accuracy(output, target)
146        self.accuracy.track()
147
148        self.other_metrics(output, target)

Train the model

151        if self.mode.is_train:

Calculate gradients

153            loss.backward()

Clip gradients

155            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

157            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

159            if batch_idx.is_last and self.is_log_model_params_grads:
160                tracker.add('model', self.model)

Clear the gradients

162            self.optimizer.zero_grad()

Save the tracked metrics

165        tracker.save()

Sampling function to generate samples periodically while training

167    def sample(self):

Starting prompt

173        prompt = self.prompt

Collect output for printing

175        log = [(prompt, Text.subtle)]

Sample 25 tokens

177        for i in monit.iterate('Sample', 25):

Tokenize the prompt

179            data = self.text.text_to_i(prompt).unsqueeze(-1)
180            data = data.to(self.device)

Get the model output

182            output, *_ = self.model(data)

Get the model prediction (greedy)

184            output = output.argmax(dim=-1).squeeze()

Add the prediction to prompt

186            prompt += self.prompt_separator + self.text.itos[output[-1]]

Add the prediction for logging

188            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
189
190        tracker.add({'sampled': prompt})

Print the sampled output

192        logger.log(log)
195@option(NLPAutoRegressionConfigs.optimizer)
196def _optimizer(c: NLPAutoRegressionConfigs):
201    optimizer = OptimizerConfigs()
202    optimizer.parameters = c.model.parameters()
203    optimizer.optimizer = 'Adam'
204    optimizer.d_model = c.d_model
205
206    return optimizer

Get number of tokens

209@option(NLPAutoRegressionConfigs.n_tokens)
210def _n_tokens(c: NLPAutoRegressionConfigs):
214    return c.text.n_tokens

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

217@option(NLPAutoRegressionConfigs.tokenizer)
218def basic_english():
232    from torchtext.data import get_tokenizer
233    return get_tokenizer('basic_english')

Character level tokenizer

236def character_tokenizer(x: str):
240    return list(x)

Character level tokenizer configuration

243@option(NLPAutoRegressionConfigs.tokenizer)
244def character():
248    return character_tokenizer

Tiny Shakespeare dataset

It will download from the url if not present

251@option(NLPAutoRegressionConfigs.text)
252def tiny_shakespeare(c: NLPAutoRegressionConfigs):
258    return TextFileDataset(
259        lab.get_data_path() / 'tiny_shakespeare.txt',
260        c.tokenizer,
261        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

Sequential training data loader

264@option(NLPAutoRegressionConfigs.train_loader)
265def sequential_train_loader(c: NLPAutoRegressionConfigs):
269    return SequentialDataLoader(text=c.text.train,
270                                dataset=c.text,
271                                batch_size=c.batch_size,
272                                seq_len=c.seq_len)

Sequential validation data loader

275@option(NLPAutoRegressionConfigs.valid_loader)
276def sequential_valid_loader(c: NLPAutoRegressionConfigs):
280    return SequentialDataLoader(text=c.text.valid,
281                                dataset=c.text,
282                                batch_size=c.batch_size,
283                                seq_len=c.seq_len)

Transpose batch

DataLoader collects the batches on the first dimension. We need to transpose it to be sequence first.

286def transpose_batch(batch):
294    transposed_data = list(zip(*batch))

Stack the batch along the second dimension dim=1

296    src = torch.stack(transposed_data[0], dim=1)
297    tgt = torch.stack(transposed_data[1], dim=1)
298
299    return src, tgt

Shuffled training data loader

302@option(NLPAutoRegressionConfigs.train_loader)
303def shuffled_train_loader(c: NLPAutoRegressionConfigs):
307    dataset = SequentialUnBatchedDataset(text=c.text.train,
308                                         dataset=c.text,
309                                         seq_len=c.seq_len)
310    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
311
312    return DataLoader(dataset,
313                      batch_size=c.batch_size,
314                      collate_fn=transpose_batch,
315                      sampler=sampler)

Shuffled validation data loader

318@option(NLPAutoRegressionConfigs.valid_loader)
319def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
323    dataset = SequentialUnBatchedDataset(text=c.text.valid,
324                                         dataset=c.text,
325                                         seq_len=c.seq_len)
326    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
327
328    return DataLoader(dataset,
329                      batch_size=c.batch_size,
330                      collate_fn=transpose_batch,
331                      sampler=sampler)