1import torch.nn as nn
2import torch.nn.functional as F
3import torch.utils.data
4
5from labml import experiment, tracker
6from labml.configs import option
7from labml_helpers.datasets.mnist import MNISTConfigs
8from labml_helpers.device import DeviceConfigs
9from labml_helpers.metrics.accuracy import Accuracy
10from labml_helpers.module import Module
11from labml_helpers.seed import SeedConfigs
12from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
13from labml_nn.normalization.batch_norm import BatchNorm
16class Net(Module):
17 def __init__(self):
18 super().__init__()
19 self.conv1 = nn.Conv2d(1, 20, 5, 1)
20 self.bn1 = BatchNorm(20)
21 self.conv2 = nn.Conv2d(20, 50, 5, 1)
22 self.bn2 = BatchNorm(50)
23 self.fc1 = nn.Linear(4 * 4 * 50, 500)
24 self.bn3 = BatchNorm(500)
25 self.fc2 = nn.Linear(500, 10)
27 def __call__(self, x: torch.Tensor):
28 x = F.relu(self.bn1(self.conv1(x)))
29 x = F.max_pool2d(x, 2, 2)
30 x = F.relu(self.bn2(self.conv2(x)))
31 x = F.max_pool2d(x, 2, 2)
32 x = x.view(-1, 4 * 4 * 50)
33 x = F.relu(self.bn3(self.fc1(x)))
34 return self.fc2(x)
37class Configs(MNISTConfigs, TrainValidConfigs):
38 optimizer: torch.optim.Adam
39 model: nn.Module
40 set_seed = SeedConfigs()
41 device: torch.device = DeviceConfigs()
42 epochs: int = 10
43
44 is_save_models = True
45 model: nn.Module
46 inner_iterations = 10
47
48 accuracy_func = Accuracy()
49 loss_func = nn.CrossEntropyLoss()
51 def init(self):
52 tracker.set_queue("loss.*", 20, True)
53 tracker.set_scalar("accuracy.*", True)
54 hook_model_outputs(self.mode, self.model, 'model')
55 self.state_modules = [self.accuracy_func]
57 def step(self, batch: any, batch_idx: BatchIndex):
58 data, target = batch[0].to(self.device), batch[1].to(self.device)
59
60 if self.mode.is_train:
61 tracker.add_global_step(len(data))
62
63 with self.mode.update(is_log_activations=batch_idx.is_last):
64 output = self.model(data)
65
66 loss = self.loss_func(output, target)
67 self.accuracy_func(output, target)
68 tracker.add("loss.", loss)
69
70 if self.mode.is_train:
71 loss.backward()
72
73 self.optimizer.step()
74 if batch_idx.is_last:
75 tracker.add('model', self.model)
76 self.optimizer.zero_grad()
77
78 tracker.save()
81@option(Configs.model)
82def model(c: Configs):
83 return Net().to(c.device)
84
85
86@option(Configs.optimizer)
87def _optimizer(c: Configs):
88 from labml_helpers.optimizer import OptimizerConfigs
89 opt_conf = OptimizerConfigs()
90 opt_conf.parameters = c.model.parameters()
91 return opt_conf
92
93
94def main():
95 conf = Configs()
96 experiment.create(name='mnist_labml_helpers')
97 experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
98 conf.set_seed.set()
99 experiment.add_pytorch_models(dict(model=conf.model))
100 with experiment.start():
101 conf.run()
102
103
104if __name__ == '__main__':
105 main()