11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18from torchtext.vocab import Vocab
19
20from labml import lab, tracker, monit
21from labml.configs import option
22from labml_helpers.device import DeviceConfigs
23from labml_helpers.metrics.accuracy import Accuracy
24from labml_helpers.module import Module
25from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
26from labml_nn.optimizers.configs import OptimizerConfigs
This has the basic configurations for NLP classification task training. All the properties are configurable.
29class NLPClassificationConfigs(TrainValidConfigs):
Optimizer
40 optimizer: torch.optim.Adam
Training device
42 device: torch.device = DeviceConfigs()
Autoregressive model
45 model: Module
Batch size
47 batch_size: int = 16
Length of the sequence, or context size
49 seq_len: int = 512
Vocabulary
51 vocab: Vocab = 'ag_news'
Number of token in vocabulary
53 n_tokens: int
Number of classes
55 n_classes: int = 'ag_news'
Tokenizer
57 tokenizer: Callable = 'character'
Whether to periodically save models
60 is_save_models = True
Loss function
63 loss_func = nn.CrossEntropyLoss()
Accuracy function
65 accuracy = Accuracy()
Model embedding size
67 d_model: int = 512
Gradient clipping
69 grad_norm_clip: float = 1.0
Training data loader
72 train_loader: DataLoader = 'ag_news'
Validation data loader
74 valid_loader: DataLoader = 'ag_news'
76 def init(self):
Set tracker configurations
81 tracker.set_scalar("accuracy.*", True)
82 tracker.set_scalar("loss.*", True)
Add a hook to log module outputs
84 hook_model_outputs(self.mode, self.model, 'model')
Add accuracy as a state module. The name is probably confusing, since it’s meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
89 self.state_modules = [self.accuracy]
91 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
97 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
100 if self.mode.is_train:
101 tracker.add_global_step(data.shape[1])
Whether to capture model outputs
104 with self.mode.update(is_log_activations=batch_idx.is_last):
Get model outputs. It’s returning a tuple for states when using RNNs. This is not implemented yet. 😜
108 output, *_ = self.model(data)
Calculate and log loss
111 loss = self.loss_func(output, target)
112 tracker.add("loss.", loss)
Calculate and log accuracy
115 self.accuracy(output, target)
116 self.accuracy.track()
Train the model
119 if self.mode.is_train:
Calculate gradients
121 loss.backward()
Clip gradients
123 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
125 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
127 if batch_idx.is_last:
128 tracker.add('model', self.model)
Clear the gradients
130 self.optimizer.zero_grad()
Save the tracked metrics
133 tracker.save()
136@option(NLPClassificationConfigs.optimizer)
137def _optimizer(c: NLPClassificationConfigs):
142 optimizer = OptimizerConfigs()
143 optimizer.parameters = c.model.parameters()
144 optimizer.optimizer = 'Adam'
145 optimizer.d_model = c.d_model
146
147 return optimizer
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
as the configurations dictionary when starting the experiment.
150@option(NLPClassificationConfigs.tokenizer)
151def basic_english():
165 from torchtext.data import get_tokenizer
166 return get_tokenizer('basic_english')
169def character_tokenizer(x: str):
173 return list(x)
Character level tokenizer configuration
176@option(NLPClassificationConfigs.tokenizer)
177def character():
181 return character_tokenizer
Get number of tokens
184@option(NLPClassificationConfigs.n_tokens)
185def _n_tokens(c: NLPClassificationConfigs):
189 return len(c.vocab) + 2
192class CollateFunc:
tokenizer
is the tokenizer functionvocab
is the vocabularyseq_len
is the length of the sequencepadding_token
is the token used for padding when the seq_len
is larger than the text lengthclassifier_token
is the [CLS]
token which we set at end of the input197 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
205 self.classifier_token = classifier_token
206 self.padding_token = padding_token
207 self.seq_len = seq_len
208 self.vocab = vocab
209 self.tokenizer = tokenizer
batch
is the batch of data collected by the DataLoader
211 def __call__(self, batch):
Input data tensor, initialized with padding_token
217 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
Empty labels tensor
219 labels = torch.zeros(len(batch), dtype=torch.long)
Loop through the samples
222 for (i, (_label, _text)) in enumerate(batch):
Set the label
224 labels[i] = int(_label) - 1
Tokenize the input text
226 _text = [self.vocab[token] for token in self.tokenizer(_text)]
Truncate upto seq_len
228 _text = _text[:self.seq_len]
Transpose and add to data
230 data[:len(_text), i] = data.new_tensor(_text)
Set the final token in the sequence to [CLS]
233 data[-1, :] = self.classifier_token
236 return data, labels
This loads the AG News dataset and the set the values for
n_classes',
vocab,
train_loader, and
valid_loader`.
239@option([NLPClassificationConfigs.n_classes,
240 NLPClassificationConfigs.vocab,
241 NLPClassificationConfigs.train_loader,
242 NLPClassificationConfigs.valid_loader])
243def ag_news(c: NLPClassificationConfigs):
Get training and validation datasets
252 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
Load data to memory
255 with monit.section('Load data'):
256 from labml_nn.utils import MapStyleDataset
Create map-style datasets
259 train, valid = MapStyleDataset(train), MapStyleDataset(valid)
Get tokenizer
262 tokenizer = c.tokenizer
Create a counter
265 counter = Counter()
Collect tokens from training dataset
267 for (label, line) in train:
268 counter.update(tokenizer(line))
Collect tokens from validation dataset
270 for (label, line) in valid:
271 counter.update(tokenizer(line))
Create vocabulary
273 vocab = Vocab(counter, min_freq=1)
Create training data loader
276 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
277 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
Create validation data loader
279 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
280 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
Return n_classes',
vocab,
train_loader, and
valid_loader`
283 return 4, vocab, train_loader, valid_loader