11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs28class CrossEntropyLoss(Module):33 def __init__(self):
34 super().__init__()
35 self.loss = nn.CrossEntropyLoss()37 def forward(self, outputs, targets):
38 return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.
41class NLPAutoRegressionConfigs(TrainValidConfigs):Optimizer
51 optimizer: torch.optim.AdamTraining device
53 device: torch.device = DeviceConfigs()Autoregressive model
56 model: ModuleText dataset
58 text: TextDatasetBatch size
60 batch_size: int = 16Length of the sequence, or context size
62 seq_len: int = 512Number of token in vocabulary
64 n_tokens: intTokenizer
66 tokenizer: Callable = 'character'Text prompt to start sampling (for illustration)
69 prompt: strThe token separator when sampling (blank for character level tokenization)
71 prompt_separator: strWhether to periodically save models
74 is_save_models = TrueLoss function
77 loss_func = CrossEntropyLoss()Accuracy function
79 accuracy = Accuracy()Model embedding size
81 d_model: int = 512Gradient clipping
83 grad_norm_clip: float = 1.0Training data loader
86 train_loader: DataLoader = 'shuffled_train_loader'Validation data loader
88 valid_loader: DataLoader = 'shuffled_valid_loader'90 def init(self):Set tracker configurations
95 tracker.set_scalar("accuracy.*", True)
96 tracker.set_scalar("loss.*", True)Add a hook to log module outputs
98 hook_model_outputs(self.mode, self.model, 'model')Add accuracy as a state module. The name is probably confusing, since it’s meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
103 self.state_modules = [self.accuracy]Override to calculate and log other metrics
105 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):107 pass109 def step(self, batch: any, batch_idx: BatchIndex):Set training/eval mode
115 self.model.train(self.mode.is_train)Move data to the device
118 data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
121 if self.mode.is_train:
122 tracker.add_global_step(data.shape[0] * data.shape[1])Whether to capture model outputs
125 with self.mode.update(is_log_activations=batch_idx.is_last):Get model outputs. It’s returning a tuple for states when using RNNs. This is not implemented yet. 😜
129 output, *_ = self.model(data)Calculate and log loss
132 loss = self.loss_func(output, target)
133 tracker.add("loss.", loss)Calculate and log accuracy
136 self.accuracy(output, target)
137 self.accuracy.track()
138
139 self.other_metrics(output, target)Train the model
142 if self.mode.is_train:Calculate gradients
144 loss.backward()Clip gradients
146 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
148 self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
150 if batch_idx.is_last:
151 tracker.add('model', self.model)Clear the gradients
153 self.optimizer.zero_grad()Save the tracked metrics
156 tracker.save()158 def sample(self):Starting prompt
164 prompt = self.promptCollect output for printing
166 log = [(prompt, Text.subtle)]Sample 25 tokens
168 for i in monit.iterate('Sample', 25):Tokenize the prompt
170 data = self.text.text_to_i(prompt).unsqueeze(-1)
171 data = data.to(self.device)Get the model output
173 output, *_ = self.model(data)Get the model prediction (greedy)
175 output = output.argmax(dim=-1).squeeze()Add the prediction to prompt
177 prompt += self.prompt_separator + self.text.itos[output[-1]]Add the prediction for logging
179 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]Print the sampled output
182 logger.log(log)185@option(NLPAutoRegressionConfigs.optimizer)
186def _optimizer(c: NLPAutoRegressionConfigs):191 optimizer = OptimizerConfigs()
192 optimizer.parameters = c.model.parameters()
193 optimizer.optimizer = 'Adam'
194 optimizer.d_model = c.d_model
195
196 return optimizerGet number of tokens
199@option(NLPAutoRegressionConfigs.n_tokens)
200def _n_tokens(c: NLPAutoRegressionConfigs):204 return c.text.n_tokensWe use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
as the configurations dictionary when starting the experiment.
207@option(NLPAutoRegressionConfigs.tokenizer)
208def basic_english():222 from torchtext.data import get_tokenizer
223 return get_tokenizer('basic_english')226def character_tokenizer(x: str):230 return list(x)233@option(NLPAutoRegressionConfigs.tokenizer)
234def character():238 return character_tokenizer241@option(NLPAutoRegressionConfigs.text)
242def tiny_shakespeare(c: NLPAutoRegressionConfigs):248 return TextFileDataset(
249 lab.get_data_path() / 'tiny_shakespeare.txt',
250 c.tokenizer,
251 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')254@option(NLPAutoRegressionConfigs.train_loader)
255def sequential_train_loader(c: NLPAutoRegressionConfigs):259 return SequentialDataLoader(text=c.text.train,
260 dataset=c.text,
261 batch_size=c.batch_size,
262 seq_len=c.seq_len)265@option(NLPAutoRegressionConfigs.valid_loader)
266def sequential_valid_loader(c: NLPAutoRegressionConfigs):270 return SequentialDataLoader(text=c.text.valid,
271 dataset=c.text,
272 batch_size=c.batch_size,
273 seq_len=c.seq_len)DataLoader collects the batches on the first dimension.
We need to transpose it to be sequence first.
276def transpose_batch(batch):284 transposed_data = list(zip(*batch))Stack the batch along the second dimension dim=1
286 src = torch.stack(transposed_data[0], dim=1)
287 tgt = torch.stack(transposed_data[1], dim=1)
288
289 return src, tgt292@option(NLPAutoRegressionConfigs.train_loader)
293def shuffled_train_loader(c: NLPAutoRegressionConfigs):297 return DataLoader(SequentialUnBatchedDataset(text=c.text.train,
298 dataset=c.text,
299 seq_len=c.seq_len),
300 batch_size=c.batch_size,
301 collate_fn=transpose_batch,
302 shuffle=True)305@option(NLPAutoRegressionConfigs.valid_loader)
306def shuffled_valid_loader(c: NLPAutoRegressionConfigs):310 return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
311 dataset=c.text,
312 seq_len=c.seq_len),
313 batch_size=c.batch_size,
314 collate_fn=transpose_batch,
315 shuffle=True)