11from collections import Counter
12from typing import Callable
13
14import torchtext
15import torchtext.vocab
16from torchtext.vocab import Vocab
17
18import torch
19from labml import lab, tracker, monit
20from labml.configs import option
21from labml_nn.helpers.device import DeviceConfigs
22from labml_nn.helpers.metrics import Accuracy
23from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
24from labml_nn.optimizers.configs import OptimizerConfigs
25from torch import nn
26from torch.utils.data import DataLoaderThis has the basic configurations for NLP classification task training. All the properties are configurable.
29class NLPClassificationConfigs(TrainValidConfigs):Optimizer
40    optimizer: torch.optim.AdamTraining device
42    device: torch.device = DeviceConfigs()Autoregressive model
45    model: nn.ModuleBatch size
47    batch_size: int = 16Length of the sequence, or context size
49    seq_len: int = 512Vocabulary
51    vocab: Vocab = 'ag_news'Number of token in vocabulary
53    n_tokens: intNumber of classes
55    n_classes: int = 'ag_news'Tokenizer
57    tokenizer: Callable = 'character'Whether to periodically save models
60    is_save_models = TrueLoss function
63    loss_func = nn.CrossEntropyLoss()Accuracy function
65    accuracy = Accuracy()Model embedding size
67    d_model: int = 512Gradient clipping
69    grad_norm_clip: float = 1.0Training data loader
72    train_loader: DataLoader = 'ag_news'Validation data loader
74    valid_loader: DataLoader = 'ag_news'Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
79    is_log_model_params_grads: bool = FalseWhether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
84    is_log_model_activations: bool = False86    def init(self):Set tracker configurations
91        tracker.set_scalar("accuracy.*", True)
92        tracker.set_scalar("loss.*", True)Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
97        self.state_modules = [self.accuracy]99    def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
105        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
108        if self.mode.is_train:
109            tracker.add_global_step(data.shape[1])Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
114        output, *_ = self.model(data)Calculate and log loss
117        loss = self.loss_func(output, target)
118        tracker.add("loss.", loss)Calculate and log accuracy
121        self.accuracy(output, target)
122        self.accuracy.track()Train the model
125        if self.mode.is_train:Calculate gradients
127            loss.backward()Clip gradients
129            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
131            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
133            if batch_idx.is_last and self.is_log_model_params_grads:
134                tracker.add('model', self.model)Clear the gradients
136            self.optimizer.zero_grad()Save the tracked metrics
139        tracker.save()142@option(NLPClassificationConfigs.optimizer)
143def _optimizer(c: NLPClassificationConfigs):148    optimizer = OptimizerConfigs()
149    optimizer.parameters = c.model.parameters()
150    optimizer.optimizer = 'Adam'
151    optimizer.d_model = c.d_model
152
153    return optimizerWe use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',in the configurations dictionary when starting the experiment.
156@option(NLPClassificationConfigs.tokenizer)
157def basic_english():171    from torchtext.data import get_tokenizer
172    return get_tokenizer('basic_english')175def character_tokenizer(x: str):179    return list(x)Character level tokenizer configuration
182@option(NLPClassificationConfigs.tokenizer)
183def character():187    return character_tokenizerGet number of tokens
190@option(NLPClassificationConfigs.n_tokens)
191def _n_tokens(c: NLPClassificationConfigs):195    return len(c.vocab) + 2198class CollateFunc:tokenizer
 is the tokenizer function vocab
 is the vocabulary seq_len
 is the length of the sequence padding_token
 is the token used for padding when the seq_len
 is larger than the text length classifier_token
 is the [CLS]
 token which we set at end of the input203    def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):211        self.classifier_token = classifier_token
212        self.padding_token = padding_token
213        self.seq_len = seq_len
214        self.vocab = vocab
215        self.tokenizer = tokenizerbatch
 is the batch of data collected by the DataLoader
217    def __call__(self, batch):Input data tensor, initialized with padding_token
 
223        data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)Empty labels tensor
225        labels = torch.zeros(len(batch), dtype=torch.long)Loop through the samples
228        for (i, (_label, _text)) in enumerate(batch):Set the label
230            labels[i] = int(_label) - 1Tokenize the input text
232            _text = [self.vocab[token] for token in self.tokenizer(_text)]Truncate upto seq_len
 
234            _text = _text[:self.seq_len]Transpose and add to data
236            data[:len(_text), i] = data.new_tensor(_text)Set the final token in the sequence to [CLS]
 
239        data[-1, :] = self.classifier_token242        return data, labelsThis loads the AG News dataset and the set the values for  n_classes
, vocab
, train_loader
, and valid_loader
.
245@option([NLPClassificationConfigs.n_classes,
246         NLPClassificationConfigs.vocab,
247         NLPClassificationConfigs.train_loader,
248         NLPClassificationConfigs.valid_loader])
249def ag_news(c: NLPClassificationConfigs):Get training and validation datasets
258    train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))Load data to memory
261    with monit.section('Load data'):
262        from labml_nn.utils import MapStyleDatasetCreate map-style datasets
265        train, valid = MapStyleDataset(train), MapStyleDataset(valid)Get tokenizer
268    tokenizer = c.tokenizerCreate a counter
271    counter = Counter()Collect tokens from training dataset
273    for (label, line) in train:
274        counter.update(tokenizer(line))Collect tokens from validation dataset
276    for (label, line) in valid:
277        counter.update(tokenizer(line))Create vocabulary
279    vocab = torchtext.vocab.vocab(counter, min_freq=1)Create training data loader
282    train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
283                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))Create validation data loader
285    valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
286                              collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))Return n_classes
, vocab
, train_loader
, and valid_loader
 
289    return 4, vocab, train_loader, valid_loader