11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigs
This has the basic configurations for NLP classification task training. All the properties are configurable.
30class NLPClassificationConfigs(TrainValidConfigs):
Optimizer
41 optimizer: torch.optim.Adam
Training device
43 device: torch.device = DeviceConfigs()
Autoregressive model
46 model: Module
Batch size
48 batch_size: int = 16
Length of the sequence, or context size
50 seq_len: int = 512
Vocabulary
52 vocab: Vocab = 'ag_news'
Number of token in vocabulary
54 n_tokens: int
Number of classes
56 n_classes: int = 'ag_news'
Tokenizer
58 tokenizer: Callable = 'character'
Whether to periodically save models
61 is_save_models = True
Loss function
64 loss_func = nn.CrossEntropyLoss()
Accuracy function
66 accuracy = Accuracy()
Model embedding size
68 d_model: int = 512
Gradient clipping
70 grad_norm_clip: float = 1.0
Training data loader
73 train_loader: DataLoader = 'ag_news'
Validation data loader
75 valid_loader: DataLoader = 'ag_news'
77 def init(self):
Set tracker configurations
82 tracker.set_scalar("accuracy.*", True)
83 tracker.set_scalar("loss.*", True)
Add a hook to log module outputs
85 hook_model_outputs(self.mode, self.model, 'model')
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
90 self.state_modules = [self.accuracy]
92 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
98 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
101 if self.mode.is_train:
102 tracker.add_global_step(data.shape[1])
Whether to capture model outputs
105 with self.mode.update(is_log_activations=batch_idx.is_last):
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
109 output, *_ = self.model(data)
Calculate and log loss
112 loss = self.loss_func(output, target)
113 tracker.add("loss.", loss)
Calculate and log accuracy
116 self.accuracy(output, target)
117 self.accuracy.track()
Train the model
120 if self.mode.is_train:
Calculate gradients
122 loss.backward()
Clip gradients
124 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
126 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
128 if batch_idx.is_last:
129 tracker.add('model', self.model)
Clear the gradients
131 self.optimizer.zero_grad()
Save the tracked metrics
134 tracker.save()
137@option(NLPClassificationConfigs.optimizer)
138def _optimizer(c: NLPClassificationConfigs):
143 optimizer = OptimizerConfigs()
144 optimizer.parameters = c.model.parameters()
145 optimizer.optimizer = 'Adam'
146 optimizer.d_model = c.d_model
147
148 return optimizer
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
in the configurations dictionary when starting the experiment.
151@option(NLPClassificationConfigs.tokenizer)
152def basic_english():
166 from torchtext.data import get_tokenizer
167 return get_tokenizer('basic_english')
170def character_tokenizer(x: str):
174 return list(x)
Character level tokenizer configuration
177@option(NLPClassificationConfigs.tokenizer)
178def character():
182 return character_tokenizer
Get number of tokens
185@option(NLPClassificationConfigs.n_tokens)
186def _n_tokens(c: NLPClassificationConfigs):
190 return len(c.vocab) + 2
193class CollateFunc:
tokenizer
is the tokenizer function vocab
is the vocabulary seq_len
is the length of the sequence padding_token
is the token used for padding when the seq_len
is larger than the text length classifier_token
is the [CLS]
token which we set at end of the input198 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
206 self.classifier_token = classifier_token
207 self.padding_token = padding_token
208 self.seq_len = seq_len
209 self.vocab = vocab
210 self.tokenizer = tokenizer
batch
is the batch of data collected by the DataLoader
212 def __call__(self, batch):
Input data tensor, initialized with padding_token
218 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
Empty labels tensor
220 labels = torch.zeros(len(batch), dtype=torch.long)
Loop through the samples
223 for (i, (_label, _text)) in enumerate(batch):
Set the label
225 labels[i] = int(_label) - 1
Tokenize the input text
227 _text = [self.vocab[token] for token in self.tokenizer(_text)]
Truncate upto seq_len
229 _text = _text[:self.seq_len]
Transpose and add to data
231 data[:len(_text), i] = data.new_tensor(_text)
Set the final token in the sequence to [CLS]
234 data[-1, :] = self.classifier_token
237 return data, labels
This loads the AG News dataset and the set the values for n_classes
, vocab
, train_loader
, and valid_loader
.
240@option([NLPClassificationConfigs.n_classes,
241 NLPClassificationConfigs.vocab,
242 NLPClassificationConfigs.train_loader,
243 NLPClassificationConfigs.valid_loader])
244def ag_news(c: NLPClassificationConfigs):
Get training and validation datasets
253 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
Load data to memory
256 with monit.section('Load data'):
257 from labml_nn.utils import MapStyleDataset
Create map-style datasets
260 train, valid = MapStyleDataset(train), MapStyleDataset(valid)
Get tokenizer
263 tokenizer = c.tokenizer
Create a counter
266 counter = Counter()
Collect tokens from training dataset
268 for (label, line) in train:
269 counter.update(tokenizer(line))
Collect tokens from validation dataset
271 for (label, line) in valid:
272 counter.update(tokenizer(line))
Create vocabulary
274 vocab = torchtext.vocab.vocab(counter, min_freq=1)
Create training data loader
277 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
278 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
Create validation data loader
280 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
281 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
Return n_classes
, vocab
, train_loader
, and valid_loader
284 return 4, vocab, train_loader, valid_loader