11from typing import Callable
12
13import torch
14import torch.nn as nn
15from torch.utils.data import DataLoader
16
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs
28class CrossEntropyLoss(Module):
33 def __init__(self):
34 super().__init__()
35 self.loss = nn.CrossEntropyLoss()
37 def __call__(self, outputs, targets):
38 return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.
41class NLPAutoRegressionConfigs(TrainValidConfigs):
Optimizer
52 optimizer: torch.optim.Adam
Training device
54 device: torch.device = DeviceConfigs()
Autoregressive model
57 model: Module
Text dataset
59 text: TextDataset
Batch size
61 batch_size: int = 16
Length of the sequence, or context size
63 seq_len: int = 512
Number of token in vocabulary
65 n_tokens: int
Tokenizer
67 tokenizer: Callable = 'character'
Text prompt to start sampling (for illustration)
70 prompt: str
The token separator when sampling (blank for character level tokenization)
72 prompt_separator: str
Whether to periodically save models
75 is_save_models = True
Loss function
78 loss_func = CrossEntropyLoss()
Accuracy function
80 accuracy = Accuracy()
Model embedding size
82 d_model: int = 512
Gradient clipping
84 grad_norm_clip: float = 1.0
Training data loader
87 train_loader: DataLoader = 'shuffled_train_loader'
Validation data loader
89 valid_loader: DataLoader = 'shuffled_valid_loader'
91 def init(self):
Set tracker configurations
96 tracker.set_scalar("accuracy.*", True)
97 tracker.set_scalar("loss.*", True)
Add a hook to log module outputs
99 hook_model_outputs(self.mode, self.model, 'model')
Add accuracy as a state module. The name is probably confusing, since it’s meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
104 self.state_modules = [self.accuracy]
106 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
112 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
115 if self.mode.is_train:
116 tracker.add_global_step(data.shape[0] * data.shape[1])
Whether to capture model outputs
119 with self.mode.update(is_log_activations=batch_idx.is_last):
Get model outputs. It’s returning a tuple for states when using RNNs. This is not implemented yet. 😜
123 output, *_ = self.model(data)
Calculate and log loss
126 loss = self.loss_func(output, target)
127 tracker.add("loss.", loss)
Calculate and log accuracy
130 self.accuracy(output, target)
131 self.accuracy.track()
Train the model
134 if self.mode.is_train:
Calculate gradients
136 loss.backward()
Clip gradients
138 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
140 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
142 if batch_idx.is_last:
143 tracker.add('model', self.model)
Clear the gradients
145 self.optimizer.zero_grad()
Save the tracked metrics
148 tracker.save()
150 def sample(self):
Starting prompt
156 prompt = self.prompt
Collect output for printing
158 log = [(prompt, Text.subtle)]
Sample 25 tokens
160 for i in monit.iterate('Sample', 25):
Tokenize the prompt
162 data = self.text.text_to_i(prompt).unsqueeze(-1)
163 data = data.to(self.device)
Get the model output
165 output, *_ = self.model(data)
Get the model prediction (greedy)
167 output = output.argmax(dim=-1).squeeze()
Add the prediction to prompt
169 prompt += self.prompt_separator + self.text.itos[output[-1]]
Add the prediction for logging
171 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
Print the sampled output
174 logger.log(log)
177@option(NLPAutoRegressionConfigs.optimizer)
178def _optimizer(c: NLPAutoRegressionConfigs):
183 optimizer = OptimizerConfigs()
184 optimizer.parameters = c.model.parameters()
185 optimizer.optimizer = 'Adam'
186 optimizer.d_model = c.d_model
187
188 return optimizer
Get number of tokens
191@option(NLPAutoRegressionConfigs.n_tokens)
192def _n_tokens(c: NLPAutoRegressionConfigs):
196 return c.text.n_tokens
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
as the configurations dictionary when starting the experiment.
199@option(NLPAutoRegressionConfigs.tokenizer)
200def basic_english():
214 from torchtext.data import get_tokenizer
215 return get_tokenizer('basic_english')
218def character_tokenizer(x: str):
222 return list(x)
225@option(NLPAutoRegressionConfigs.tokenizer)
226def character():
230 return character_tokenizer
233@option(NLPAutoRegressionConfigs.text)
234def tiny_shakespeare(c: NLPAutoRegressionConfigs):
240 return TextFileDataset(
241 lab.get_data_path() / 'tiny_shakespeare.txt',
242 c.tokenizer,
243 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
246@option(NLPAutoRegressionConfigs.train_loader)
247def sequential_train_loader(c: NLPAutoRegressionConfigs):
251 return SequentialDataLoader(text=c.text.train,
252 dataset=c.text,
253 batch_size=c.batch_size,
254 seq_len=c.seq_len)
257@option(NLPAutoRegressionConfigs.valid_loader)
258def sequential_valid_loader(c: NLPAutoRegressionConfigs):
262 return SequentialDataLoader(text=c.text.valid,
263 dataset=c.text,
264 batch_size=c.batch_size,
265 seq_len=c.seq_len)
DataLoader
collects the batches on the first dimension.
We need to transpose it to be sequence first.
268def transpose_batch(batch):
276 transposed_data = list(zip(*batch))
Stack the batch along the second dimension dim=1
278 src = torch.stack(transposed_data[0], dim=1)
279 tgt = torch.stack(transposed_data[1], dim=1)
280
281 return src, tgt
284@option(NLPAutoRegressionConfigs.train_loader)
285def shuffled_train_loader(c: NLPAutoRegressionConfigs):
289 return DataLoader(SequentialUnBatchedDataset(text=c.text.train,
290 dataset=c.text,
291 seq_len=c.seq_len),
292 batch_size=c.batch_size,
293 collate_fn=transpose_batch,
294 shuffle=True)
297@option(NLPAutoRegressionConfigs.valid_loader)
298def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
302 return DataLoader(SequentialUnBatchedDataset(text=c.text.valid,
303 dataset=c.text,
304 seq_len=c.seq_len),
305 batch_size=c.batch_size,
306 collate_fn=transpose_batch,
307 shuffle=True)