11import torch.nn as nn
12import torch.utils.data
13
14from labml import tracker
15from labml.configs import option
16from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
17from labml_nn.helpers.device import DeviceConfigs
18from labml_nn.helpers.metrics import Accuracy
19from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
20from labml_nn.optimizers.configs import OptimizerConfigs
23class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
Optimizer
31 optimizer: torch.optim.Adam
Training device
33 device: torch.device = DeviceConfigs()
Classification model
36 model: nn.Module
Number of epochs to train for
38 epochs: int = 10
Number of times to switch between training and validation within an epoch
41 inner_iterations = 10
Accuracy function
44 accuracy = Accuracy()
Loss function
46 loss_func = nn.CrossEntropyLoss()
48 def init(self):
Set tracker configurations
53 tracker.set_scalar("loss.*", True)
54 tracker.set_scalar("accuracy.*", True)
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
59 self.state_modules = [self.accuracy]
61 def step(self, batch: any, batch_idx: BatchIndex):
Training/Evaluation mode
67 self.model.train(self.mode.is_train)
Move data to the device
70 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of samples processed) when in training mode
73 if self.mode.is_train:
74 tracker.add_global_step(len(data))
Get model outputs.
77 output = self.model(data)
Calculate and log loss
80 loss = self.loss_func(output, target)
81 tracker.add("loss.", loss)
Calculate and log accuracy
84 self.accuracy(output, target)
85 self.accuracy.track()
Train the model
88 if self.mode.is_train:
Calculate gradients
90 loss.backward()
Take optimizer step
92 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
94 if batch_idx.is_last:
95 tracker.add('model', self.model)
Clear the gradients
97 self.optimizer.zero_grad()
Save the tracked metrics
100 tracker.save()
103@option(MNISTConfigs.optimizer)
104def _optimizer(c: MNISTConfigs):
108 opt_conf = OptimizerConfigs()
109 opt_conf.parameters = c.model.parameters()
110 opt_conf.optimizer = 'Adam'
111 return opt_conf