11import torch.nn as nn
12import torch.utils.data
13from labml_helpers.module import Module
14
15from labml import tracker
16from labml.configs import option
17from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
18from labml_helpers.device import DeviceConfigs
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
21from labml_nn.optimizers.configs import OptimizerConfigs24class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):Optimizer
32    optimizer: torch.optim.AdamTraining device
34    device: torch.device = DeviceConfigs()Classification model
37    model: ModuleNumber of epochs to train for
39    epochs: int = 10Number of times to switch between training and validation within an epoch
42    inner_iterations = 10Accuracy function
45    accuracy = Accuracy()Loss function
47    loss_func = nn.CrossEntropyLoss()49    def init(self):Set tracker configurations
54        tracker.set_scalar("loss.*", True)
55        tracker.set_scalar("accuracy.*", True)Add a hook to log module outputs
57        hook_model_outputs(self.mode, self.model, 'model')Add accuracy as a state module. The name is probably confusing, since it’s meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
62        self.state_modules = [self.accuracy]64    def step(self, batch: any, batch_idx: BatchIndex):Move data to the device
70        data, target = batch[0].to(self.device), batch[1].to(self.device)Update global step (number of samples processed) when in training mode
73        if self.mode.is_train:
74            tracker.add_global_step(len(data))Whether to capture model outputs
77        with self.mode.update(is_log_activations=batch_idx.is_last):Get model outputs.
79            output = self.model(data)Calculate and log loss
82        loss = self.loss_func(output, target)
83        tracker.add("loss.", loss)Calculate and log accuracy
86        self.accuracy(output, target)
87        self.accuracy.track()Train the model
90        if self.mode.is_train:Calculate gradients
92            loss.backward()Take optimizer step
94            self.optimizer.step()Log the model parameters and gradients on last batch of every epoch
96            if batch_idx.is_last:
97                tracker.add('model', self.model)Clear the gradients
99            self.optimizer.zero_grad()Save the tracked metrics
102        tracker.save()105@option(MNISTConfigs.optimizer)
106def _optimizer(c: MNISTConfigs):110    opt_conf = OptimizerConfigs()
111    opt_conf.parameters = c.model.parameters()
112    opt_conf.optimizer = 'Adam'
113    return opt_conf