11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs28class CrossEntropyLoss(Module):33    def __init__(self):
34        super().__init__()
35        self.loss = nn.CrossEntropyLoss()37    def __call__(self, outputs, targets):
38        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.
41class NLPAutoRegressionConfigs(TrainValidConfigs):Optimizer
52    optimizer: torch.optim.AdamTraining device
54    device: torch.device = DeviceConfigs()Autoregressive model
57    model: ModuleText dataset
59    text: TextDatasetBatch size
61    batch_size: int = 16Length of the sequence, or context size
63    seq_len: int = 512Number of token in vocabulary
65    n_tokens: intTokenizer
67    tokenizer: Callable = 'character'Text prompt to start sampling (for illustration)
70    prompt: strThe token separator when sampling (blank for character level tokenization)
72    prompt_separator: strWhether to periodically save models
75    is_save_models = TrueLoss function
78    loss_func = CrossEntropyLoss()Accuracy function
80    accuracy = Accuracy()Model embedding size
82    d_model: int = 512Gradient clipping
84    grad_norm_clip: float = 1.0Training data loader
87    train_loader: DataLoader = 'shuffled_train_loader'Validation data loader
89    valid_loader: DataLoader = 'shuffled_valid_loader'91    def init(self):Set tracker configurations
96        tracker.set_scalar("accuracy.*", True)
97        tracker.set_scalar("loss.*", True)Add a hook to log module outputs
99        hook_model_outputs(self.mode, self.model, 'model')Add accuracy as a state module. The name is probably confusing, since it’s meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
104        self.state_modules = [self.accuracy]106    def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
112        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
115        if self.mode.is_train:
116            tracker.add_global_step(data.shape[0] * data.shape[1])Whether to capture model outputs
119        with self.mode.update(is_log_activations=batch_idx.is_last):Get model outputs. It’s returning a tuple for states when using RNNs. This is not implemented yet. 😜
123            output, *_ = self.model(data)Calculate and log loss
126        loss = self.loss_func(output, target)
127        tracker.add("loss.", loss)Calculate and log accuracy
130        self.accuracy(output, target)
131        self.accuracy.track()Train the model
134        if self.mode.is_train:Calculate gradients
136            loss.backward()Clip gradients
138            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
140            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
142            if batch_idx.is_last:
143                tracker.add('model', self.model)Clear the gradients
145            self.optimizer.zero_grad()Save the tracked metrics
148        tracker.save()150    def sample(self):Starting prompt
156        prompt = self.promptCollect output for printing
158        log = [(prompt, Text.subtle)]Sample 25 tokens
160        for i in monit.iterate('Sample', 25):Tokenize the prompt
162            data = self.text.text_to_i(prompt).unsqueeze(-1)
163            data = data.to(self.device)Get the model output
165            output, *_ = self.model(data)Get the model prediction (greedy)
167            output = output.argmax(dim=-1).squeeze()Add the prediction to prompt
169            prompt += self.prompt_separator + self.text.itos[output[-1]]Add the prediction for logging
171            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]Print the sampled output
174        logger.log(log)177@option(NLPAutoRegressionConfigs.optimizer)
178def _optimizer(c: NLPAutoRegressionConfigs):183    optimizer = OptimizerConfigs()
184    optimizer.parameters = c.model.parameters()
185    optimizer.optimizer = 'Adam'
186    optimizer.d_model = c.d_model
187
188    return optimizerGet number of tokens
191@option(NLPAutoRegressionConfigs.n_tokens)
192def _n_tokens(c: NLPAutoRegressionConfigs):196    return c.text.n_tokensWe use character level tokenizer in this experiment. You can switch by setting,
    'tokenizer': 'basic_english',
as the configurations dictionary when starting the experiment.
199@option(NLPAutoRegressionConfigs.tokenizer)
200def basic_english():214    from torchtext.data import get_tokenizer
215    return get_tokenizer('basic_english')218def character_tokenizer(x: str):222    return list(x)225@option(NLPAutoRegressionConfigs.tokenizer)
226def character():230    return character_tokenizer233@option(NLPAutoRegressionConfigs.text)
234def tiny_shakespeare(c: NLPAutoRegressionConfigs):240    return TextFileDataset(
241        lab.get_data_path() / 'tiny_shakespeare.txt',
242        c.tokenizer,
243        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')246@option(NLPAutoRegressionConfigs.train_loader)
247def sequential_train_loader(c: NLPAutoRegressionConfigs):251    return SequentialDataLoader(text=c.text.train,
252                                dataset=c.text,
253                                batch_size=c.batch_size,
254                                seq_len=c.seq_len)257@option(NLPAutoRegressionConfigs.valid_loader)
258def sequential_valid_loader(c: NLPAutoRegressionConfigs):262    return SequentialDataLoader(text=c.text.valid,
263                                dataset=c.text,
264                                batch_size=c.batch_size,
265                                seq_len=c.seq_len)DataLoader collects the batches on the first dimension.
We need to transpose it to be sequence first.
268def transpose_batch(batch):276    transposed_data = list(zip(*batch))Stack the batch along the second dimension dim=1
278    src = torch.stack(transposed_data[0], dim=1)
279    tgt = torch.stack(transposed_data[1], dim=1)
280
281    return src, tgt284@option(NLPAutoRegressionConfigs.train_loader)
285def shuffled_train_loader(c: NLPAutoRegressionConfigs):289    return DataLoader(SequentialUnBatchedDataset(text=c.text.train,
290                                                 dataset=c.text,
291                                                 seq_len=c.seq_len),
292                      batch_size=c.batch_size,
293                      collate_fn=transpose_batch,
294                      shuffle=True)297@option(NLPAutoRegressionConfigs.valid_loader)
298def shuffled_valid_loader(c: NLPAutoRegressionConfigs):302    return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
303                                                 dataset=c.text,
304                                                 seq_len=c.seq_len),
305                      batch_size=c.batch_size,
306                      collate_fn=transpose_batch,
307                      shuffle=True)