11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs28class CrossEntropyLoss(Module):33 def __init__(self):
34 super().__init__()
35 self.loss = nn.CrossEntropyLoss()37 def forward(self, outputs, targets):
38 return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.
41class NLPAutoRegressionConfigs(TrainValidConfigs):Optimizer
52 optimizer: torch.optim.AdamTraining device
54 device: torch.device = DeviceConfigs()Autoregressive model
57 model: ModuleText dataset
59 text: TextDatasetBatch size
61 batch_size: int = 16Length of the sequence, or context size
63 seq_len: int = 512Number of token in vocabulary
65 n_tokens: intTokenizer
67 tokenizer: Callable = 'character'Text prompt to start sampling (for illustration)
70 prompt: strThe token separator when sampling (blank for character level tokenization)
72 prompt_separator: strWhether to periodically save models
75 is_save_models = TrueLoss function
78 loss_func = CrossEntropyLoss()Accuracy function
80 accuracy = Accuracy()Model embedding size
82 d_model: int = 512Gradient clipping
84 grad_norm_clip: float = 1.0Training data loader
87 train_loader: DataLoader = 'shuffled_train_loader'Validation data loader
89 valid_loader: DataLoader = 'shuffled_valid_loader'91 def init(self):Set tracker configurations
96 tracker.set_scalar("accuracy.*", True)
97 tracker.set_scalar("loss.*", True)Add a hook to log module outputs
99 hook_model_outputs(self.mode, self.model, 'model')Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
104 self.state_modules = [self.accuracy]Override to calculate and log other metrics
106 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):108 pass110 def step(self, batch: any, batch_idx: BatchIndex):Set training/eval mode
116 self.model.train(self.mode.is_train)Move data to the device
119 data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of tokens processed) when in training mode
122 if self.mode.is_train:
123 tracker.add_global_step(data.shape[0] * data.shape[1])Whether to capture model outputs
126 with self.mode.update(is_log_activations=batch_idx.is_last):Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
130 output, *_ = self.model(data)Calculate and log loss
133 loss = self.loss_func(output, target)
134 tracker.add("loss.", loss)Calculate and log accuracy
137 self.accuracy(output, target)
138 self.accuracy.track()
139
140 self.other_metrics(output, target)Train the model
143 if self.mode.is_train:Calculate gradients
145 loss.backward()Clip gradients
147 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
149 self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
151 if batch_idx.is_last:
152 tracker.add('model', self.model)Clear the gradients
154 self.optimizer.zero_grad()Save the tracked metrics
157 tracker.save()159 def sample(self):Starting prompt
165 prompt = self.promptCollect output for printing
167 log = [(prompt, Text.subtle)]Sample 25 tokens
169 for i in monit.iterate('Sample', 25):Tokenize the prompt
171 data = self.text.text_to_i(prompt).unsqueeze(-1)
172 data = data.to(self.device)Get the model output
174 output, *_ = self.model(data)Get the model prediction (greedy)
176 output = output.argmax(dim=-1).squeeze()Add the prediction to prompt
178 prompt += self.prompt_separator + self.text.itos[output[-1]]Add the prediction for logging
180 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]Print the sampled output
183 logger.log(log)186@option(NLPAutoRegressionConfigs.optimizer)
187def _optimizer(c: NLPAutoRegressionConfigs):192 optimizer = OptimizerConfigs()
193 optimizer.parameters = c.model.parameters()
194 optimizer.optimizer = 'Adam'
195 optimizer.d_model = c.d_model
196
197 return optimizerGet number of tokens
200@option(NLPAutoRegressionConfigs.n_tokens)
201def _n_tokens(c: NLPAutoRegressionConfigs):205 return c.text.n_tokensWe use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
in the configurations dictionary when starting the experiment.
208@option(NLPAutoRegressionConfigs.tokenizer)
209def basic_english():223 from torchtext.data import get_tokenizer
224 return get_tokenizer('basic_english')227def character_tokenizer(x: str):231 return list(x)234@option(NLPAutoRegressionConfigs.tokenizer)
235def character():239 return character_tokenizer242@option(NLPAutoRegressionConfigs.text)
243def tiny_shakespeare(c: NLPAutoRegressionConfigs):249 return TextFileDataset(
250 lab.get_data_path() / 'tiny_shakespeare.txt',
251 c.tokenizer,
252 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')255@option(NLPAutoRegressionConfigs.train_loader)
256def sequential_train_loader(c: NLPAutoRegressionConfigs):260 return SequentialDataLoader(text=c.text.train,
261 dataset=c.text,
262 batch_size=c.batch_size,
263 seq_len=c.seq_len)266@option(NLPAutoRegressionConfigs.valid_loader)
267def sequential_valid_loader(c: NLPAutoRegressionConfigs):271 return SequentialDataLoader(text=c.text.valid,
272 dataset=c.text,
273 batch_size=c.batch_size,
274 seq_len=c.seq_len)DataLoader
collects the batches on the first dimension. We need to transpose it to be sequence first.
277def transpose_batch(batch):285 transposed_data = list(zip(*batch))Stack the batch along the second dimension dim=1
287 src = torch.stack(transposed_data[0], dim=1)
288 tgt = torch.stack(transposed_data[1], dim=1)
289
290 return src, tgt293@option(NLPAutoRegressionConfigs.train_loader)
294def shuffled_train_loader(c: NLPAutoRegressionConfigs):298 return DataLoader(SequentialUnBatchedDataset(text=c.text.train,
299 dataset=c.text,
300 seq_len=c.seq_len),
301 batch_size=c.batch_size,
302 collate_fn=transpose_batch,
303 shuffle=True)306@option(NLPAutoRegressionConfigs.valid_loader)
307def shuffled_valid_loader(c: NLPAutoRegressionConfigs):311 return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
312 dataset=c.text,
313 seq_len=c.seq_len),
314 batch_size=c.batch_size,
315 collate_fn=transpose_batch,
316 shuffle=True)