Evidential Deep Learning to Quantify Classification Uncertainty (#85)

This commit is contained in:
Varuna Jayasiri
2021-08-21 10:25:32 +05:30
committed by GitHub
parent 387b6dfd1e
commit b6607524b8
19 changed files with 2720 additions and 17 deletions

View File

@ -154,15 +154,19 @@ implementations.</p>
<ul>
<li><a href="adaptive_computation/ponder_net/index.html">PonderNet</a></li>
</ul>
<h4><a href="uncertainty/index.html">Uncertainty</a></h4>
<ul>
<li><a href="uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li>
</ul>
<h3>Installation</h3>
<pre><code class="bash">pip install labml-nn
</code></pre>
<h3>Citing LabML</h3>
<p>If you use LabML for academic research, please cite the library using the following BibTeX entry.</p>
<p>If you use this for academic research, please cite it using the following BibTeX entry.</p>
<pre><code class="bibtex">@misc{labml,
author = {Varuna Jayasiri, Nipun Wijerathne},
title = {LabML: A library to organize machine learning experiments},
title = {labml.ai Annotated Paper Implementations},
year = {2020},
url = {https://nn.labml.ai/},
}

View File

@ -268,7 +268,10 @@ and set a new function to calculate the model.</p>
<p>Load configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">})</span></pre></div>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">76</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">77</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">0.001</span><span class="p">,</span>
<span class="lineno">78</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -279,8 +282,8 @@ and set a new function to calculate the model.</p>
<p>Start the experiment and run the training loop</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">78</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">81</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -291,8 +294,8 @@ and set a new function to calculate the model.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">83</span> <span class="n">main</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">85</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">86</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -114,6 +114,9 @@
"1704.03477": [
"https://nn.labml.ai/sketch_rnn/index.html"
],
"1806.01768": [
"https://nn.labml.ai/uncertainty/evidence/index.html"
],
"1509.06461": [
"https://nn.labml.ai/rl/dqn/index.html"
],

View File

@ -204,7 +204,7 @@
<url>
<loc>https://nn.labml.ai/normalization/batch_norm/mnist.html</loc>
<lastmod>2021-08-19T16:30:00+00:00</lastmod>
<lastmod>2021-08-20T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -281,7 +281,7 @@
<url>
<loc>https://nn.labml.ai/index.html</loc>
<lastmod>2021-08-12T16:30:00+00:00</lastmod>
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -797,6 +797,27 @@
</url>
<url>
<loc>https://nn.labml.ai/uncertainty/evidence/index.html</loc>
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/uncertainty/evidence/experiment.html</loc>
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/uncertainty/index.html</loc>
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/rl/game.html</loc>
<lastmod>2020-12-10T16:30:00+00:00</lastmod>

View File

@ -0,0 +1,867 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This trains is EDL model on MNIST"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Evidential Deep Learning to Quantify Classification Uncertainty Experiment"/>
<meta name="twitter:description" content="This trains is EDL model on MNIST"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/uncertainty/evidence/experiment.html"/>
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty Experiment"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty Experiment"/>
<meta property="og:description" content="This trains is EDL model on MNIST"/>
<title>Evidential Deep Learning to Quantify Classification Uncertainty Experiment</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/uncertainty/evidence/experiment.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">uncertainty</a>
<a class="parent" href="index.html">evidence</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/experiment.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a> Experiment</h1>
<p>This trains a model based on <a href="index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a>
on MNIST dataset.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
<span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">17</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">18</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span><span class="p">,</span> <span class="n">calculate</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_helpers.schedule</span> <span class="kn">import</span> <span class="n">Schedule</span><span class="p">,</span> <span class="n">RelativePiecewise</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">BatchIndex</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.mnist</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml_nn.uncertainty.evidence</span> <span class="kn">import</span> <span class="n">KLDivergenceLoss</span><span class="p">,</span> <span class="n">TrackStatistics</span><span class="p">,</span> <span class="n">MaximumLikelihoodLoss</span><span class="p">,</span> \
<span class="lineno">26</span> <span class="n">CrossEntropyBayesRisk</span><span class="p">,</span> <span class="n">SquaredErrorBayesRisk</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>LeNet based model fro MNIST classification</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">35</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>First $5x5$ convolution layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>ReLU activation</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>$2x2$ max-pooling</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Second $5x5$ convolution layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>ReLU activation</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>$2x2$ max-pooling</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>First fully-connected layer that maps to $500$ features</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">50</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>ReLU activation</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="bp">self</span><span class="o">.</span><span class="n">act3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Final fully connected layer to output evidence for $10$ classes.
The ReLU or Softplus activation is applied to this outside the model to get the
non-negative evidence</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Dropout for the hidden layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<ul>
<li><code>x</code> is the batch of MNIST images of shape <code>[batch_size, 1, 28, 28]</code></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Apply first convolution and max pooling.
The result has shape <code>[batch_size, 20, 12, 12]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Apply second convolution and max pooling.
The result has shape <code>[batch_size, 50, 4, 4]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Flatten the tensor to shape <code>[batch_size, 50 * 4 * 4]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Apply hidden layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act3</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Apply dropout</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Apply final layer and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<h2>Configurations</h2>
<p>We use <a href="../../experiments/mnist.html#MNISTConfigs"><code>MNISTConfigs</code></a> configurations.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p><a href="index.html#KLDivergenceLoss">KL Divergence regularization</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="n">KLDivergenceLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>KL Divergence regularization coefficient schedule</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">kl_div_coef</span><span class="p">:</span> <span class="n">Schedule</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>KL Divergence regularization coefficient schedule</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">kl_div_coef_schedule</span> <span class="o">=</span> <span class="p">[(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p><a href="index.html#TrackStatistics">Stats module</a> for tracking</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">TrackStatistics</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Dropout</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Module to convert the model output to non-zero evidences</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">outputs_to_evidence</span><span class="p">:</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<h3>Initialization</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Set tracker configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;loss.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">105</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;accuracy.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">106</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_histogram</span><span class="p">(</span><span class="s1">&#39;u.*&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">107</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_histogram</span><span class="p">(</span><span class="s1">&#39;prob.*&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="lineno">108</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s1">&#39;annealing_coef.*&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="lineno">109</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s1">&#39;kl_div_loss.*&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<h3>Training or validation step</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Training/Evaluation mode</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Move data to the device</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>One-hot coded targets</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">eye</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">127</span> <span class="n">target</span> <span class="o">=</span> <span class="n">eye</span><span class="p">[</span><span class="n">target</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Update global step (number of samples processed) when in training mode</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">131</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Get model outputs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Get evidences $e_k \ge 0$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">evidence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">outputs_to_evidence</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Calculate loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">evidence</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Calculate KL Divergence regularization loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss</span><span class="p">(</span><span class="n">evidence</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<span class="lineno">142</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>
<span class="lineno">143</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;kl_div_loss.&quot;</span><span class="p">,</span> <span class="n">kl_div_loss</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>KL Divergence loss coefficient $\lambda_t$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">annealing_coef</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mf">1.</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_coef</span><span class="p">(</span><span class="n">tracker</span><span class="o">.</span><span class="n">get_global_step</span><span class="p">()))</span>
<span class="lineno">147</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;annealing_coef.&quot;</span><span class="p">,</span> <span class="n">annealing_coef</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Total loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">+</span> <span class="n">annealing_coef</span> <span class="o">*</span> <span class="n">kl_div_loss</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Track statistics</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">stats</span><span class="p">(</span><span class="n">evidence</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Train the model</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Calculate gradients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Take optimizer step</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Clear the gradients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Save the tracked metrics</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<h3>Create model</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">169</span><span class="k">def</span> <span class="nf">mnist_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">return</span> <span class="n">Model</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<h3>KL Divergence Loss Coefficient Schedule</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">kl_div_coef</span><span class="p">)</span>
<span class="lineno">177</span><span class="k">def</span> <span class="nf">kl_div_coef</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Create a <a href="https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise">relative piecewise schedule</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">return</span> <span class="n">RelativePiecewise</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">kl_div_coef_schedule</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">epochs</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">train_dataset</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p><a href="index.html#MaximumLikelihoodLoss">Maximum Likelihood Loss</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">loss_func</span><span class="p">,</span> <span class="s1">&#39;max_likelihood_loss&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">MaximumLikelihoodLoss</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p><a href="index.html#CrossEntropyBayesRisk">Cross Entropy Bayes Risk</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">loss_func</span><span class="p">,</span> <span class="s1">&#39;cross_entropy_bayes_risk&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">CrossEntropyBayesRisk</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p><a href="index.html#SquaredErrorBayesRisk">Squared Error Bayes Risk</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">loss_func</span><span class="p">,</span> <span class="s1">&#39;squared_error_bayes_risk&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">SquaredErrorBayesRisk</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>ReLU to calculate evidence</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">outputs_to_evidence</span><span class="p">,</span> <span class="s1">&#39;relu&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Softplus to calculate evidence</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">outputs_to_evidence</span><span class="p">,</span> <span class="s1">&#39;softplus&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softplus</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Create experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;evidence_mnist&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Create configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Load configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">206</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">207</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">0.001</span><span class="p">,</span>
<span class="lineno">208</span> <span class="s1">&#39;optimizer.weight_decay&#39;</span><span class="p">:</span> <span class="mf">0.005</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>&lsquo;loss_func&rsquo;: &lsquo;max_likelihood_loss&rsquo;,
&lsquo;loss_func&rsquo;: &lsquo;cross_entropy_bayes_risk&rsquo;,</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="s1">&#39;loss_func&#39;</span><span class="p">:</span> <span class="s1">&#39;squared_error_bayes_risk&#39;</span><span class="p">,</span>
<span class="lineno">213</span>
<span class="lineno">214</span> <span class="s1">&#39;outputs_to_evidence&#39;</span><span class="p">:</span> <span class="s1">&#39;softplus&#39;</span><span class="p">,</span>
<span class="lineno">215</span>
<span class="lineno">216</span> <span class="s1">&#39;dropout&#39;</span><span class="p">:</span> <span class="mf">0.5</span><span class="p">,</span>
<span class="lineno">217</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>Start the experiment and run the training loop</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">220</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">225</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -0,0 +1,798 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification Uncertainty."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification Uncertainty."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/uncertainty/evidence/index.html"/>
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification Uncertainty."/>
<title>Evidential Deep Learning to Quantify Classification Uncertainty</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/uncertainty/evidence/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">uncertainty</a>
<a class="parent" href="index.html">evidence</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Evidential Deep Learning to Quantify Classification Uncertainty</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://papers.labml.ai/paper/1806.01768">Evidential Deep Learning to Quantify Classification Uncertainty</a>.</p>
<p><a href="https://en.wikipedia.org/wiki/Dempster%E2%80%93Shafer_theory">Dampster-Shafer Theory of Evidence</a>
assigns belief masses a set of classes (unlike assigning a probability to a single class).
Sum of the masses of all subsets is $1$.
Individual class probabilities (plausibilities) can be derived from these masses.</p>
<p>Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying &ldquo;I don&rsquo;t know&rdquo;.</p>
<p>If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and
an overall uncertainty mass $u \ge 0$ to all classes.</p>
<p>
<script type="math/tex; mode=display">u + \sum_{k=1}^K b_k = 1</script>
</p>
<p>Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$
and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$.
Paper uses term evidence as a measure of the amount of support
collected from data in favor of a sample to be classified into a certain class.</p>
<p>This corresponds to a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet distribution</a>
with parameters $\color{cyan}{\alpha_k} = e_k + 1$, and
$\color{cyan}{\alpha_0} = S = \sum_{k=1}^K \color{cyan}{\alpha_k}$ is known as the Dirichlet strength.
Dirichlet distribution $D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}})$
is a distribution over categorical distribution; i.e. you can sample class probabilities
from a Dirichlet distribution.
The expected probability for class $k$ is $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$.</p>
<p>We get the model to output evidences
<script type="math/tex; mode=display">\mathbf{e} = \color{cyan}{\mathbf{\alpha}} - 1 = f(\mathbf{x} | \Theta)</script>
for a given input $\mathbf{x}$.
We use a function such as
<a href="https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html">ReLU</a> or a
<a href="https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html">Softplus</a>
at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.</p>
<p>The paper proposes a few loss functions to train the model, which we have implemented below.</p>
<p>Here is the <a href="experiment.html">training code <code>experiment.py</code></a> to train a model on MNIST dataset.</p>
<p><a href="https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">55</span>
<span class="lineno">56</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span>
<span class="lineno">57</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p><a id="MaximumLikelihoodLoss"></a></p>
<h2>Type II Maximum Likelihood Loss</h2>
<p>The distribution D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}}) is a prior on the likelihood
$Multi(\mathbf{y} \vert p)$,
and the negative log marginal likelihood is calculated by integrating over class probabilities
$\mathbf{p}$.</p>
<p>If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\mathcal{L}(\Theta)
&= -\log \Bigg(
\int
\prod_{k=1}^K p_k^{y_k}
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
d\mathbf{p}
\Bigg ) \\
&= \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)
\end{align}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span><span class="k">class</span> <span class="nc">MaximumLikelihoodLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul>
<li><code>evidence</code> is $\mathbf{e} \ge 0$ with shape <code>[batch_size, n_classes]</code></li>
<li><code>target</code> is $\mathbf{y}$ with shape <code>[batch_size, n_classes]</code></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>$S = \sum_{k=1}^K \color{cyan}{\alpha_k}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">strength</span> <span class="o">=</span> <span class="n">alpha</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">target</span> <span class="o">*</span> <span class="p">(</span><span class="n">strength</span><span class="o">.</span><span class="n">log</span><span class="p">()[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">-</span> <span class="n">alpha</span><span class="o">.</span><span class="n">log</span><span class="p">()))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Mean loss over the batch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p><a id="CrossEntropyBayesRisk"></a></p>
<h2>Bayes Risk with Cross Entropy Loss</h2>
<p>Bayes risk is the overall maximum cost of making incorrect estimates.
It takes a cost function that gives the cost of making an incorrect estimate
and sums it over all possible outcomes based on probability distribution.</p>
<p>Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$
<script type="math/tex; mode=display">\sum_{k=1}^K -y_k \log p_k</script>
</p>
<p>We integrate this cost over all $\mathbf{p}$</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\mathcal{L}(\Theta)
&= -\log \Bigg(
\int
\Big[ \sum_{k=1}^K -y_k \log p_k \Big]
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
d\mathbf{p}
\Bigg ) \\
&= \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)
\end{align}</script>
</p>
<p>where $\psi(\cdot)$ is the $digamma$ function.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span><span class="k">class</span> <span class="nc">CrossEntropyBayesRisk</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<ul>
<li><code>evidence</code> is $\mathbf{e} \ge 0$ with shape <code>[batch_size, n_classes]</code></li>
<li><code>target</code> is $\mathbf{y}$ with shape <code>[batch_size, n_classes]</code></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>$S = \sum_{k=1}^K \color{cyan}{\alpha_k}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">strength</span> <span class="o">=</span> <span class="n">alpha</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">target</span> <span class="o">*</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">strength</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">alpha</span><span class="p">)))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Mean loss over the batch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p><a id="SquaredErrorBayesRisk"></a></p>
<h2>Bayes Risk with Squared Error Loss</h2>
<p>Here the cost function is squared error,
<script type="math/tex; mode=display">\sum_{k=1}^K (y_k - p_k)^2 = \Vert \mathbf{y} - \mathbf{p} \Vert_2^2</script>
</p>
<p>We integrate this cost over all $\mathbf{p}$</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\mathcal{L}(\Theta)
&= -\log \Bigg(
\int
\Big[ \sum_{k=1}^K (y_k - p_k)^2 \Big]
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
d\mathbf{p}
\Bigg ) \\
&= \sum_{k=1}^K \mathbb{E} \Big[ y_k^2 -2 y_k p_k + p_k^2 \Big] \\
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big)
\end{align}</script>
</p>
<p>Where <script type="math/tex; mode=display">\mathbb{E}[p_k] = \hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}</script>
is the expected probability when sampled from the Dirichlet distribution
and <script type="math/tex; mode=display">\mathbb{E}[p_k^2] = \mathbb{E}[p_k]^2 + \text{Var}(p_k)</script>
where
<script type="math/tex; mode=display">\text{Var}(p_k) = \frac{\color{cyan}{\alpha_k}(S - \color{cyan}{\alpha_k})}{S^2 (S + 1)}
= \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}</script>
is the variance.</p>
<p>This gives,
<script type="math/tex; mode=display">\begin{align}
\mathcal{L}(\Theta)
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big) \\
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k]^2 + \text{Var}(p_k) \Big) \\
&= \sum_{k=1}^K \Big( \big( y_k -\mathbb{E}[p_k] \big)^2 + \text{Var}(p_k) \Big) \\
&= \sum_{k=1}^K \Big( ( y_k -\hat{p}_k)^2 + \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1} \Big)
\end{align}</script>
</p>
<p>This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and
the second part is the variance.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span><span class="k">class</span> <span class="nc">SquaredErrorBayesRisk</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<ul>
<li><code>evidence</code> is $\mathbf{e} \ge 0$ with shape <code>[batch_size, n_classes]</code></li>
<li><code>target</code> is $\mathbf{y}$ with shape <code>[batch_size, n_classes]</code></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>$S = \sum_{k=1}^K \color{cyan}{\alpha_k}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">strength</span> <span class="o">=</span> <span class="n">alpha</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>$\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">p</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">/</span> <span class="n">strength</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Error $(y_k -\hat{p}_k)^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">err</span> <span class="o">=</span> <span class="p">(</span><span class="n">target</span> <span class="o">-</span> <span class="n">p</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">var</span> <span class="o">=</span> <span class="n">p</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">strength</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Sum of them</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">err</span> <span class="o">+</span> <span class="n">var</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Mean loss over the batch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p><a id="KLDivergenceLoss"></a></p>
<h2>KL Divergence Regularization Loss</h2>
<p>This tries to shrink the total evidence to zero if the sample cannot be correctly classified.</p>
<p>First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$ the
Dirichlet parameters after remove the correct evidence.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
&KL \Big[ D(\mathbf{p} \vert \mathbf{\tilde{\alpha}}) \Big \Vert
D(\mathbf{p} \vert <1, \dots, 1>\Big] \\
&= \log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
{\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg)
+ \sum_{k=1}^K (\tilde{\alpha}_k - 1)
\Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]
\end{align}</script>
</p>
<p>where $\Gamma(\cdot)$ is the gamma function,
$\psi(\cdot)$ is the $digamma$ function and
$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span><span class="k">class</span> <span class="nc">KLDivergenceLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<ul>
<li><code>evidence</code> is $\mathbf{e} \ge 0$ with shape <code>[batch_size, n_classes]</code></li>
<li><code>target</code> is $\mathbf{y}$ with shape <code>[batch_size, n_classes]</code></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">238</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">244</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Number of classes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="n">n_classes</span> <span class="o">=</span> <span class="n">evidence</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Remove non-misleading evidence
<script type="math/tex; mode=display">\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">249</span> <span class="n">alpha_tilde</span> <span class="o">=</span> <span class="n">target</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">target</span><span class="p">)</span> <span class="o">*</span> <span class="n">alpha</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">251</span> <span class="n">strength_tilde</span> <span class="o">=</span> <span class="n">alpha_tilde</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>The first term
<script type="math/tex; mode=display">\begin{align}
&\log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
{\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg) \\
&= \log \Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)
- \log \Gamma(K)
- \sum_{k=1}^K \log \Gamma(\tilde{\alpha}_k)
\end{align}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">261</span> <span class="n">first</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">lgamma</span><span class="p">(</span><span class="n">alpha_tilde</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
<span class="lineno">262</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">lgamma</span><span class="p">(</span><span class="n">alpha_tilde</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">n_classes</span><span class="p">)))</span>
<span class="lineno">263</span> <span class="o">-</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">lgamma</span><span class="p">(</span><span class="n">alpha_tilde</span><span class="p">))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>The second term
<script type="math/tex; mode=display">\sum_{k=1}^K (\tilde{\alpha}_k - 1)
\Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">268</span> <span class="n">second</span> <span class="o">=</span> <span class="p">(</span>
<span class="lineno">269</span> <span class="p">(</span><span class="n">alpha_tilde</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span>
<span class="lineno">270</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">alpha_tilde</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">strength_tilde</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">])</span>
<span class="lineno">271</span> <span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Sum of the terms</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">274</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">first</span> <span class="o">+</span> <span class="n">second</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Mean loss over the batch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">277</span> <span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p><a id="TrackStatistics"></a></p>
<h3>Track statistics</h3>
<p>This module computes statistics and tracks them with <a href="https://docs.labml.ai/api/tracker.html">labml <code>tracker</code></a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">280</span><span class="k">class</span> <span class="nc">TrackStatistics</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">287</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Number of classes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">289</span> <span class="n">n_classes</span> <span class="o">=</span> <span class="n">evidence</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Predictions that correctly match with the target (greedy sampling based on highest probability)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">291</span> <span class="n">match</span> <span class="o">=</span> <span class="n">evidence</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Track accuracy</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">293</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;accuracy.&#39;</span><span class="p">,</span> <span class="n">match</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="n">match</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">296</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>$S = \sum_{k=1}^K \color{cyan}{\alpha_k}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">298</span> <span class="n">strength</span> <span class="o">=</span> <span class="n">alpha</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>$\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">301</span> <span class="n">expected_probability</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">/</span> <span class="n">strength</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Expected probability of the selected (greedy highset probability) class</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">303</span> <span class="n">expected_probability</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">expected_probability</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Uncertainty mass $u = \frac{K}{S}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">306</span> <span class="n">uncertainty_mass</span> <span class="o">=</span> <span class="n">n_classes</span> <span class="o">/</span> <span class="n">strength</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Track $u$ for correctly predictions</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">309</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;u.succ.&#39;</span><span class="p">,</span> <span class="n">uncertainty_mass</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="n">match</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Track $u$ for incorrect predictions</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">311</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;u.fail.&#39;</span><span class="p">,</span> <span class="n">uncertainty_mass</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="o">~</span><span class="n">match</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Track $\hat{p}_k$ for correctly predictions</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">313</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;prob.succ.&#39;</span><span class="p">,</span> <span class="n">expected_probability</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="n">match</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Track $\hat{p}_k$ for incorrect predictions</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">315</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;prob.fail.&#39;</span><span class="p">,</span> <span class="n">expected_probability</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="o">~</span><span class="n">match</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -0,0 +1,144 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/uncertainty/evidence/readme.html"/>
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
<meta property="og:description" content=""/>
<title>Evidential Deep Learning to Quantify Classification Uncertainty</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/uncertainty/evidence/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">uncertainty</a>
<a class="parent" href="index.html">evidence</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://papers.labml.ai/paper/1806.01768">Evidential Deep Learning to Quantify Classification Uncertainty</a>.</p>
<p>Here is the <a href="https://nn.labml.ai/uncertainty/evidence/experiment.html">training code <code>experiment.py</code></a> to train a model on MNIST dataset.</p>
<p><a href="https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

143
docs/uncertainty/index.html Normal file
View File

@ -0,0 +1,143 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A set of PyTorch implementations/tutorials related to uncertainty estimation"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Neural Networks with Uncertainty Estimation"/>
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials related to uncertainty estimation"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/uncertainty/index.html"/>
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
<meta property="og:description" content="A set of PyTorch implementations/tutorials related to uncertainty estimation"/>
<title>Neural Networks with Uncertainty Estimation</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/uncertainty/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">uncertainty</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Neural Networks with Uncertainty Estimation</h1>
<p>These are neural network architectures that estimate the uncertainty of the predictions.</p>
<ul>
<li><a href="evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -0,0 +1,143 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Neural Networks with Uncertainty Estimation"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/uncertainty/readme.html"/>
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
<meta property="og:description" content=""/>
<title>Neural Networks with Uncertainty Estimation</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/uncertainty/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">uncertainty</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/uncertainty/index.html">Neural Networks with Uncertainty Estimation</a></h1>
<p>These are neural network architectures that estimate the uncertainty of the predictions.</p>
<ul>
<li><a href="https://nn.labml.ai/uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li>
</ul>
</div>
<div class='code'>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -94,6 +94,10 @@ Solving games with incomplete information such as poker with CFR.
* [PonderNet](adaptive_computation/ponder_net/index.html)
#### ✨ [Uncertainty](uncertainty/index.html)
* [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html)
### Installation
```bash
@ -102,12 +106,12 @@ pip install labml-nn
### Citing LabML
If you use LabML for academic research, please cite the library using the following BibTeX entry.
If you use this for academic research, please cite it using the following BibTeX entry.
```bibtex
@misc{labml,
author = {Varuna Jayasiri, Nipun Wijerathne},
title = {LabML: A library to organize machine learning experiments},
title = {labml.ai Annotated Paper Implementations},
year = {2020},
url = {https://nn.labml.ai/},
}

View File

@ -72,7 +72,10 @@ def main():
# Create configurations
conf = MNISTConfigs()
# Load configurations
experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
experiment.configs(conf, {
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 0.001,
})
# Start the experiment and run the training loop
with experiment.start():
conf.run()

View File

@ -0,0 +1,13 @@
"""
---
title: Neural Networks with Uncertainty Estimation
summary: >
A set of PyTorch implementations/tutorials related to uncertainty estimation
---
# Neural Networks with Uncertainty Estimation
These are neural network architectures that estimate the uncertainty of the predictions.
* [Evidential Deep Learning to Quantify Classification Uncertainty](evidence/index.html)
"""

View File

@ -0,0 +1,315 @@
"""
---
title: "Evidential Deep Learning to Quantify Classification Uncertainty"
summary: >
A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification
Uncertainty.
---
# Evidential Deep Learning to Quantify Classification Uncertainty
This is a [PyTorch](https://pytorch.org) implementation of the paper
[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
[Dampster-Shafer Theory of Evidence](https://en.wikipedia.org/wiki/Dempster%E2%80%93Shafer_theory)
assigns belief masses a set of classes (unlike assigning a probability to a single class).
Sum of the masses of all subsets is $1$.
Individual class probabilities (plausibilities) can be derived from these masses.
Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".
If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and
an overall uncertainty mass $u \ge 0$ to all classes.
$$u + \sum_{k=1}^K b_k = 1$$
Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$
and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$.
Paper uses term evidence as a measure of the amount of support
collected from data in favor of a sample to be classified into a certain class.
This corresponds to a [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution)
with parameters $\color{cyan}{\alpha_k} = e_k + 1$, and
$\color{cyan}{\alpha_0} = S = \sum_{k=1}^K \color{cyan}{\alpha_k}$ is known as the Dirichlet strength.
Dirichlet distribution $D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}})$
is a distribution over categorical distribution; i.e. you can sample class probabilities
from a Dirichlet distribution.
The expected probability for class $k$ is $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$.
We get the model to output evidences
$$\mathbf{e} = \color{cyan}{\mathbf{\alpha}} - 1 = f(\mathbf{x} | \Theta)$$
for a given input $\mathbf{x}$.
We use a function such as
[ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) or a
[Softplus](https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html)
at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.
The paper proposes a few loss functions to train the model, which we have implemented below.
Here is the [training code `experiment.py`](experiment.html) to train a model on MNIST dataset.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)
"""
import torch
from labml import tracker
from labml_helpers.module import Module
class MaximumLikelihoodLoss(Module):
"""
<a id="MaximumLikelihoodLoss"></a>
## Type II Maximum Likelihood Loss
The distribution D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}}) is a prior on the likelihood
$Multi(\mathbf{y} \vert p)$,
and the negative log marginal likelihood is calculated by integrating over class probabilities
$\mathbf{p}$.
If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,
\begin{align}
\mathcal{L}(\Theta)
&= -\log \Bigg(
\int
\prod_{k=1}^K p_k^{y_k}
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
d\mathbf{p}
\Bigg ) \\
&= \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)
\end{align}
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
"""
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
strength = alpha.sum(dim=-1)
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)$
loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)
# Mean loss over the batch
return loss.mean()
class CrossEntropyBayesRisk(Module):
"""
<a id="CrossEntropyBayesRisk"></a>
## Bayes Risk with Cross Entropy Loss
Bayes risk is the overall maximum cost of making incorrect estimates.
It takes a cost function that gives the cost of making an incorrect estimate
and sums it over all possible outcomes based on probability distribution.
Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$
$$\sum_{k=1}^K -y_k \log p_k$$
We integrate this cost over all $\mathbf{p}$
\begin{align}
\mathcal{L}(\Theta)
&= -\log \Bigg(
\int
\Big[ \sum_{k=1}^K -y_k \log p_k \Big]
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
d\mathbf{p}
\Bigg ) \\
&= \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)
\end{align}
where $\psi(\cdot)$ is the $digamma$ function.
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
"""
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
strength = alpha.sum(dim=-1)
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)$
loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)
# Mean loss over the batch
return loss.mean()
class SquaredErrorBayesRisk(Module):
"""
<a id="SquaredErrorBayesRisk"></a>
## Bayes Risk with Squared Error Loss
Here the cost function is squared error,
$$\sum_{k=1}^K (y_k - p_k)^2 = \Vert \mathbf{y} - \mathbf{p} \Vert_2^2$$
We integrate this cost over all $\mathbf{p}$
\begin{align}
\mathcal{L}(\Theta)
&= -\log \Bigg(
\int
\Big[ \sum_{k=1}^K (y_k - p_k)^2 \Big]
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
d\mathbf{p}
\Bigg ) \\
&= \sum_{k=1}^K \mathbb{E} \Big[ y_k^2 -2 y_k p_k + p_k^2 \Big] \\
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big)
\end{align}
Where $$\mathbb{E}[p_k] = \hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$$
is the expected probability when sampled from the Dirichlet distribution
and $$\mathbb{E}[p_k^2] = \mathbb{E}[p_k]^2 + \text{Var}(p_k)$$
where
$$\text{Var}(p_k) = \frac{\color{cyan}{\alpha_k}(S - \color{cyan}{\alpha_k})}{S^2 (S + 1)}
= \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$$
is the variance.
This gives,
\begin{align}
\mathcal{L}(\Theta)
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big) \\
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k]^2 + \text{Var}(p_k) \Big) \\
&= \sum_{k=1}^K \Big( \big( y_k -\mathbb{E}[p_k] \big)^2 + \text{Var}(p_k) \Big) \\
&= \sum_{k=1}^K \Big( ( y_k -\hat{p}_k)^2 + \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1} \Big)
\end{align}
This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and
the second part is the variance.
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
"""
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
strength = alpha.sum(dim=-1)
# $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
p = alpha / strength[:, None]
# Error $(y_k -\hat{p}_k)^2$
err = (target - p) ** 2
# Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$
var = p * (1 - p) / (strength[:, None] + 1)
# Sum of them
loss = (err + var).sum(dim=-1)
# Mean loss over the batch
return loss.mean()
class KLDivergenceLoss(Module):
"""
<a id="KLDivergenceLoss"></a>
## KL Divergence Regularization Loss
This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$ the
Dirichlet parameters after remove the correct evidence.
\begin{align}
&KL \Big[ D(\mathbf{p} \vert \mathbf{\tilde{\alpha}}) \Big \Vert
D(\mathbf{p} \vert <1, \dots, 1>\Big] \\
&= \log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
{\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg)
+ \sum_{k=1}^K (\tilde{\alpha}_k - 1)
\Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]
\end{align}
where $\Gamma(\cdot)$ is the gamma function,
$\psi(\cdot)$ is the $digamma$ function and
$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
"""
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
"""
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# Number of classes
n_classes = evidence.shape[-1]
# Remove non-misleading evidence
# $$\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$$
alpha_tilde = target + (1 - target) * alpha
# $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
strength_tilde = alpha_tilde.sum(dim=-1)
# The first term
# \begin{align}
# &\log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
# {\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg) \\
# &= \log \Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)
# - \log \Gamma(K)
# - \sum_{k=1}^K \log \Gamma(\tilde{\alpha}_k)
# \end{align}
first = (torch.lgamma(alpha_tilde.sum(dim=-1))
- torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
- (torch.lgamma(alpha_tilde)).sum(dim=-1))
# The second term
# $$\sum_{k=1}^K (\tilde{\alpha}_k - 1)
# \Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]$$
second = (
(alpha_tilde - 1) *
(torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
).sum(dim=-1)
# Sum of the terms
loss = first + second
# Mean loss over the batch
return loss.mean()
class TrackStatistics(Module):
"""
<a id="TrackStatistics"></a>
### Track statistics
This module computes statistics and tracks them with [labml `tracker`](https://docs.labml.ai/api/tracker.html).
"""
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
# Number of classes
n_classes = evidence.shape[-1]
# Predictions that correctly match with the target (greedy sampling based on highest probability)
match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))
# Track accuracy
tracker.add('accuracy.', match.sum() / match.shape[0])
# $\color{cyan}{\alpha_k} = e_k + 1$
alpha = evidence + 1.
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
strength = alpha.sum(dim=-1)
# $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
expected_probability = alpha / strength[:, None]
# Expected probability of the selected (greedy highset probability) class
expected_probability, _ = expected_probability.max(dim=-1)
# Uncertainty mass $u = \frac{K}{S}$
uncertainty_mass = n_classes / strength
# Track $u$ for correctly predictions
tracker.add('u.succ.', uncertainty_mass.masked_select(match))
# Track $u$ for incorrect predictions
tracker.add('u.fail.', uncertainty_mass.masked_select(~match))
# Track $\hat{p}_k$ for correctly predictions
tracker.add('prob.succ.', expected_probability.masked_select(match))
# Track $\hat{p}_k$ for incorrect predictions
tracker.add('prob.fail.', expected_probability.masked_select(~match))

View File

@ -0,0 +1,225 @@
"""
---
title: "Evidential Deep Learning to Quantify Classification Uncertainty Experiment"
summary: >
This trains is EDL model on MNIST
---
# [Evidential Deep Learning to Quantify Classification Uncertainty](index.html) Experiment
This trains a model based on [Evidential Deep Learning to Quantify Classification Uncertainty](index.html)
on MNIST dataset.
"""
from typing import Any
import torch.nn as nn
import torch.utils.data
from labml import tracker, experiment
from labml.configs import option, calculate
from labml_helpers.module import Module
from labml_helpers.schedule import Schedule, RelativePiecewise
from labml_helpers.train_valid import BatchIndex
from labml_nn.experiments.mnist import MNISTConfigs
from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
CrossEntropyBayesRisk, SquaredErrorBayesRisk
class Model(Module):
"""
## LeNet based model fro MNIST classification
"""
def __init__(self, dropout: float):
super().__init__()
# First $5x5$ convolution layer
self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
# ReLU activation
self.act1 = nn.ReLU()
# $2x2$ max-pooling
self.max_pool1 = nn.MaxPool2d(2, 2)
# Second $5x5$ convolution layer
self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
# ReLU activation
self.act2 = nn.ReLU()
# $2x2$ max-pooling
self.max_pool2 = nn.MaxPool2d(2, 2)
# First fully-connected layer that maps to $500$ features
self.fc1 = nn.Linear(50 * 4 * 4, 500)
# ReLU activation
self.act3 = nn.ReLU()
# Final fully connected layer to output evidence for $10$ classes.
# The ReLU or Softplus activation is applied to this outside the model to get the
# non-negative evidence
self.fc2 = nn.Linear(500, 10)
# Dropout for the hidden layer
self.dropout = nn.Dropout(p=dropout)
def __call__(self, x: torch.Tensor):
"""
* `x` is the batch of MNIST images of shape `[batch_size, 1, 28, 28]`
"""
# Apply first convolution and max pooling.
# The result has shape `[batch_size, 20, 12, 12]`
x = self.max_pool1(self.act1(self.conv1(x)))
# Apply second convolution and max pooling.
# The result has shape `[batch_size, 50, 4, 4]`
x = self.max_pool2(self.act2(self.conv2(x)))
# Flatten the tensor to shape `[batch_size, 50 * 4 * 4]`
x = x.view(x.shape[0], -1)
# Apply hidden layer
x = self.act3(self.fc1(x))
# Apply dropout
x = self.dropout(x)
# Apply final layer and return
return self.fc2(x)
class Configs(MNISTConfigs):
"""
## Configurations
We use [`MNISTConfigs`](../../experiments/mnist.html#MNISTConfigs) configurations.
"""
# [KL Divergence regularization](index.html#KLDivergenceLoss)
kl_div_loss = KLDivergenceLoss()
# KL Divergence regularization coefficient schedule
kl_div_coef: Schedule
# KL Divergence regularization coefficient schedule
kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
# [Stats module](index.html#TrackStatistics) for tracking
stats = TrackStatistics()
# Dropout
dropout: float = 0.5
# Module to convert the model output to non-zero evidences
outputs_to_evidence: Module
def init(self):
"""
### Initialization
"""
# Set tracker configurations
tracker.set_scalar("loss.*", True)
tracker.set_scalar("accuracy.*", True)
tracker.set_histogram('u.*', True)
tracker.set_histogram('prob.*', False)
tracker.set_scalar('annealing_coef.*', False)
tracker.set_scalar('kl_div_loss.*', False)
#
self.state_modules = []
def step(self, batch: Any, batch_idx: BatchIndex):
"""
### Training or validation step
"""
# Training/Evaluation mode
self.model.train(self.mode.is_train)
# Move data to the device
data, target = batch[0].to(self.device), batch[1].to(self.device)
# One-hot coded targets
eye = torch.eye(10).to(torch.float).to(self.device)
target = eye[target]
# Update global step (number of samples processed) when in training mode
if self.mode.is_train:
tracker.add_global_step(len(data))
# Get model outputs
outputs = self.model(data)
# Get evidences $e_k \ge 0$
evidence = self.outputs_to_evidence(outputs)
# Calculate loss
loss = self.loss_func(evidence, target)
# Calculate KL Divergence regularization loss
kl_div_loss = self.kl_div_loss(evidence, target)
tracker.add("loss.", loss)
tracker.add("kl_div_loss.", kl_div_loss)
# KL Divergence loss coefficient $\lambda_t$
annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
tracker.add("annealing_coef.", annealing_coef)
# Total loss
loss = loss + annealing_coef * kl_div_loss
# Track statistics
self.stats(evidence, target)
# Train the model
if self.mode.is_train:
# Calculate gradients
loss.backward()
# Take optimizer step
self.optimizer.step()
# Clear the gradients
self.optimizer.zero_grad()
# Save the tracked metrics
tracker.save()
@option(Configs.model)
def mnist_model(c: Configs):
"""
### Create model
"""
return Model(c.dropout).to(c.device)
@option(Configs.kl_div_coef)
def kl_div_coef(c: Configs):
"""
### KL Divergence Loss Coefficient Schedule
"""
# Create a [relative piecewise schedule](https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise)
return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
# [Maximum Likelihood Loss](index.html#MaximumLikelihoodLoss)
calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
# [Cross Entropy Bayes Risk](index.html#CrossEntropyBayesRisk)
calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
# [Squared Error Bayes Risk](index.html#SquaredErrorBayesRisk)
calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
# ReLU to calculate evidence
calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
# Softplus to calculate evidence
calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
def main():
# Create experiment
experiment.create(name='evidence_mnist')
# Create configurations
conf = Configs()
# Load configurations
experiment.configs(conf, {
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 0.001,
'optimizer.weight_decay': 0.005,
# 'loss_func': 'max_likelihood_loss',
# 'loss_func': 'cross_entropy_bayes_risk',
'loss_func': 'squared_error_bayes_risk',
'outputs_to_evidence': 'softplus',
'dropout': 0.5,
})
# Start the experiment and run the training loop
with experiment.start():
conf.run()
#
if __name__ == '__main__':
main()

View File

@ -0,0 +1,8 @@
# [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
This is a [PyTorch](https://pytorch.org) implementation of the paper
[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
Here is the [training code `experiment.py`](https://nn.labml.ai/uncertainty/evidence/experiment.html) to train a model on MNIST dataset.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)

View File

@ -0,0 +1,5 @@
# [Neural Networks with Uncertainty Estimation](https://nn.labml.ai/uncertainty/index.html)
These are neural network architectures that estimate the uncertainty of the predictions.
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)

View File

@ -99,6 +99,10 @@ Solving games with incomplete information such as poker with CFR.
* [PonderNet](https://nn.labml.ai/adaptive_computation/ponder_net/index.html)
#### ✨ [Uncertainty](https://nn.labml.ai/uncertainty/index.html)
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
### Installation
```bash

View File

@ -1,6 +1,6 @@
torch>=1.7
labml>=0.4.94
labml-helpers>=0.4.77
labml>=0.4.132
labml-helpers>=0.4.81
torchvision
numpy>=1.16.3
matplotlib>=3.0.3

View File

@ -5,10 +5,10 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
version='0.4.109',
version='0.4.110',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="🧠 Implementations/tutorials of deep learning papers with side-by-side notes; including transformers (original, xl, switch, feedback, vit), optimizers(adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), reinforcement learning (ppo, dqn), capsnet, distillation, etc.",
description="🧑‍🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/labmlai/annotated_deep_learning_paper_implementations",
@ -20,7 +20,7 @@ setuptools.setup(
'labml_helpers', 'labml_helpers.*',
'test',
'test.*')),
install_requires=['labml>=0.4.129',
install_requires=['labml>=0.4.132',
'labml-helpers>=0.4.81',
'torch',
'einops',