11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader, RandomSampler
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs
28class CrossEntropyLoss(Module):
33 def __init__(self):
34 super().__init__()
35 self.loss = nn.CrossEntropyLoss()
37 def forward(self, outputs, targets):
38 return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.
41class NLPAutoRegressionConfigs(TrainValidConfigs):
Optimizer
52 optimizer: torch.optim.Adam
Training device
54 device: torch.device = DeviceConfigs()
Autoregressive model
57 model: Module
Text dataset
59 text: TextDataset
Batch size
61 batch_size: int = 16
Length of the sequence, or context size
63 seq_len: int = 512
Number of token in vocabulary
65 n_tokens: int
Tokenizer
67 tokenizer: Callable = 'character'
Text prompt to start sampling (for illustration)
70 prompt: str
The token separator when sampling (blank for character level tokenization)
72 prompt_separator: str
Whether to periodically save models
75 is_save_models = True
Loss function
78 loss_func = CrossEntropyLoss()
Accuracy function
80 accuracy = Accuracy()
Model embedding size
82 d_model: int = 512
Gradient clipping
84 grad_norm_clip: float = 1.0
Training data loader
87 train_loader: DataLoader = 'shuffled_train_loader'
Validation data loader
89 valid_loader: DataLoader = 'shuffled_valid_loader'
Data loaders shuffle with replacement
92 dataloader_shuffle_with_replacement: bool = False
94 def init(self):
Set tracker configurations
99 tracker.set_scalar("accuracy.*", True)
100 tracker.set_scalar("loss.*", True)
Add a hook to log module outputs
102 hook_model_outputs(self.mode, self.model, 'model')
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
107 self.state_modules = [self.accuracy]
Override to calculate and log other metrics
109 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
111 pass
113 def step(self, batch: any, batch_idx: BatchIndex):
Set training/eval mode
119 self.model.train(self.mode.is_train)
Move data to the device
122 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
125 if self.mode.is_train:
126 tracker.add_global_step(data.shape[0] * data.shape[1])
Whether to capture model outputs
129 with self.mode.update(is_log_activations=batch_idx.is_last):
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
133 output, *_ = self.model(data)
Calculate and log loss
136 loss = self.loss_func(output, target)
137 tracker.add("loss.", loss)
Calculate and log accuracy
140 self.accuracy(output, target)
141 self.accuracy.track()
142
143 self.other_metrics(output, target)
Train the model
146 if self.mode.is_train:
Calculate gradients
148 loss.backward()
Clip gradients
150 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
152 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
154 if batch_idx.is_last:
155 tracker.add('model', self.model)
Clear the gradients
157 self.optimizer.zero_grad()
Save the tracked metrics
160 tracker.save()
162 def sample(self):
Starting prompt
168 prompt = self.prompt
Collect output for printing
170 log = [(prompt, Text.subtle)]
Sample 25 tokens
172 for i in monit.iterate('Sample', 25):
Tokenize the prompt
174 data = self.text.text_to_i(prompt).unsqueeze(-1)
175 data = data.to(self.device)
Get the model output
177 output, *_ = self.model(data)
Get the model prediction (greedy)
179 output = output.argmax(dim=-1).squeeze()
Add the prediction to prompt
181 prompt += self.prompt_separator + self.text.itos[output[-1]]
Add the prediction for logging
183 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
Print the sampled output
186 logger.log(log)
189@option(NLPAutoRegressionConfigs.optimizer)
190def _optimizer(c: NLPAutoRegressionConfigs):
195 optimizer = OptimizerConfigs()
196 optimizer.parameters = c.model.parameters()
197 optimizer.optimizer = 'Adam'
198 optimizer.d_model = c.d_model
199
200 return optimizer
Get number of tokens
203@option(NLPAutoRegressionConfigs.n_tokens)
204def _n_tokens(c: NLPAutoRegressionConfigs):
208 return c.text.n_tokens
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
in the configurations dictionary when starting the experiment.
211@option(NLPAutoRegressionConfigs.tokenizer)
212def basic_english():
226 from torchtext.data import get_tokenizer
227 return get_tokenizer('basic_english')
230def character_tokenizer(x: str):
234 return list(x)
237@option(NLPAutoRegressionConfigs.tokenizer)
238def character():
242 return character_tokenizer
245@option(NLPAutoRegressionConfigs.text)
246def tiny_shakespeare(c: NLPAutoRegressionConfigs):
252 return TextFileDataset(
253 lab.get_data_path() / 'tiny_shakespeare.txt',
254 c.tokenizer,
255 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
258@option(NLPAutoRegressionConfigs.train_loader)
259def sequential_train_loader(c: NLPAutoRegressionConfigs):
263 return SequentialDataLoader(text=c.text.train,
264 dataset=c.text,
265 batch_size=c.batch_size,
266 seq_len=c.seq_len)
269@option(NLPAutoRegressionConfigs.valid_loader)
270def sequential_valid_loader(c: NLPAutoRegressionConfigs):
274 return SequentialDataLoader(text=c.text.valid,
275 dataset=c.text,
276 batch_size=c.batch_size,
277 seq_len=c.seq_len)
DataLoader
collects the batches on the first dimension. We need to transpose it to be sequence first.
280def transpose_batch(batch):
288 transposed_data = list(zip(*batch))
Stack the batch along the second dimension dim=1
290 src = torch.stack(transposed_data[0], dim=1)
291 tgt = torch.stack(transposed_data[1], dim=1)
292
293 return src, tgt
296@option(NLPAutoRegressionConfigs.train_loader)
297def shuffled_train_loader(c: NLPAutoRegressionConfigs):
301 dataset = SequentialUnBatchedDataset(text=c.text.train,
302 dataset=c.text,
303 seq_len=c.seq_len)
304 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
305
306 return DataLoader(dataset,
307 batch_size=c.batch_size,
308 collate_fn=transpose_batch,
309 sampler=sampler)
312@option(NLPAutoRegressionConfigs.valid_loader)
313def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
317 dataset = SequentialUnBatchedDataset(text=c.text.valid,
318 dataset=c.text,
319 seq_len=c.seq_len)
320 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
321
322 return DataLoader(dataset,
323 batch_size=c.batch_size,
324 collate_fn=transpose_batch,
325 sampler=sampler)