11import torch.nn as nn
12import torch.utils.data
13
14from labml import tracker
15from labml.configs import option
16from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
17from labml_nn.helpers.device  import DeviceConfigs
18from labml_nn.helpers.metrics  import Accuracy
19from labml_nn.helpers.trainer  import TrainValidConfigs, BatchIndex
20from labml_nn.optimizers.configs import OptimizerConfigs23class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):Optimizer
31    optimizer: torch.optim.AdamTraining device
33    device: torch.device = DeviceConfigs()Classification model
36    model: nn.ModuleNumber of epochs to train for
38    epochs: int = 10Number of times to switch between training and validation within an epoch
41    inner_iterations = 10Accuracy function
44    accuracy = Accuracy()Loss function
46    loss_func = nn.CrossEntropyLoss()48    def init(self):Set tracker configurations
53        tracker.set_scalar("loss.*", True)
54        tracker.set_scalar("accuracy.*", True)Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
59        self.state_modules = [self.accuracy]61    def step(self, batch: any, batch_idx: BatchIndex):Training/Evaluation mode
67        self.model.train(self.mode.is_train)Move data to the device
70        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of samples processed) when in training mode
73        if self.mode.is_train:
74            tracker.add_global_step(len(data))Get model outputs.
77        output = self.model(data)Calculate and log loss
80        loss = self.loss_func(output, target)
81        tracker.add("loss.", loss)Calculate and log accuracy
84        self.accuracy(output, target)
85        self.accuracy.track()Train the model
88        if self.mode.is_train:Calculate gradients
90            loss.backward()Take optimizer step
92            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
94            if batch_idx.is_last:
95                tracker.add('model', self.model)Clear the gradients
97            self.optimizer.zero_grad()Save the tracked metrics
100        tracker.save()103@option(MNISTConfigs.optimizer)
104def _optimizer(c: MNISTConfigs):108    opt_conf = OptimizerConfigs()
109    opt_conf.parameters = c.model.parameters()
110    opt_conf.optimizer = 'Adam'
111    return opt_conf