This is a PyTorch implementation/tutorial of Dynamic Routing Between Capsules.
Capsule network is a neural network architecture that embeds features as capsules and routes them with a voting mechanism to next layer of capsules.
Unlike in other implementations of models, we've included a sample, because it is difficult to understand some concepts with just the modules. This is the annotated code for a model that uses capsules to classify MNIST dataset
This file holds the implementations of the core modules of Capsule Networks.
I used jindongwang/Pytorch-CapsuleNet to clarify some confusions I had with the paper.
Here's a notebook for training a Capsule Network on MNIST dataset.
32import torch.nn as nn
33import torch.nn.functional as F
34import torch.utils.dataThis is squashing function from paper, given by equation .
normalizes the length of all the capsules, whilst shrinks the capsules that have a length smaller than one .
37class Squash(nn.Module):52    def __init__(self, epsilon=1e-8):
53        super().__init__()
54        self.epsilon = epsilon The shape of s
 is [batch_size, n_capsules, n_features]
56    def forward(self, s: torch.Tensor):62        s2 = (s ** 2).sum(dim=-1, keepdims=True)We add an epsilon when calculating  to make sure it doesn't become zero. If this becomes zero it starts giving out nan
 values and training fails.  
68        return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))This is the routing mechanism described in the paper. You can use multiple routing layers in your models.
This combines calculating for this layer and the routing algorithm described in Procedure 1.
71class Router(nn.Module): in_caps
 is the number of capsules, and in_d
 is the number of features per capsule from the layer below. out_caps
 and out_d
 are the same for this layer.
iterations
 is the number of routing iterations, symbolized by  in the paper.
82    def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):89        super().__init__()
90        self.in_caps = in_caps
91        self.out_caps = out_caps
92        self.iterations = iterations
93        self.softmax = nn.Softmax(dim=1)
94        self.squash = Squash()This is the weight matrix . It maps each capsule in the lower layer to each capsule in this layer
98        self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True) The shape of u
 is [batch_size, n_capsules, n_features]
. These are the capsules from the lower layer.
100    def forward(self, u: torch.Tensor):Here is used to index capsules in this layer, whilst is used to index capsules in the layer below (previous).
109        u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)Initial logits are the log prior probabilities that capsule should be coupled with . We initialize these at zero
114        b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
115
116        v = NoneIterate
119        for i in range(self.iterations):routing softmax
121            c = self.softmax(b)123            s = torch.einsum('bij,bijm->bjm', c, u_hat)125            v = self.squash(s)127            a = torch.einsum('bjm,bijm->bij', v, u_hat)129            b = b + a
130
131        return vA separate margin loss is used for each output capsule and the total loss is the sum of them. The length of each output capsule is the probability that class is present in the input.
Loss for each output capsule or class is,
is if the class is present and otherwise. The first component of the loss is when the class is not present, and the second component is if the class is present. The is used to avoid predictions going to extremes. is set to be and to be in the paper.
The down-weighting is used to stop the length of all capsules from falling during the initial phase of training.
134class MarginLoss(nn.Module):155    def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
156        super().__init__()
157
158        self.m_negative = m_negative
159        self.m_positive = m_positive
160        self.lambda_ = lambda_
161        self.n_labels = n_labels v
,  are the squashed output capsules. This has shape [batch_size, n_labels, n_features]
; that is, there is a capsule for each label.
labels
 are the labels, and has shape [batch_size]
.
163    def forward(self, v: torch.Tensor, labels: torch.Tensor):171        v_norm = torch.sqrt((v ** 2).sum(dim=-1)) labels
 is one-hot encoded labels of shape [batch_size, n_labels]
 
175        labels = torch.eye(self.n_labels, device=labels.device)[labels] loss
 has shape [batch_size, n_labels]
. We have parallelized the computation of  for for all . 
181        loss = labels * F.relu(self.m_positive - v_norm) + \
182               self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)185        return loss.sum(dim=-1).mean()