11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader, RandomSampler
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs28class CrossEntropyLoss(Module):33    def __init__(self):
34        super().__init__()
35        self.loss = nn.CrossEntropyLoss()37    def forward(self, outputs, targets):
38        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.
41class NLPAutoRegressionConfigs(TrainValidConfigs):Optimizer
52    optimizer: torch.optim.AdamTraining device
54    device: torch.device = DeviceConfigs()Autoregressive model
57    model: ModuleText dataset
59    text: TextDatasetBatch size
61    batch_size: int = 16Length of the sequence, or context size
63    seq_len: int = 512Number of token in vocabulary
65    n_tokens: intTokenizer
67    tokenizer: Callable = 'character'Text prompt to start sampling (for illustration)
70    prompt: strThe token separator when sampling (blank for character level tokenization)
72    prompt_separator: strWhether to periodically save models
75    is_save_models = TrueLoss function
78    loss_func = CrossEntropyLoss()Accuracy function
80    accuracy = Accuracy()Model embedding size
82    d_model: int = 512Gradient clipping
84    grad_norm_clip: float = 1.0Training data loader
87    train_loader: DataLoader = 'shuffled_train_loader'Validation data loader
89    valid_loader: DataLoader = 'shuffled_valid_loader'Data loaders shuffle with replacement
92    dataloader_shuffle_with_replacement: bool = FalseWhether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
97    is_log_model_params_grads: bool = FalseWhether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
102    is_log_model_activations: bool = False104    def init(self):Set tracker configurations
109        tracker.set_scalar("accuracy.*", True)
110        tracker.set_scalar("loss.*", True)
111        tracker.set_text("sampled", False)Add a hook to log module outputs
113        hook_model_outputs(self.mode, self.model, 'model')Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
118        self.state_modules = [self.accuracy]Override to calculate and log other metrics
120    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):122        pass124    def step(self, batch: any, batch_idx: BatchIndex):Set training/eval mode
130        self.model.train(self.mode.is_train)Move data to the device
133        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
136        if self.mode.is_train:
137            tracker.add_global_step(data.shape[0] * data.shape[1])Whether to capture model outputs
140        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
144            output, *_ = self.model(data)Calculate and log loss
147        loss = self.loss_func(output, target)
148        tracker.add("loss.", loss)Calculate and log accuracy
151        self.accuracy(output, target)
152        self.accuracy.track()
153
154        self.other_metrics(output, target)Train the model
157        if self.mode.is_train:Calculate gradients
159            loss.backward()Clip gradients
161            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
163            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
165            if batch_idx.is_last and self.is_log_model_params_grads:
166                tracker.add('model', self.model)Clear the gradients
168            self.optimizer.zero_grad()Save the tracked metrics
171        tracker.save()173    def sample(self):Starting prompt
179        prompt = self.promptCollect output for printing
181        log = [(prompt, Text.subtle)]Sample 25 tokens
183        for i in monit.iterate('Sample', 25):Tokenize the prompt
185            data = self.text.text_to_i(prompt).unsqueeze(-1)
186            data = data.to(self.device)Get the model output
188            output, *_ = self.model(data)Get the model prediction (greedy)
190            output = output.argmax(dim=-1).squeeze()Add the prediction to prompt
192            prompt += self.prompt_separator + self.text.itos[output[-1]]Add the prediction for logging
194            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
195
196        tracker.add({'sampled': prompt})Print the sampled output
198        logger.log(log)201@option(NLPAutoRegressionConfigs.optimizer)
202def _optimizer(c: NLPAutoRegressionConfigs):207    optimizer = OptimizerConfigs()
208    optimizer.parameters = c.model.parameters()
209    optimizer.optimizer = 'Adam'
210    optimizer.d_model = c.d_model
211
212    return optimizerGet number of tokens
215@option(NLPAutoRegressionConfigs.n_tokens)
216def _n_tokens(c: NLPAutoRegressionConfigs):220    return c.text.n_tokensWe use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',in the configurations dictionary when starting the experiment.
223@option(NLPAutoRegressionConfigs.tokenizer)
224def basic_english():238    from torchtext.data import get_tokenizer
239    return get_tokenizer('basic_english')242def character_tokenizer(x: str):246    return list(x)249@option(NLPAutoRegressionConfigs.tokenizer)
250def character():254    return character_tokenizer257@option(NLPAutoRegressionConfigs.text)
258def tiny_shakespeare(c: NLPAutoRegressionConfigs):264    return TextFileDataset(
265        lab.get_data_path() / 'tiny_shakespeare.txt',
266        c.tokenizer,
267        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')270@option(NLPAutoRegressionConfigs.train_loader)
271def sequential_train_loader(c: NLPAutoRegressionConfigs):275    return SequentialDataLoader(text=c.text.train,
276                                dataset=c.text,
277                                batch_size=c.batch_size,
278                                seq_len=c.seq_len)281@option(NLPAutoRegressionConfigs.valid_loader)
282def sequential_valid_loader(c: NLPAutoRegressionConfigs):286    return SequentialDataLoader(text=c.text.valid,
287                                dataset=c.text,
288                                batch_size=c.batch_size,
289                                seq_len=c.seq_len)DataLoader
 collects the batches on the first dimension. We need to transpose it to be sequence first.
292def transpose_batch(batch):300    transposed_data = list(zip(*batch))Stack the batch along the second dimension dim=1
 
302    src = torch.stack(transposed_data[0], dim=1)
303    tgt = torch.stack(transposed_data[1], dim=1)
304
305    return src, tgt308@option(NLPAutoRegressionConfigs.train_loader)
309def shuffled_train_loader(c: NLPAutoRegressionConfigs):313    dataset = SequentialUnBatchedDataset(text=c.text.train,
314                                         dataset=c.text,
315                                         seq_len=c.seq_len)
316    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
317
318    return DataLoader(dataset,
319                      batch_size=c.batch_size,
320                      collate_fn=transpose_batch,
321                      sampler=sampler)324@option(NLPAutoRegressionConfigs.valid_loader)
325def shuffled_valid_loader(c: NLPAutoRegressionConfigs):329    dataset = SequentialUnBatchedDataset(text=c.text.valid,
330                                         dataset=c.text,
331                                         seq_len=c.seq_len)
332    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
333
334    return DataLoader(dataset,
335                      batch_size=c.batch_size,
336                      collate_fn=transpose_batch,
337                      sampler=sampler)