NLP model trainer for classification

11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs

Trainer configurations

This has the basic configurations for NLP classification task training. All the properties are configurable.

30class NLPClassificationConfigs(TrainValidConfigs):

Optimizer

41    optimizer: torch.optim.Adam

Training device

43    device: torch.device = DeviceConfigs()

Autoregressive model

46    model: Module

Batch size

48    batch_size: int = 16

Length of the sequence, or context size

50    seq_len: int = 512

Vocabulary

52    vocab: Vocab = 'ag_news'

Number of token in vocabulary

54    n_tokens: int

Number of classes

56    n_classes: int = 'ag_news'

Tokenizer

58    tokenizer: Callable = 'character'

Whether to periodically save models

61    is_save_models = True

Loss function

64    loss_func = nn.CrossEntropyLoss()

Accuracy function

66    accuracy = Accuracy()

Model embedding size

68    d_model: int = 512

Gradient clipping

70    grad_norm_clip: float = 1.0

Training data loader

73    train_loader: DataLoader = 'ag_news'

Validation data loader

75    valid_loader: DataLoader = 'ag_news'

Initialization

77    def init(self):

Set tracker configurations

82        tracker.set_scalar("accuracy.*", True)
83        tracker.set_scalar("loss.*", True)

Add a hook to log module outputs

85        hook_model_outputs(self.mode, self.model, 'model')

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

90        self.state_modules = [self.accuracy]

Training or validation step

92    def step(self, batch: any, batch_idx: BatchIndex):

Move data to the device

98        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

101        if self.mode.is_train:
102            tracker.add_global_step(data.shape[1])

Whether to capture model outputs

105        with self.mode.update(is_log_activations=batch_idx.is_last):

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

109            output, *_ = self.model(data)

Calculate and log loss

112        loss = self.loss_func(output, target)
113        tracker.add("loss.", loss)

Calculate and log accuracy

116        self.accuracy(output, target)
117        self.accuracy.track()

Train the model

120        if self.mode.is_train:

Calculate gradients

122            loss.backward()

Clip gradients

124            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

126            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

128            if batch_idx.is_last:
129                tracker.add('model', self.model)

Clear the gradients

131            self.optimizer.zero_grad()

Save the tracked metrics

134        tracker.save()
137@option(NLPClassificationConfigs.optimizer)
138def _optimizer(c: NLPClassificationConfigs):
143    optimizer = OptimizerConfigs()
144    optimizer.parameters = c.model.parameters()
145    optimizer.optimizer = 'Adam'
146    optimizer.d_model = c.d_model
147
148    return optimizer

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

151@option(NLPClassificationConfigs.tokenizer)
152def basic_english():
166    from torchtext.data import get_tokenizer
167    return get_tokenizer('basic_english')

Character level tokenizer

170def character_tokenizer(x: str):
174    return list(x)

Character level tokenizer configuration

177@option(NLPClassificationConfigs.tokenizer)
178def character():
182    return character_tokenizer

Get number of tokens

185@option(NLPClassificationConfigs.n_tokens)
186def _n_tokens(c: NLPClassificationConfigs):
190    return len(c.vocab) + 2

Function to load data into batches

193class CollateFunc:
  • tokenizer is the tokenizer function
  • vocab is the vocabulary
  • seq_len is the length of the sequence
  • padding_token is the token used for padding when the seq_len is larger than the text length
  • classifier_token is the [CLS] token which we set at end of the input
198    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
206        self.classifier_token = classifier_token
207        self.padding_token = padding_token
208        self.seq_len = seq_len
209        self.vocab = vocab
210        self.tokenizer = tokenizer
  • batch is the batch of data collected by the DataLoader
212    def __call__(self, batch):

Input data tensor, initialized with padding_token

218        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)

Empty labels tensor

220        labels = torch.zeros(len(batch), dtype=torch.long)

Loop through the samples

223        for (i, (_label, _text)) in enumerate(batch):

Set the label

225            labels[i] = int(_label) - 1

Tokenize the input text

227            _text = [self.vocab[token] for token in self.tokenizer(_text)]

Truncate upto seq_len

229            _text = _text[:self.seq_len]

Transpose and add to data

231            data[:len(_text), i] = data.new_tensor(_text)

Set the final token in the sequence to [CLS]

234        data[-1, :] = self.classifier_token

237        return data, labels

AG News dataset

This loads the AG News dataset and the set the values for n_classes , vocab , train_loader , and valid_loader .

240@option([NLPClassificationConfigs.n_classes,
241         NLPClassificationConfigs.vocab,
242         NLPClassificationConfigs.train_loader,
243         NLPClassificationConfigs.valid_loader])
244def ag_news(c: NLPClassificationConfigs):

Get training and validation datasets

253    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))

Load data to memory

256    with monit.section('Load data'):
257        from labml_nn.utils import MapStyleDataset
260        train, valid = MapStyleDataset(train), MapStyleDataset(valid)

Get tokenizer

263    tokenizer = c.tokenizer

Create a counter

266    counter = Counter()

Collect tokens from training dataset

268    for (label, line) in train:
269        counter.update(tokenizer(line))

Collect tokens from validation dataset

271    for (label, line) in valid:
272        counter.update(tokenizer(line))

Create vocabulary

274    vocab = torchtext.vocab.vocab(counter, min_freq=1)

Create training data loader

277    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
278                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

Create validation data loader

280    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
281                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))

Return n_classes , vocab , train_loader , and valid_loader

284    return 4, vocab, train_loader, valid_loader