Files
Varuna Jayasiri c4d2e8cd22 docs
2025-07-31 08:48:07 +05:30

585 lines
46 KiB
HTML
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Code for training Capsule Networks on MNIST dataset"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Classify MNIST digits with Capsule Networks"/>
<meta name="twitter:description" content="Code for training Capsule Networks on MNIST dataset"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/capsule_networks/mnist.html"/>
<meta property="og:title" content="Classify MNIST digits with Capsule Networks"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Classify MNIST digits with Capsule Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Classify MNIST digits with Capsule Networks"/>
<meta property="og:description" content="Code for training Capsule Networks on MNIST dataset"/>
<title>Classify MNIST digits with Capsule Networks</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/capsule_networks/mnist.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">capsule_networks</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/capsule_networks/mnist.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Classify MNIST digits with Capsule Networks</h1>
<p>This is an annotated PyTorch code to classify MNIST digits with PyTorch.</p>
<p>This paper implements the experiment described in paper <a href="https://arxiv.org/abs/1710.09829">Dynamic Routing Between Capsules</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
<span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">17</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
<span class="lineno">18</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">tracker</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.capsule_networks</span> <span class="kn">import</span> <span class="n">Squash</span><span class="p">,</span> <span class="n">Router</span><span class="p">,</span> <span class="n">MarginLoss</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.datasets</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.metrics</span> <span class="kn">import</span> <span class="n">AccuracyDirect</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.trainer</span> <span class="kn">import</span> <span class="n">SimpleTrainValidConfigs</span><span class="p">,</span> <span class="n">BatchIndex</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Model for classifying MNIST digits</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">MNISTCapsuleNetworkModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">33</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>First convolution layer has <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style="">2</span></span><span class="mord">56</span></span></span></span></span>, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">9</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">9</span></span></span></span></span> convolution kernels </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>The second layer (Primary Capsules) s a convolutional capsule layer with <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord coloredeq eqm" style="">3</span></span><span class="mord" style=""><span class="mord coloredeq eql" style="">2</span></span></span></span></span></span></span> channels of convolutional <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style="">8</span></span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span></span></span></span></span> capsules (<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style="">8</span></span></span></span></span></span> features per capsule). That is, each primary capsule contains 8 convolutional units with a 9 × 9 kernel and a stride of 2. In order to implement this we create a convolutional layer with <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord coloredeq eqm" style="">3</span></span><span class="mord" style=""><span class="mord coloredeq eql" style="">2</span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style="">8</span></span></span></span></span></span> channels and reshape and permutate its output to get the capsules of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style="">8</span></span></span></span></span></span> features each. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">squash</span> <span class="o">=</span> <span class="n">Squash</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Routing layer gets the <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord coloredeq eqm" style="">3</span></span><span class="mord" style=""><span class="mord coloredeq eql" style="">2</span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">6</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">6</span></span></span></span></span> primary capsules and produces <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">10</span></span></span></span></span></span> capsules. Each of the primary capsules have <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style="">8</span></span></span></span></span></span> features, while output capsules (Digit Capsules) have <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style="">16</span></span></span></span></span></span> features. The routing algorithm iterates <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqm" style=""><span class="mord" style="">3</span></span></span></span></span></span> times. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span> <span class="bp">self</span><span class="o">.</span><span class="n">digit_capsules</span> <span class="o">=</span> <span class="n">Router</span><span class="p">(</span><span class="mi">32</span> <span class="o">*</span> <span class="mi">6</span> <span class="o">*</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>This is the decoder mentioned in the paper. It takes the outputs of the <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">10</span></span></span></span></span></span> digit capsules, each with <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style="">16</span></span></span></span></span></span> features to reproduce the image. It goes through linear layers of sizes <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">51</span><span class="mord coloredeq eql" style=""><span class="mord" style="">2</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style="">10</span></span><span class="mord coloredeq eql" style=""><span class="mord" style="">2</span></span><span class="mord">4</span></span></span></span></span> with <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.00773em;">R</span><span class="mord mathnormal">e</span><span class="mord mathnormal" style="margin-right:0.10903em;">LU</span></span></span></span></span> activations. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="lineno">54</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">16</span> <span class="o">*</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span>
<span class="lineno">55</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="lineno">56</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1024</span><span class="p">),</span>
<span class="lineno">57</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="lineno">58</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">1024</span><span class="p">,</span> <span class="mi">784</span><span class="p">),</span>
<span class="lineno">59</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()</span>
<span class="lineno">60</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p> <code class="highlight"><span></span><span class="n">data</span></code>
are the MNIST images, with shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Pass through the first convolution layer. Output of this layer has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">20</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Pass through the second convolution layer. Output of this has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">]</span></code>
. <em>Note that this layer has a stride length of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style="">2</span></span></span></span></span></span></em>. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Resize and permutate to get the capsules </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">caps</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">6</span> <span class="o">*</span> <span class="mi">6</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Squash the capsules </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="n">caps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">squash</span><span class="p">(</span><span class="n">caps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Take them through the router to get digit capsules. This has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">16</span><span class="p">]</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">caps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">digit_capsules</span><span class="p">(</span><span class="n">caps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Get masks for reconstructioon </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>The prediction by the capsule network is the capsule with longest length </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">pred</span> <span class="o">=</span> <span class="p">(</span><span class="n">caps</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Create a mask to maskout all the other capsules </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">data</span><span class="o">.</span><span class="n">device</span><span class="p">)[</span><span class="n">pred</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Mask the digit capsules to get only the capsule that made the prediction and take it through decoder to get reconstruction </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">reconstructions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">((</span><span class="n">caps</span> <span class="o">*</span> <span class="n">mask</span><span class="p">[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Reshape the reconstruction to match the image dimensions </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">reconstructions</span> <span class="o">=</span> <span class="n">reconstructions</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">)</span>
<span class="lineno">94</span>
<span class="lineno">95</span> <span class="k">return</span> <span class="n">caps</span><span class="p">,</span> <span class="n">reconstructions</span><span class="p">,</span> <span class="n">pred</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p> Configurations with MNIST data and Train &amp; Validation setup</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="p">,</span> <span class="n">SimpleTrainValidConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">102</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span>
<span class="lineno">103</span> <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;capsule_network_model&#39;</span>
<span class="lineno">104</span> <span class="n">reconstruction_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MSELoss</span><span class="p">()</span>
<span class="lineno">105</span> <span class="n">margin_loss</span> <span class="o">=</span> <span class="n">MarginLoss</span><span class="p">(</span><span class="n">n_labels</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="lineno">106</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">AccuracyDirect</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Print losses and accuracy to screen </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s1">&#39;loss.*&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">111</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s1">&#39;accuracy.*&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>We need to set the metrics to calculate them for the epoch for training and validation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p> This method gets called by the trainer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Set the model mode </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Get the images and labels and move them to the model&#x27;s device </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Increment step in training mode </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">128</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Run the model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="n">caps</span><span class="p">,</span> <span class="n">reconstructions</span><span class="p">,</span> <span class="n">pred</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Calculate the total loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">margin_loss</span><span class="p">(</span><span class="n">caps</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.0005</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">reconstruction_loss</span><span class="p">(</span><span class="n">reconstructions</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span>
<span class="lineno">135</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Call accuracy metric </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<span class="lineno">139</span>
<span class="lineno">140</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">141</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="lineno">142</span>
<span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Log parameters and gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
<span class="lineno">146</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;model&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">147</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="lineno">148</span>
<span class="lineno">149</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Set the model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">153</span><span class="k">def</span> <span class="nf">capsule_network_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">155</span> <span class="k">return</span> <span class="n">MNISTCapsuleNetworkModel</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p> Run the experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;capsule_network_mnist&#39;</span><span class="p">)</span>
<span class="lineno">163</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span>
<span class="lineno">164</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">165</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">1e-3</span><span class="p">})</span>
<span class="lineno">166</span>
<span class="lineno">167</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span>
<span class="lineno">168</span>
<span class="lineno">169</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">170</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
<span class="lineno">171</span>
<span class="lineno">172</span>
<span class="lineno">173</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">174</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>