11from typing import Callable
12
13import torch
14import torch.nn as nn
15from labml import lab, monit, logger, tracker
16from labml.configs import option
17from labml.logger import Text
18from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
19from labml_nn.helpers.device import DeviceConfigs
20from labml_nn.helpers.metrics import Accuracy
21from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
22from labml_nn.optimizers.configs import OptimizerConfigs
23from torch.utils.data import DataLoader, RandomSampler
26class CrossEntropyLoss(nn.Module):
31 def __init__(self):
32 super().__init__()
33 self.loss = nn.CrossEntropyLoss()
35 def forward(self, outputs, targets):
36 return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.
39class NLPAutoRegressionConfigs(TrainValidConfigs):
Optimizer
50 optimizer: torch.optim.Adam
Training device
52 device: torch.device = DeviceConfigs()
Autoregressive model
55 model: nn.Module
Text dataset
57 text: TextDataset
Batch size
59 batch_size: int = 16
Length of the sequence, or context size
61 seq_len: int = 512
Number of token in vocabulary
63 n_tokens: int
Tokenizer
65 tokenizer: Callable = 'character'
Text prompt to start sampling (for illustration)
68 prompt: str
The token separator when sampling (blank for character level tokenization)
70 prompt_separator: str
Whether to periodically save models
73 is_save_models = True
Loss function
76 loss_func = CrossEntropyLoss()
Accuracy function
78 accuracy = Accuracy()
Model embedding size
80 d_model: int = 512
Gradient clipping
82 grad_norm_clip: float = 1.0
Training data loader
85 train_loader: DataLoader = 'shuffled_train_loader'
Validation data loader
87 valid_loader: DataLoader = 'shuffled_valid_loader'
Data loaders shuffle with replacement
90 dataloader_shuffle_with_replacement: bool = False
Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
95 is_log_model_params_grads: bool = False
Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
100 is_log_model_activations: bool = False
102 def init(self):
Set tracker configurations
107 tracker.set_scalar("accuracy.*", True)
108 tracker.set_scalar("loss.*", True)
109 tracker.set_text("sampled", False)
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
114 self.state_modules = [self.accuracy]
Override to calculate and log other metrics
116 def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
118 pass
120 def step(self, batch: any, batch_idx: BatchIndex):
Set training/eval mode
126 self.model.train(self.mode.is_train)
Move data to the device
129 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
132 if self.mode.is_train:
133 tracker.add_global_step(data.shape[0] * data.shape[1])
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
138 output, *_ = self.model(data)
Calculate and log loss
141 loss = self.loss_func(output, target)
142 tracker.add("loss.", loss)
Calculate and log accuracy
145 self.accuracy(output, target)
146 self.accuracy.track()
147
148 self.other_metrics(output, target)
Train the model
151 if self.mode.is_train:
Calculate gradients
153 loss.backward()
Clip gradients
155 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
157 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
159 if batch_idx.is_last and self.is_log_model_params_grads:
160 tracker.add('model', self.model)
Clear the gradients
162 self.optimizer.zero_grad()
Save the tracked metrics
165 tracker.save()
167 def sample(self):
Starting prompt
173 prompt = self.prompt
Collect output for printing
175 log = [(prompt, Text.subtle)]
Sample 25 tokens
177 for i in monit.iterate('Sample', 25):
Tokenize the prompt
179 data = self.text.text_to_i(prompt).unsqueeze(-1)
180 data = data.to(self.device)
Get the model output
182 output, *_ = self.model(data)
Get the model prediction (greedy)
184 output = output.argmax(dim=-1).squeeze()
Add the prediction to prompt
186 prompt += self.prompt_separator + self.text.itos[output[-1]]
Add the prediction for logging
188 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
189
190 tracker.add({'sampled': prompt})
Print the sampled output
192 logger.log(log)
195@option(NLPAutoRegressionConfigs.optimizer)
196def _optimizer(c: NLPAutoRegressionConfigs):
201 optimizer = OptimizerConfigs()
202 optimizer.parameters = c.model.parameters()
203 optimizer.optimizer = 'Adam'
204 optimizer.d_model = c.d_model
205
206 return optimizer
Get number of tokens
209@option(NLPAutoRegressionConfigs.n_tokens)
210def _n_tokens(c: NLPAutoRegressionConfigs):
214 return c.text.n_tokens
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
in the configurations dictionary when starting the experiment.
217@option(NLPAutoRegressionConfigs.tokenizer)
218def basic_english():
232 from torchtext.data import get_tokenizer
233 return get_tokenizer('basic_english')
236def character_tokenizer(x: str):
240 return list(x)
243@option(NLPAutoRegressionConfigs.tokenizer)
244def character():
248 return character_tokenizer
251@option(NLPAutoRegressionConfigs.text)
252def tiny_shakespeare(c: NLPAutoRegressionConfigs):
258 return TextFileDataset(
259 lab.get_data_path() / 'tiny_shakespeare.txt',
260 c.tokenizer,
261 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
264@option(NLPAutoRegressionConfigs.train_loader)
265def sequential_train_loader(c: NLPAutoRegressionConfigs):
269 return SequentialDataLoader(text=c.text.train,
270 dataset=c.text,
271 batch_size=c.batch_size,
272 seq_len=c.seq_len)
275@option(NLPAutoRegressionConfigs.valid_loader)
276def sequential_valid_loader(c: NLPAutoRegressionConfigs):
280 return SequentialDataLoader(text=c.text.valid,
281 dataset=c.text,
282 batch_size=c.batch_size,
283 seq_len=c.seq_len)
DataLoader
collects the batches on the first dimension. We need to transpose it to be sequence first.
286def transpose_batch(batch):
294 transposed_data = list(zip(*batch))
Stack the batch along the second dimension dim=1
296 src = torch.stack(transposed_data[0], dim=1)
297 tgt = torch.stack(transposed_data[1], dim=1)
298
299 return src, tgt
302@option(NLPAutoRegressionConfigs.train_loader)
303def shuffled_train_loader(c: NLPAutoRegressionConfigs):
307 dataset = SequentialUnBatchedDataset(text=c.text.train,
308 dataset=c.text,
309 seq_len=c.seq_len)
310 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
311
312 return DataLoader(dataset,
313 batch_size=c.batch_size,
314 collate_fn=transpose_batch,
315 sampler=sampler)
318@option(NLPAutoRegressionConfigs.valid_loader)
319def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
323 dataset = SequentialUnBatchedDataset(text=c.text.valid,
324 dataset=c.text,
325 seq_len=c.seq_len)
326 sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
327
328 return DataLoader(dataset,
329 batch_size=c.batch_size,
330 collate_fn=transpose_batch,
331 sampler=sampler)