Files
Varuna Jayasiri f1fe7087f1 footer
2021-08-19 15:21:18 +05:30

509 lines
32 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="PyTorch implementation and tutorial of Capsule Networks. Capsule network is a neural network architecture that embeds features as capsules and routes them with a voting mechanism to next layer of capsules."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Capsule Networks"/>
<meta name="twitter:description" content="PyTorch implementation and tutorial of Capsule Networks. Capsule network is a neural network architecture that embeds features as capsules and routes them with a voting mechanism to next layer of capsules."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/capsule_networks/index.html"/>
<meta property="og:title" content="Capsule Networks"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Capsule Networks"/>
<meta property="og:description" content="PyTorch implementation and tutorial of Capsule Networks. Capsule network is a neural network architecture that embeds features as capsules and routes them with a voting mechanism to next layer of capsules."/>
<title>Capsule Networks</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/capsule_networks/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">capsule_networks</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/capsule_networks/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Capsule Networks</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of
<a href="https://papers.labml.ai/paper/1710.09829">Dynamic Routing Between Capsules</a>.</p>
<p>Capsule network is a neural network architecture that embeds features
as capsules and routes them with a voting mechanism to next layer of capsules.</p>
<p>Unlike in other implementations of models, we&rsquo;ve included a sample, because
it is difficult to understand some concepts with just the modules.
<a href="mnist.html">This is the annotated code for a model that uses capsules to classify MNIST dataset</a></p>
<p>This file holds the implementations of the core modules of Capsule Networks.</p>
<p>I used <a href="https://github.com/jindongwang/Pytorch-CapsuleNet">jindongwang/Pytorch-CapsuleNet</a> to clarify some
confusions I had with the paper.</p>
<p>Here&rsquo;s a notebook for training a Capsule Network on MNIST dataset.</p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/capsule_networks/mnist.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/e7c08e08586711ebb3e30242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span></span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">34</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">36</span>
<span class="lineno">37</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Squash</h2>
<p>This is <strong>squashing</strong> function from paper, given by equation $(1)$.</p>
<p>
<script type="math/tex; mode=display">\mathbf{v}_j = \frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2}
\frac{\mathbf{s}_j}{\lVert \mathbf{s}_j \rVert}</script>
</p>
<p>$\frac{\mathbf{s}_j}{\lVert \mathbf{s}_j \rVert}$
normalizes the length of all the capsules, whilst
$\frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2}$
shrinks the capsules that have a length smaller than one .</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span><span class="k">class</span> <span class="nc">Squash</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">):</span>
<span class="lineno">56</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">epsilon</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>The shape of <code>s</code> is <code>[batch_size, n_capsules, n_features]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">s</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>${\lVert \mathbf{s}_j \rVert}^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">s2</span> <span class="o">=</span> <span class="p">(</span><span class="n">s</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>We add an epsilon when calculating $\lVert \mathbf{s}_j \rVert$ to make sure it doesn&rsquo;t become zero.
If this becomes zero it starts giving out <code>nan</code> values and training fails.
<script type="math/tex; mode=display">\mathbf{v}_j = \frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2}
\frac{\mathbf{s}_j}{\sqrt{{\lVert \mathbf{s}_j \rVert}^2 + \epsilon}}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">return</span> <span class="p">(</span><span class="n">s2</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">s2</span><span class="p">))</span> <span class="o">*</span> <span class="p">(</span><span class="n">s</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">s2</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">epsilon</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<h2>Routing Algorithm</h2>
<p>This is the routing mechanism described in the paper.
You can use multiple routing layers in your models.</p>
<p>This combines calculating $\mathbf{s}_j$ for this layer and
the routing algorithm described in <em>Procedure 1</em>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span><span class="k">class</span> <span class="nc">Router</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p><code>in_caps</code> is the number of capsules, and <code>in_d</code> is the number of features per capsule from the layer below.
<code>out_caps</code> and <code>out_d</code> are the same for this layer.</p>
<p><code>iterations</code> is the number of routing iterations, symbolized by $r$ in the paper.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_caps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_caps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">in_d</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_d</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">iterations</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_caps</span> <span class="o">=</span> <span class="n">in_caps</span>
<span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_caps</span> <span class="o">=</span> <span class="n">out_caps</span>
<span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">iterations</span> <span class="o">=</span> <span class="n">iterations</span>
<span class="lineno">96</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">squash</span> <span class="o">=</span> <span class="n">Squash</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>This is the weight matrix $\mathbf{W}_{ij}$. It maps each capsule in the
lower layer to each capsule in this layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">in_caps</span><span class="p">,</span> <span class="n">out_caps</span><span class="p">,</span> <span class="n">in_d</span><span class="p">,</span> <span class="n">out_d</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>The shape of <code>u</code> is <code>[batch_size, n_capsules, n_features]</code>.
These are the capsules from the lower layer.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">u</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\hat{\mathbf{u}}_{j|i} = \mathbf{W}_{ij} \mathbf{u}_i</script>
Here $j$ is used to index capsules in this layer, whilst $i$ is
used to index capsules in the layer below (previous).</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="n">u_hat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ijnm,bin-&gt;bijm&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">u</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Initial logits $b_{ij}$ are the log prior probabilities that capsule $i$
should be coupled with $j$.
We initialize these at zero</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">b</span> <span class="o">=</span> <span class="n">u</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">u</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_caps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_caps</span><span class="p">)</span>
<span class="lineno">118</span>
<span class="lineno">119</span> <span class="n">v</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Iterate</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">iterations</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>routing softmax <script type="math/tex; mode=display">c_{ij} = \frac{\exp({b_{ij}})}{\sum_k\exp({b_{ik}})}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">c</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">b</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\mathbf{s}_j = \sum_i{c_{ij} \hat{\mathbf{u}}_{j|i}}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">s</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bij,bijm-&gt;bjm&#39;</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">u_hat</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\mathbf{v}_j = squash(\mathbf{s}_j)</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">squash</span><span class="p">(</span><span class="n">s</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>
<script type="math/tex; mode=display">a_{ij} = \mathbf{v}_j \cdot \hat{\mathbf{u}}_{j|i}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bjm,bijm-&gt;bij&#39;</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">u_hat</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>
<script type="math/tex; mode=display">b_{ij} \gets b_{ij} + \mathbf{v}_j \cdot \hat{\mathbf{u}}_{j|i}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">b</span> <span class="o">=</span> <span class="n">b</span> <span class="o">+</span> <span class="n">a</span>
<span class="lineno">133</span>
<span class="lineno">134</span> <span class="k">return</span> <span class="n">v</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<h2>Margin loss for class existence</h2>
<p>A separate margin loss is used for each output capsule and the total loss is the sum of them.
The length of each output capsule is the probability that class is present in the input.</p>
<p>Loss for each output capsule or class $k$ is,
<script type="math/tex; mode=display">\mathcal{L}_k = T_k \max(0, m^{+} - \lVert\mathbf{v}_k\rVert)^2 +
\lambda (1 - T_k) \max(0, \lVert\mathbf{v}_k\rVert - m^{-})^2</script>
</p>
<p>$T_k$ is $1$ if the class $k$ is present and $0$ otherwise.
The first component of the loss is $0$ when the class is not present,
and the second component is $0$ if the class is present.
The $\max(0, x)$ is used to avoid predictions going to extremes.
$m^{+}$ is set to be $0.9$ and $m^{-}$ to be $0.1$ in the paper.</p>
<p>The $\lambda$ down-weighting is used to stop the length of all capsules from
falling during the initial phase of training.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span><span class="k">class</span> <span class="nc">MarginLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">n_labels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lambda_</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">m_positive</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9</span><span class="p">,</span> <span class="n">m_negative</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span>
<span class="lineno">158</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">159</span>
<span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">m_negative</span> <span class="o">=</span> <span class="n">m_negative</span>
<span class="lineno">161</span> <span class="bp">self</span><span class="o">.</span><span class="n">m_positive</span> <span class="o">=</span> <span class="n">m_positive</span>
<span class="lineno">162</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">=</span> <span class="n">lambda_</span>
<span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_labels</span> <span class="o">=</span> <span class="n">n_labels</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p><code>v</code>, $\mathbf{v}_j$ are the squashed output capsules.
This has shape <code>[batch_size, n_labels, n_features]</code>; that is, there is a capsule for each label.</p>
<p><code>labels</code> are the labels, and has shape <code>[batch_size]</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">labels</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\lVert \mathbf{v}_j \rVert</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">v_norm</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">((</span><span class="n">v</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\mathcal{L}</script>
<code>labels</code> is one-hot encoded labels of shape <code>[batch_size, n_labels]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">177</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_labels</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">labels</span><span class="o">.</span><span class="n">device</span><span class="p">)[</span><span class="n">labels</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\mathcal{L}_k = T_k \max(0, m^{+} - \lVert\mathbf{v}_k\rVert)^2 +
\lambda (1 - T_k) \max(0, \lVert\mathbf{v}_k\rVert - m^{-})^2</script>
<code>loss</code> has shape <code>[batch_size, n_labels]</code>. We have parallelized the computation
of $\mathcal{L}_k$ for for all $k$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">labels</span> <span class="o">*</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">m_positive</span> <span class="o">-</span> <span class="n">v_norm</span><span class="p">)</span> <span class="o">+</span> \
<span class="lineno">184</span> <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">labels</span><span class="p">)</span> <span class="o">*</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">v_norm</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">m_negative</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\sum_k \mathcal{L}_k</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>