Auto-regressive NLP model trainer

11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs

Cross entropy loss

28class CrossEntropyLoss(Module):
33    def __init__(self):
34        super().__init__()
35        self.loss = nn.CrossEntropyLoss()
37    def forward(self, outputs, targets):
38        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))

Trainer configurations

This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.

41class NLPAutoRegressionConfigs(TrainValidConfigs):

Optimizer

51    optimizer: torch.optim.Adam

Training device

53    device: torch.device = DeviceConfigs()

Autoregressive model

56    model: Module

Text dataset

58    text: TextDataset

Batch size

60    batch_size: int = 16

Length of the sequence, or context size

62    seq_len: int = 512

Number of token in vocabulary

64    n_tokens: int

Tokenizer

66    tokenizer: Callable = 'character'

Text prompt to start sampling (for illustration)

69    prompt: str

The token separator when sampling (blank for character level tokenization)

71    prompt_separator: str

Whether to periodically save models

74    is_save_models = True

Loss function

77    loss_func = CrossEntropyLoss()

Accuracy function

79    accuracy = Accuracy()

Model embedding size

81    d_model: int = 512

Gradient clipping

83    grad_norm_clip: float = 1.0

Training data loader

86    train_loader: DataLoader = 'shuffled_train_loader'

Validation data loader

88    valid_loader: DataLoader = 'shuffled_valid_loader'

Initialization

90    def init(self):

Set tracker configurations

95        tracker.set_scalar("accuracy.*", True)
96        tracker.set_scalar("loss.*", True)

Add a hook to log module outputs

98        hook_model_outputs(self.mode, self.model, 'model')

Add accuracy as a state module. The name is probably confusing, since it’s meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

103        self.state_modules = [self.accuracy]

Override to calculate and log other metrics

105    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
107        pass

Training or validation step

109    def step(self, batch: any, batch_idx: BatchIndex):

Set training/eval mode

115        self.model.train(self.mode.is_train)

Move data to the device

118        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

121        if self.mode.is_train:
122            tracker.add_global_step(data.shape[0] * data.shape[1])

Whether to capture model outputs

125        with self.mode.update(is_log_activations=batch_idx.is_last):

Get model outputs. It’s returning a tuple for states when using RNNs. This is not implemented yet. 😜

129            output, *_ = self.model(data)

Calculate and log loss

132        loss = self.loss_func(output, target)
133        tracker.add("loss.", loss)

Calculate and log accuracy

136        self.accuracy(output, target)
137        self.accuracy.track()
138
139        self.other_metrics(output, target)

Train the model

142        if self.mode.is_train:

Calculate gradients

144            loss.backward()

Clip gradients

146            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

148            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

150            if batch_idx.is_last:
151                tracker.add('model', self.model)

Clear the gradients

153            self.optimizer.zero_grad()

Save the tracked metrics

156        tracker.save()

Sampling function to generate samples periodically while training

158    def sample(self):

Starting prompt

164        prompt = self.prompt

Collect output for printing

166        log = [(prompt, Text.subtle)]

Sample 25 tokens

168        for i in monit.iterate('Sample', 25):

Tokenize the prompt

170            data = self.text.text_to_i(prompt).unsqueeze(-1)
171            data = data.to(self.device)

Get the model output

173            output, *_ = self.model(data)

Get the model prediction (greedy)

175            output = output.argmax(dim=-1).squeeze()

Add the prediction to prompt

177            prompt += self.prompt_separator + self.text.itos[output[-1]]

Add the prediction for logging

179            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

Print the sampled output

182        logger.log(log)
185@option(NLPAutoRegressionConfigs.optimizer)
186def _optimizer(c: NLPAutoRegressionConfigs):
191    optimizer = OptimizerConfigs()
192    optimizer.parameters = c.model.parameters()
193    optimizer.optimizer = 'Adam'
194    optimizer.d_model = c.d_model
195
196    return optimizer

Get number of tokens

199@option(NLPAutoRegressionConfigs.n_tokens)
200def _n_tokens(c: NLPAutoRegressionConfigs):
204    return c.text.n_tokens

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

    'tokenizer': 'basic_english',

as the configurations dictionary when starting the experiment.

207@option(NLPAutoRegressionConfigs.tokenizer)
208def basic_english():
222    from torchtext.data import get_tokenizer
223    return get_tokenizer('basic_english')

Character level tokenizer

226def character_tokenizer(x: str):
230    return list(x)

Character level tokenizer configuration

233@option(NLPAutoRegressionConfigs.tokenizer)
234def character():
238    return character_tokenizer

Tiny Shakespeare dataset

It will download from the url if not present

241@option(NLPAutoRegressionConfigs.text)
242def tiny_shakespeare(c: NLPAutoRegressionConfigs):
248    return TextFileDataset(
249        lab.get_data_path() / 'tiny_shakespeare.txt',
250        c.tokenizer,
251        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

Sequential training data loader

254@option(NLPAutoRegressionConfigs.train_loader)
255def sequential_train_loader(c: NLPAutoRegressionConfigs):
259    return SequentialDataLoader(text=c.text.train,
260                                dataset=c.text,
261                                batch_size=c.batch_size,
262                                seq_len=c.seq_len)

Sequential validation data loader

265@option(NLPAutoRegressionConfigs.valid_loader)
266def sequential_valid_loader(c: NLPAutoRegressionConfigs):
270    return SequentialDataLoader(text=c.text.valid,
271                                dataset=c.text,
272                                batch_size=c.batch_size,
273                                seq_len=c.seq_len)

Transpose batch

DataLoader collects the batches on the first dimension. We need to transpose it to be sequence first.

276def transpose_batch(batch):
284    transposed_data = list(zip(*batch))

Stack the batch along the second dimension dim=1

286    src = torch.stack(transposed_data[0], dim=1)
287    tgt = torch.stack(transposed_data[1], dim=1)
288
289    return src, tgt

Shuffled training data loader

292@option(NLPAutoRegressionConfigs.train_loader)
293def shuffled_train_loader(c: NLPAutoRegressionConfigs):
297    return DataLoader(SequentialUnBatchedDataset(text=c.text.train,
298                                                 dataset=c.text,
299                                                 seq_len=c.seq_len),
300                      batch_size=c.batch_size,
301                      collate_fn=transpose_batch,
302                      shuffle=True)

Shuffled validation data loader

305@option(NLPAutoRegressionConfigs.valid_loader)
306def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
310    return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
311                                                 dataset=c.text,
312                                                 seq_len=c.seq_len),
313                      batch_size=c.batch_size,
314                      collate_fn=transpose_batch,
315                      shuffle=True)