9import torch.nn as nn
10import torch.utils.data
11
12from labml import experiment, tracker
13from labml.configs import option
14from labml_nn.helpers.datasets import MNISTConfigs
15from labml_nn.helpers.device import DeviceConfigs
16from labml_nn.helpers.metrics import Accuracy
17from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
18from labml_nn.optimizers.configs import OptimizerConfigs21class Model(nn.Module):26    def __init__(self):
27        super().__init__()
28        self.conv1 = nn.Conv2d(1, 20, 5, 1)
29        self.pool1 = nn.MaxPool2d(2)
30        self.conv2 = nn.Conv2d(20, 50, 5, 1)
31        self.pool2 = nn.MaxPool2d(2)
32        self.fc1 = nn.Linear(16 * 50, 500)
33        self.fc2 = nn.Linear(500, 10)
34        self.activation = nn.ReLU()36    def forward(self, x):
37        x = self.activation(self.conv1(x))
38        x = self.pool1(x)
39        x = self.activation(self.conv2(x))
40        x = self.pool2(x)
41        x = self.activation(self.fc1(x.view(-1, 16 * 50)))
42        return self.fc2(x)45class Configs(MNISTConfigs, TrainValidConfigs):49    optimizer: torch.optim.Adam
50    model: nn.Module
51    device: torch.device = DeviceConfigs()
52    epochs: int = 10
53
54    is_save_models = True
55    model: nn.Module
56    inner_iterations = 10
57
58    accuracy_func = Accuracy()
59    loss_func = nn.CrossEntropyLoss()61    def init(self):
62        tracker.set_queue("loss.*", 20, True)
63        tracker.set_scalar("accuracy.*", True)
64        self.state_modules = [self.accuracy_func]66    def step(self, batch: any, batch_idx: BatchIndex):Get the batch
68        data, target = batch[0].to(self.device), batch[1].to(self.device)Add global step if we are in training mode
71        if self.mode.is_train:
72            tracker.add_global_step(len(data))Run the model
75        output = self.model(data)Calculate the loss
78        loss = self.loss_func(output, target)Calculate the accuracy
80        self.accuracy_func(output, target)Log the loss
82        tracker.add("loss.", loss)Optimize if we are in training mode
85        if self.mode.is_train:Calculate the gradients
87            loss.backward()Take optimizer step
90            self.optimizer.step()Log the parameter and gradient L2 norms once per epoch
92            if batch_idx.is_last:
93                tracker.add('model', self.model)
94                tracker.add('optimizer', (self.optimizer, {'model': self.model}))Clear the gradients
96            self.optimizer.zero_grad()Save logs
99        tracker.save()Create a configurable optimizer. We can change the optimizer type and hyper-parameters using configurations.
102@option(Configs.model)
103def model(c: Configs):
104    return Model().to(c.device)
105
106
107@option(Configs.optimizer)
108def _optimizer(c: Configs):113    opt_conf = OptimizerConfigs()
114    opt_conf.parameters = c.model.parameters()
115    return opt_conf118def main():
119    conf = Configs()
120    conf.inner_iterations = 10
121    experiment.create(name='mnist_ada_belief')
122    experiment.configs(conf, {'inner_iterations': 10,Specify the optimizer
124                              'optimizer.optimizer': 'Adam',
125                              'optimizer.learning_rate': 1.5e-4})
126    experiment.add_pytorch_models(dict(model=conf.model))
127    with experiment.start():
128        conf.run()
129
130
131if __name__ == '__main__':
132    main()