Classify MNIST digits with Capsule Networks

This is an annotated PyTorch code to classify MNIST digits with PyTorch.

This paper implements the experiment described in paper Dynamic Routing Between Capsules.

14from typing import Any
15
16import torch.nn as nn
17import torch.nn.functional as F
18import torch.utils.data
19from labml import experiment, tracker
20from labml.configs import option
21from labml_nn.capsule_networks import Squash, Router, MarginLoss
22from labml_nn.helpers.datasets import MNISTConfigs
23from labml_nn.helpers.metrics import AccuracyDirect
24from labml_nn.helpers.trainer import SimpleTrainValidConfigs, BatchIndex

Model for classifying MNIST digits

27class MNISTCapsuleNetworkModel(nn.Module):
32    def __init__(self):
33        super().__init__()

First convolution layer has , convolution kernels

35        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)

The second layer (Primary Capsules) s a convolutional capsule layer with channels of convolutional capsules ( features per capsule). That is, each primary capsule contains 8 convolutional units with a 9 × 9 kernel and a stride of 2. In order to implement this we create a convolutional layer with channels and reshape and permutate its output to get the capsules of features each.

41        self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
42        self.squash = Squash()

Routing layer gets the primary capsules and produces capsules. Each of the primary capsules have features, while output capsules (Digit Capsules) have features. The routing algorithm iterates times.

48        self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)

This is the decoder mentioned in the paper. It takes the outputs of the digit capsules, each with features to reproduce the image. It goes through linear layers of sizes and with activations.

53        self.decoder = nn.Sequential(
54            nn.Linear(16 * 10, 512),
55            nn.ReLU(),
56            nn.Linear(512, 1024),
57            nn.ReLU(),
58            nn.Linear(1024, 784),
59            nn.Sigmoid()
60        )

data are the MNIST images, with shape [batch_size, 1, 28, 28]

62    def forward(self, data: torch.Tensor):

Pass through the first convolution layer. Output of this layer has shape [batch_size, 256, 20, 20]

68        x = F.relu(self.conv1(data))

Pass through the second convolution layer. Output of this has shape [batch_size, 32 * 8, 6, 6] . Note that this layer has a stride length of .

72        x = self.conv2(x)

Resize and permutate to get the capsules

75        caps = x.view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)

Squash the capsules

77        caps = self.squash(caps)

Take them through the router to get digit capsules. This has shape [batch_size, 10, 16] .

80        caps = self.digit_capsules(caps)

Get masks for reconstructioon

83        with torch.no_grad():

The prediction by the capsule network is the capsule with longest length

85            pred = (caps ** 2).sum(-1).argmax(-1)

Create a mask to maskout all the other capsules

87            mask = torch.eye(10, device=data.device)[pred]

Mask the digit capsules to get only the capsule that made the prediction and take it through decoder to get reconstruction

91        reconstructions = self.decoder((caps * mask[:, :, None]).view(x.shape[0], -1))

Reshape the reconstruction to match the image dimensions

93        reconstructions = reconstructions.view(-1, 1, 28, 28)
94
95        return caps, reconstructions, pred

Configurations with MNIST data and Train & Validation setup

98class Configs(MNISTConfigs, SimpleTrainValidConfigs):
102    epochs: int = 10
103    model: nn.Module = 'capsule_network_model'
104    reconstruction_loss = nn.MSELoss()
105    margin_loss = MarginLoss(n_labels=10)
106    accuracy = AccuracyDirect()
108    def init(self):

Print losses and accuracy to screen

110        tracker.set_scalar('loss.*', True)
111        tracker.set_scalar('accuracy.*', True)

We need to set the metrics to calculate them for the epoch for training and validation

114        self.state_modules = [self.accuracy]

This method gets called by the trainer

116    def step(self, batch: Any, batch_idx: BatchIndex):

Set the model mode

121        self.model.train(self.mode.is_train)

Get the images and labels and move them to the model's device

124        data, target = batch[0].to(self.device), batch[1].to(self.device)

Increment step in training mode

127        if self.mode.is_train:
128            tracker.add_global_step(len(data))

Run the model

131        caps, reconstructions, pred = self.model(data)

Calculate the total loss

134        loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
135        tracker.add("loss.", loss)

Call accuracy metric

138        self.accuracy(pred, target)
139
140        if self.mode.is_train:
141            loss.backward()
142
143            self.optimizer.step()

Log parameters and gradients

145            if batch_idx.is_last:
146                tracker.add('model', self.model)
147            self.optimizer.zero_grad()
148
149            tracker.save()

Set the model

152@option(Configs.model)
153def capsule_network_model(c: Configs):
155    return MNISTCapsuleNetworkModel().to(c.device)

Run the experiment

158def main():
162    experiment.create(name='capsule_network_mnist')
163    conf = Configs()
164    experiment.configs(conf, {'optimizer.optimizer': 'Adam',
165                              'optimizer.learning_rate': 1e-3})
166
167    experiment.add_pytorch_models({'model': conf.model})
168
169    with experiment.start():
170        conf.run()
171
172
173if __name__ == '__main__':
174    main()