11from collections import Counter
12from typing import Callable
13
14import torch
15import torchtext
16from torch import nn
17from torch.utils.data import DataLoader
18import torchtext.vocab
19from torchtext.vocab import Vocab
20
21from labml import lab, tracker, monit
22from labml.configs import option
23from labml_helpers.device import DeviceConfigs
24from labml_helpers.metrics.accuracy import Accuracy
25from labml_helpers.module import Module
26from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
27from labml_nn.optimizers.configs import OptimizerConfigsThis has the basic configurations for NLP classification task training. All the properties are configurable.
30class NLPClassificationConfigs(TrainValidConfigs):Optimizer
41    optimizer: torch.optim.AdamTraining device
43    device: torch.device = DeviceConfigs()Autoregressive model
46    model: ModuleBatch size
48    batch_size: int = 16Length of the sequence, or context size
50    seq_len: int = 512Vocabulary
52    vocab: Vocab = 'ag_news'Number of token in vocabulary
54    n_tokens: intNumber of classes
56    n_classes: int = 'ag_news'Tokenizer
58    tokenizer: Callable = 'character'Whether to periodically save models
61    is_save_models = TrueLoss function
64    loss_func = nn.CrossEntropyLoss()Accuracy function
66    accuracy = Accuracy()Model embedding size
68    d_model: int = 512Gradient clipping
70    grad_norm_clip: float = 1.0Training data loader
73    train_loader: DataLoader = 'ag_news'Validation data loader
75    valid_loader: DataLoader = 'ag_news'Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
80    is_log_model_params_grads: bool = FalseWhether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
85    is_log_model_activations: bool = False87    def init(self):Set tracker configurations
92        tracker.set_scalar("accuracy.*", True)
93        tracker.set_scalar("loss.*", True)Add a hook to log module outputs
95        hook_model_outputs(self.mode, self.model, 'model')Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
100        self.state_modules = [self.accuracy]102    def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
108        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
111        if self.mode.is_train:
112            tracker.add_global_step(data.shape[1])Whether to capture model outputs
115        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
119            output, *_ = self.model(data)Calculate and log loss
122        loss = self.loss_func(output, target)
123        tracker.add("loss.", loss)Calculate and log accuracy
126        self.accuracy(output, target)
127        self.accuracy.track()Train the model
130        if self.mode.is_train:Calculate gradients
132            loss.backward()Clip gradients
134            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
136            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
138            if batch_idx.is_last and self.is_log_model_params_grads:
139                tracker.add('model', self.model)Clear the gradients
141            self.optimizer.zero_grad()Save the tracked metrics
144        tracker.save()147@option(NLPClassificationConfigs.optimizer)
148def _optimizer(c: NLPClassificationConfigs):153    optimizer = OptimizerConfigs()
154    optimizer.parameters = c.model.parameters()
155    optimizer.optimizer = 'Adam'
156    optimizer.d_model = c.d_model
157
158    return optimizerWe use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',in the configurations dictionary when starting the experiment.
161@option(NLPClassificationConfigs.tokenizer)
162def basic_english():176    from torchtext.data import get_tokenizer
177    return get_tokenizer('basic_english')180def character_tokenizer(x: str):184    return list(x)Character level tokenizer configuration
187@option(NLPClassificationConfigs.tokenizer)
188def character():192    return character_tokenizerGet number of tokens
195@option(NLPClassificationConfigs.n_tokens)
196def _n_tokens(c: NLPClassificationConfigs):200    return len(c.vocab) + 2203class CollateFunc:tokenizer
 is the tokenizer function vocab
 is the vocabulary seq_len
 is the length of the sequence padding_token
 is the token used for padding when the seq_len
 is larger than the text length classifier_token
 is the [CLS]
 token which we set at end of the input208    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):216        self.classifier_token = classifier_token
217        self.padding_token = padding_token
218        self.seq_len = seq_len
219        self.vocab = vocab
220        self.tokenizer = tokenizerbatch
 is the batch of data collected by the DataLoader
222    def __call__(self, batch):Input data tensor, initialized with padding_token
 
228        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)Empty labels tensor
230        labels = torch.zeros(len(batch), dtype=torch.long)Loop through the samples
233        for (i, (_label, _text)) in enumerate(batch):Set the label
235            labels[i] = int(_label) - 1Tokenize the input text
237            _text = [self.vocab[token] for token in self.tokenizer(_text)]Truncate upto seq_len
 
239            _text = _text[:self.seq_len]Transpose and add to data
241            data[:len(_text), i] = data.new_tensor(_text)Set the final token in the sequence to [CLS]
 
244        data[-1, :] = self.classifier_token247        return data, labelsThis loads the AG News dataset and the set the values for  n_classes
, vocab
, train_loader
, and valid_loader
.
250@option([NLPClassificationConfigs.n_classes,
251         NLPClassificationConfigs.vocab,
252         NLPClassificationConfigs.train_loader,
253         NLPClassificationConfigs.valid_loader])
254def ag_news(c: NLPClassificationConfigs):Get training and validation datasets
263    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))Load data to memory
266    with monit.section('Load data'):
267        from labml_nn.utils import MapStyleDatasetCreate map-style datasets
270        train, valid = MapStyleDataset(train), MapStyleDataset(valid)Get tokenizer
273    tokenizer = c.tokenizerCreate a counter
276    counter = Counter()Collect tokens from training dataset
278    for (label, line) in train:
279        counter.update(tokenizer(line))Collect tokens from validation dataset
281    for (label, line) in valid:
282        counter.update(tokenizer(line))Create vocabulary
284    vocab = torchtext.vocab.vocab(counter, min_freq=1)Create training data loader
287    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
288                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))Create validation data loader
290    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
291                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))Return n_classes
, vocab
, train_loader
, and valid_loader
 
294    return 4, vocab, train_loader, valid_loader