1import torch.nn as nn
2import torch.nn.functional as F
3import torch.utils.data
4
5from labml import experiment, tracker
6from labml.configs import option
7from labml_helpers.datasets.mnist import MNISTConfigs
8from labml_helpers.device import DeviceConfigs
9from labml_helpers.metrics.accuracy import Accuracy
10from labml_helpers.module import Module
11from labml_helpers.seed import SeedConfigs
12from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
13from labml_nn.normalization.batch_norm import BatchNorm
16class Net(Module):
17    def __init__(self):
18        super().__init__()
19        self.conv1 = nn.Conv2d(1, 20, 5, 1)
20        self.bn1 = BatchNorm(20)
21        self.conv2 = nn.Conv2d(20, 50, 5, 1)
22        self.bn2 = BatchNorm(50)
23        self.fc1 = nn.Linear(4 * 4 * 50, 500)
24        self.bn3 = BatchNorm(500)
25        self.fc2 = nn.Linear(500, 10)
27    def __call__(self, x: torch.Tensor):
28        x = F.relu(self.bn1(self.conv1(x)))
29        x = F.max_pool2d(x, 2, 2)
30        x = F.relu(self.bn2(self.conv2(x)))
31        x = F.max_pool2d(x, 2, 2)
32        x = x.view(-1, 4 * 4 * 50)
33        x = F.relu(self.bn3(self.fc1(x)))
34        return self.fc2(x)
37class Configs(MNISTConfigs, TrainValidConfigs):
38    optimizer: torch.optim.Adam
39    model: nn.Module
40    set_seed = SeedConfigs()
41    device: torch.device = DeviceConfigs()
42    epochs: int = 10
43
44    is_save_models = True
45    model: nn.Module
46    inner_iterations = 10
47
48    accuracy_func = Accuracy()
49    loss_func = nn.CrossEntropyLoss()
51    def init(self):
52        tracker.set_queue("loss.*", 20, True)
53        tracker.set_scalar("accuracy.*", True)
54        hook_model_outputs(self.mode, self.model, 'model')
55        self.state_modules = [self.accuracy_func]
57    def step(self, batch: any, batch_idx: BatchIndex):
58        data, target = batch[0].to(self.device), batch[1].to(self.device)
59
60        if self.mode.is_train:
61            tracker.add_global_step(len(data))
62
63        with self.mode.update(is_log_activations=batch_idx.is_last):
64            output = self.model(data)
65
66        loss = self.loss_func(output, target)
67        self.accuracy_func(output, target)
68        tracker.add("loss.", loss)
69
70        if self.mode.is_train:
71            loss.backward()
72
73            self.optimizer.step()
74            if batch_idx.is_last:
75                tracker.add('model', self.model)
76            self.optimizer.zero_grad()
77
78        tracker.save()
81@option(Configs.model)
82def model(c: Configs):
83    return Net().to(c.device)
84
85
86@option(Configs.optimizer)
87def _optimizer(c: Configs):
88    from labml_helpers.optimizer import OptimizerConfigs
89    opt_conf = OptimizerConfigs()
90    opt_conf.parameters = c.model.parameters()
91    return opt_conf
92
93
94def main():
95    conf = Configs()
96    experiment.create(name='mnist_labml_helpers')
97    experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
98    conf.set_seed.set()
99    experiment.add_pytorch_models(dict(model=conf.model))
100    with experiment.start():
101        conf.run()
102
103
104if __name__ == '__main__':
105    main()