11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18from torchtext.vocab import Vocab
19
20from labml import lab, tracker, monit
21from labml.configs import option
22from labml_helpers.device import DeviceConfigs
23from labml_helpers.metrics.accuracy import Accuracy
24from labml_helpers.module import Module
25from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
26from labml_nn.optimizers.configs import OptimizerConfigsThis has the basic configurations for NLP classification task training. All the properties are configurable.
29class NLPClassificationConfigs(TrainValidConfigs):Optimizer
40 optimizer: torch.optim.AdamTraining device
42 device: torch.device = DeviceConfigs()Autoregressive model
45 model: ModuleBatch size
47 batch_size: int = 16Length of the sequence, or context size
49 seq_len: int = 512Vocabulary
51 vocab: Vocab = 'ag_news'Number of token in vocabulary
53 n_tokens: intNumber of classes
55 n_classes: int = 'ag_news'Tokenizer
57 tokenizer: Callable = 'character'Whether to periodically save models
60 is_save_models = TrueLoss function
63 loss_func = nn.CrossEntropyLoss()Accuracy function
65 accuracy = Accuracy()Model embedding size
67 d_model: int = 512Gradient clipping
69 grad_norm_clip: float = 1.0Training data loader
72 train_loader: DataLoader = 'ag_news'Validation data loader
74 valid_loader: DataLoader = 'ag_news'76 def init(self):Set tracker configurations
81 tracker.set_scalar("accuracy.*", True)
82 tracker.set_scalar("loss.*", True)Add a hook to log module outputs
84 hook_model_outputs(self.mode, self.model, 'model')Add accuracy as a state module. The name is probably confusing, since it’s meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
89 self.state_modules = [self.accuracy]91 def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
97 data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
100 if self.mode.is_train:
101 tracker.add_global_step(data.shape[1])Whether to capture model outputs
104 with self.mode.update(is_log_activations=batch_idx.is_last):Get model outputs. It’s returning a tuple for states when using RNNs. This is not implemented yet. 😜
108 output, *_ = self.model(data)Calculate and log loss
111 loss = self.loss_func(output, target)
112 tracker.add("loss.", loss)Calculate and log accuracy
115 self.accuracy(output, target)
116 self.accuracy.track()Train the model
119 if self.mode.is_train:Calculate gradients
121 loss.backward()Clip gradients
123 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
125 self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
127 if batch_idx.is_last:
128 tracker.add('model', self.model)Clear the gradients
130 self.optimizer.zero_grad()Save the tracked metrics
133 tracker.save()136@option(NLPClassificationConfigs.optimizer)
137def _optimizer(c: NLPClassificationConfigs):142 optimizer = OptimizerConfigs()
143 optimizer.parameters = c.model.parameters()
144 optimizer.optimizer = 'Adam'
145 optimizer.d_model = c.d_model
146
147 return optimizerWe use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
as the configurations dictionary when starting the experiment.
150@option(NLPClassificationConfigs.tokenizer)
151def basic_english():165 from torchtext.data import get_tokenizer
166 return get_tokenizer('basic_english')169def character_tokenizer(x: str):173 return list(x)Character level tokenizer configuration
176@option(NLPClassificationConfigs.tokenizer)
177def character():181 return character_tokenizerGet number of tokens
184@option(NLPClassificationConfigs.n_tokens)
185def _n_tokens(c: NLPClassificationConfigs):189 return len(c.vocab) + 2192class CollateFunc:tokenizer is the tokenizer functionvocab is the vocabularyseq_len is the length of the sequencepadding_token is the token used for padding when the seq_len is larger than the text lengthclassifier_token is the [CLS] token which we set at end of the input197 def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):205 self.classifier_token = classifier_token
206 self.padding_token = padding_token
207 self.seq_len = seq_len
208 self.vocab = vocab
209 self.tokenizer = tokenizerbatch is the batch of data collected by the DataLoader211 def __call__(self, batch):Input data tensor, initialized with padding_token
217 data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)Empty labels tensor
219 labels = torch.zeros(len(batch), dtype=torch.long)Loop through the samples
222 for (i, (_label, _text)) in enumerate(batch):Set the label
224 labels[i] = int(_label) - 1Tokenize the input text
226 _text = [self.vocab[token] for token in self.tokenizer(_text)]Truncate upto seq_len
228 _text = _text[:self.seq_len]Transpose and add to data
230 data[:len(_text), i] = data.new_tensor(_text)Set the final token in the sequence to [CLS]
233 data[-1, :] = self.classifier_token236 return data, labelsThis loads the AG News dataset and the set the values for
n_classes',vocab,train_loader, andvalid_loader`.
239@option([NLPClassificationConfigs.n_classes,
240 NLPClassificationConfigs.vocab,
241 NLPClassificationConfigs.train_loader,
242 NLPClassificationConfigs.valid_loader])
243def ag_news(c: NLPClassificationConfigs):Get training and validation datasets
252 train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))Load data to memory
255 with monit.section('Load data'):
256 from labml_nn.utils import MapStyleDatasetCreate map-style datasets
259 train, valid = MapStyleDataset(train), MapStyleDataset(valid)Get tokenizer
262 tokenizer = c.tokenizerCreate a counter
265 counter = Counter()Collect tokens from training dataset
267 for (label, line) in train:
268 counter.update(tokenizer(line))Collect tokens from validation dataset
270 for (label, line) in valid:
271 counter.update(tokenizer(line))Create vocabulary
273 vocab = Vocab(counter, min_freq=1)Create training data loader
276 train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
277 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))Create validation data loader
279 valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
280 collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))Return n_classes',vocab,train_loader, andvalid_loader`
283 return 4, vocab, train_loader, valid_loader