MNIST Experiment

11import torch.nn as nn
12import torch.utils.data
13
14from labml import tracker
15from labml.configs import option
16from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
17from labml_nn.helpers.device  import DeviceConfigs
18from labml_nn.helpers.metrics  import Accuracy
19from labml_nn.helpers.trainer  import TrainValidConfigs, BatchIndex
20from labml_nn.optimizers.configs import OptimizerConfigs

Trainer configurations

23class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):

Optimizer

31    optimizer: torch.optim.Adam

Training device

33    device: torch.device = DeviceConfigs()

Classification model

36    model: nn.Module

Number of epochs to train for

38    epochs: int = 10

Number of times to switch between training and validation within an epoch

41    inner_iterations = 10

Accuracy function

44    accuracy = Accuracy()

Loss function

46    loss_func = nn.CrossEntropyLoss()

Initialization

48    def init(self):

Set tracker configurations

53        tracker.set_scalar("loss.*", True)
54        tracker.set_scalar("accuracy.*", True)

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

59        self.state_modules = [self.accuracy]

Training or validation step

61    def step(self, batch: any, batch_idx: BatchIndex):

Training/Evaluation mode

67        self.model.train(self.mode.is_train)

Move data to the device

70        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of samples processed) when in training mode

73        if self.mode.is_train:
74            tracker.add_global_step(len(data))

Get model outputs.

77        output = self.model(data)

Calculate and log loss

80        loss = self.loss_func(output, target)
81        tracker.add("loss.", loss)

Calculate and log accuracy

84        self.accuracy(output, target)
85        self.accuracy.track()

Train the model

88        if self.mode.is_train:

Calculate gradients

90            loss.backward()

Take optimizer step

92            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

94            if batch_idx.is_last:
95                tracker.add('model', self.model)

Clear the gradients

97            self.optimizer.zero_grad()

Save the tracked metrics

100        tracker.save()

Default optimizer configurations

103@option(MNISTConfigs.optimizer)
104def _optimizer(c: MNISTConfigs):
108    opt_conf = OptimizerConfigs()
109    opt_conf.parameters = c.model.parameters()
110    opt_conf.optimizer = 'Adam'
111    return opt_conf