11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader, RandomSampler
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs
28class CrossEntropyLoss(Module):
33 def __init__(self):
34 super().__init__()
35 self.loss = nn.CrossEntropyLoss()
37 def forward(self, outputs, targets):
38 return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.
41class NLPAutoRegressionConfigs(TrainValidConfigs):
Optimizer
52 optimizer: torch.optim.Adam
Training device
54 device: torch.device = DeviceConfigs()
Autoregressive model
57 model: Module
Text dataset
59 text: TextDataset
Batch size
61 batch_size: int = 16
Length of the sequence, or context size
63 seq_len: int = 512
Number of token in vocabulary
65 n_tokens: int
Tokenizer
67 tokenizer: Callable = 'character'
Text prompt to start sampling (for illustration)
70 prompt: str
The token separator when sampling (blank for character level tokenization)
72 prompt_separator: str
Whether to periodically save models
75 is_save_models = True
Loss function
78 loss_func = CrossEntropyLoss()
Accuracy function
80 accuracy = Accuracy()
Model embedding size
82 d_model: int = 512
Gradient clipping
84 grad_norm_clip: float = 1.0
Training data loader
87 train_loader: DataLoader = 'shuffled_train_loader'
Validation data loader
89 valid_loader: DataLoader = 'shuffled_valid_loader'
Data loaders shuffle with replacement
92 dataloader_shuffle_with_replacement: bool = False
Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
97 is_log_model_params_grads: bool = False
Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
102 is_log_model_activations: bool = False
104 def init(self):
Set tracker configurations
109 tracker.set_scalar("accuracy.*", True)
110 tracker.set_scalar("loss.*", True)
Add a hook to log module outputs
112 hook_model_outputs(self.mode, self.model, 'model')
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
117 self.state_modules = [self.accuracy]
Override to calculate and log other metrics
119 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
121 pass
123 def step(self, batch: any, batch_idx: BatchIndex):
Set training/eval mode
129 self.model.train(self.mode.is_train)
Move data to the device
132 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
135 if self.mode.is_train:
136 tracker.add_global_step(data.shape[0] * data.shape[1])
Whether to capture model outputs
139 with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
143 output, *_ = self.model(data)
Calculate and log loss
146 loss = self.loss_func(output, target)
147 tracker.add("loss.", loss)
Calculate and log accuracy
150 self.accuracy(output, target)
151 self.accuracy.track()
152
153 self.other_metrics(output, target)
Train the model
156 if self.mode.is_train:
Calculate gradients
158 loss.backward()
Clip gradients
160 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
162 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
164 if batch_idx.is_last and self.is_log_model_params_grads:
165 tracker.add('model', self.model)
Clear the gradients
167 self.optimizer.zero_grad()
Save the tracked metrics
170 tracker.save()
172 def sample(self):
Starting prompt
178 prompt = self.prompt
Collect output for printing
180 log = [(prompt, Text.subtle)]
Sample 25 tokens
182 for i in monit.iterate('Sample', 25):
Tokenize the prompt
184 data = self.text.text_to_i(prompt).unsqueeze(-1)
185 data = data.to(self.device)
Get the model output
187 output, *_ = self.model(data)
Get the model prediction (greedy)
189 output = output.argmax(dim=-1).squeeze()
Add the prediction to prompt
191 prompt += self.prompt_separator + self.text.itos[output[-1]]
Add the prediction for logging
193 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
Print the sampled output
196 logger.log(log)
199@option(NLPAutoRegressionConfigs.optimizer)
200def _optimizer(c: NLPAutoRegressionConfigs):
205 optimizer = OptimizerConfigs()
206 optimizer.parameters = c.model.parameters()
207 optimizer.optimizer = 'Adam'
208 optimizer.d_model = c.d_model
209
210 return optimizer
Get number of tokens
213@option(NLPAutoRegressionConfigs.n_tokens)
214def _n_tokens(c: NLPAutoRegressionConfigs):
218 return c.text.n_tokens
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
in the configurations dictionary when starting the experiment.
221@option(NLPAutoRegressionConfigs.tokenizer)
222def basic_english():
236 from torchtext.data import get_tokenizer
237 return get_tokenizer('basic_english')
240def character_tokenizer(x: str):
244 return list(x)
247@option(NLPAutoRegressionConfigs.tokenizer)
248def character():
252 return character_tokenizer
255@option(NLPAutoRegressionConfigs.text)
256def tiny_shakespeare(c: NLPAutoRegressionConfigs):
262 return TextFileDataset(
263 lab.get_data_path() / 'tiny_shakespeare.txt',
264 c.tokenizer,
265 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
268@option(NLPAutoRegressionConfigs.train_loader)
269def sequential_train_loader(c: NLPAutoRegressionConfigs):
273 return SequentialDataLoader(text=c.text.train,
274 dataset=c.text,
275 batch_size=c.batch_size,
276 seq_len=c.seq_len)
279@option(NLPAutoRegressionConfigs.valid_loader)
280def sequential_valid_loader(c: NLPAutoRegressionConfigs):
284 return SequentialDataLoader(text=c.text.valid,
285 dataset=c.text,
286 batch_size=c.batch_size,
287 seq_len=c.seq_len)
DataLoader
collects the batches on the first dimension. We need to transpose it to be sequence first.
290def transpose_batch(batch):
298 transposed_data = list(zip(*batch))
Stack the batch along the second dimension dim=1
300 src = torch.stack(transposed_data[0], dim=1)
301 tgt = torch.stack(transposed_data[1], dim=1)
302
303 return src, tgt
306@option(NLPAutoRegressionConfigs.train_loader)
307def shuffled_train_loader(c: NLPAutoRegressionConfigs):
311 dataset = SequentialUnBatchedDataset(text=c.text.train,
312 dataset=c.text,
313 seq_len=c.seq_len)
314 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
315
316 return DataLoader(dataset,
317 batch_size=c.batch_size,
318 collate_fn=transpose_batch,
319 sampler=sampler)
322@option(NLPAutoRegressionConfigs.valid_loader)
323def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
327 dataset = SequentialUnBatchedDataset(text=c.text.valid,
328 dataset=c.text,
329 seq_len=c.seq_len)
330 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
331
332 return DataLoader(dataset,
333 batch_size=c.batch_size,
334 collate_fn=transpose_batch,
335 sampler=sampler)