This is an annotated PyTorch code to classify MNIST digits with PyTorch.
This paper implements the experiment described in paper Dynamic Routing Between Capsules.
14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19
20from labml import experiment, tracker
21from labml.configs import option
22from labml_helpers.datasets.mnist import MNISTConfigs
23from labml_helpers.metrics.accuracy import AccuracyDirect
24from labml_helpers.module import Module
25from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
26from labml_nn.capsule_networks import Squash, Router, MarginLoss29class MNISTCapsuleNetworkModel(Module):34    def __init__(self):
35        super().__init__()First convolution layer has $256$, $9 \times 9$ convolution kernels
37        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)The second layer (Primary Capsules) s a convolutional capsule layer with $32$ channels of convolutional $8D$ capsules ($8$ features per capsule) That is, each primary capsule contains 8 convolutional units with a 9 × 9 kernel and a stride of 2. In order to implement this we create a convolutional layer with $32 \times 8$ channels and reshapes and permutate it’s output to get the capsules of $8$ features each
43        self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
44        self.squash = Squash()Routing layer gets the $32 \times 6 \times 6$ primary capsules and produces $10$ capsules. Each of the primary capsules have $8$ features, while output capsules (Digit Capsules) have $16$ features. The routing algorithm iterates $3$ times.
50        self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)This is the decoder mentioned in the paper. It takes the outputs of the $10$ digit capsules, each with $16$ features to reproduce the image. It goes through linear layers of sizes $512% and $1024$ with $ReLU$ activations.
55        self.decoder = nn.Sequential(
56            nn.Linear(16 * 10, 512),
57            nn.ReLU(),
58            nn.Linear(512, 1024),
59            nn.ReLU(),
60            nn.Linear(1024, 784),
61            nn.Sigmoid()
62        )data are the MNIST images, with shape [batch_size, 1, 28, 28]
64    def forward(self, data: torch.Tensor):Pass through the first convolution layer.
Output of this layer has shape [batch_size, 256, 20, 20]
70        x = F.relu(self.conv1(data))Pass through the second convolution layer.
Output of this has shape [batch_size, 32 * 8, 6, 6].
*Note that this layer has a stride length of $2$.
74        x = self.conv2(x)Resize and permutate to get the capsules
77        caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)Squash the capsules
79        caps = self.squash(caps)Take them through the router to get digit capsules.
This has shape [batch_size, 10, 16].
82        caps = self.digit_capsules(caps)Get masks for reconstructioon
85        with torch.no_grad():The prediction by the capsule network is the capsule with longest length
87            pred = (caps ** 2).sum(-1).argmax(-1)Create a mask to maskout all the other capsules
89            mask = torch.eye(10, device=data.device)[pred]Mask the digit capsules to get only the capsule that made the prediction and take it through decoder to get reconstruction
93        reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))Reshape the reconstruction to match the image dimensions
95        reconstructions = reconstructions.view(-1, 1, 28, 28)
96
97        return caps, reconstructions, predConfigurations with MNIST data and Train & Validation setup
100class Configs(MNISTConfigs, SimpleTrainValidConfigs):104    epochs: int = 10
105    model: nn.Module = 'capsule_network_model'
106    reconstruction_loss = nn.MSELoss()
107    margin_loss = MarginLoss(n_labels=10)
108    accuracy = AccuracyDirect()110    def init(self):Print losses and accuracy to screen
112        tracker.set_scalar('loss.*', True)
113        tracker.set_scalar('accuracy.*', True)We need to set the metrics calculate them for the epoch for training and validation
116        self.state_modules = [self.accuracy]This method gets called by the trainer
118    def step(self, batch: Any, batch_idx: BatchIndex):Set the model mode
123        self.model.train(self.mode.is_train)Get the images and labels and move them to the model’s device
126        data, target = batch[0].to(self.device), batch[1].to(self.device)Increment step in training mode
129        if self.mode.is_train:
130            tracker.add_global_step(len(data))Whether to log activations
133        with self.mode.update(is_log_activations=batch_idx.is_last):Run the model
135            caps, reconstructions, pred = self.model(data)Calculate the total loss
138        loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
139        tracker.add("loss.", loss)Call accuracy metric
142        self.accuracy(pred, target)
143
144        if self.mode.is_train:
145            loss.backward()
146
147            self.optimizer.step()Log parameters and gradients
149            if batch_idx.is_last:
150                tracker.add('model', self.model)
151            self.optimizer.zero_grad()
152
153            tracker.save()Set the model
156@option(Configs.model)
157def capsule_network_model(c: Configs):159    return MNISTCapsuleNetworkModel().to(c.device)Run the experiment
162def main():166    experiment.create(name='capsule_network_mnist')
167    conf = Configs()
168    experiment.configs(conf, {'optimizer.optimizer': 'Adam',
169                              'optimizer.learning_rate': 1e-3})
170
171    experiment.add_pytorch_models({'model': conf.model})
172
173    with experiment.start():
174        conf.run()
175
176
177if __name__ == '__main__':
178    main()