mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
Evidential Deep Learning to Quantify Classification Uncertainty (#85)
This commit is contained in:
@ -154,15 +154,19 @@ implementations.</p>
|
||||
<ul>
|
||||
<li><a href="adaptive_computation/ponder_net/index.html">PonderNet</a></li>
|
||||
</ul>
|
||||
<h4>✨ <a href="uncertainty/index.html">Uncertainty</a></h4>
|
||||
<ul>
|
||||
<li><a href="uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li>
|
||||
</ul>
|
||||
<h3>Installation</h3>
|
||||
<pre><code class="bash">pip install labml-nn
|
||||
</code></pre>
|
||||
|
||||
<h3>Citing LabML</h3>
|
||||
<p>If you use LabML for academic research, please cite the library using the following BibTeX entry.</p>
|
||||
<p>If you use this for academic research, please cite it using the following BibTeX entry.</p>
|
||||
<pre><code class="bibtex">@misc{labml,
|
||||
author = {Varuna Jayasiri, Nipun Wijerathne},
|
||||
title = {LabML: A library to organize machine learning experiments},
|
||||
title = {labml.ai Annotated Paper Implementations},
|
||||
year = {2020},
|
||||
url = {https://nn.labml.ai/},
|
||||
}
|
||||
|
@ -268,7 +268,10 @@ and set a new function to calculate the model.</p>
|
||||
<p>Load configurations</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">})</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
|
||||
<span class="lineno">76</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span>
|
||||
<span class="lineno">77</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">0.001</span><span class="p">,</span>
|
||||
<span class="lineno">78</span> <span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
@ -279,8 +282,8 @@ and set a new function to calculate the model.</p>
|
||||
<p>Start the experiment and run the training loop</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
|
||||
<span class="lineno">78</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
|
||||
<span class="lineno">81</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
@ -291,8 +294,8 @@ and set a new function to calculate the model.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">82</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">83</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">85</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">86</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
|
@ -114,6 +114,9 @@
|
||||
"1704.03477": [
|
||||
"https://nn.labml.ai/sketch_rnn/index.html"
|
||||
],
|
||||
"1806.01768": [
|
||||
"https://nn.labml.ai/uncertainty/evidence/index.html"
|
||||
],
|
||||
"1509.06461": [
|
||||
"https://nn.labml.ai/rl/dqn/index.html"
|
||||
],
|
||||
|
@ -204,7 +204,7 @@
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/batch_norm/mnist.html</loc>
|
||||
<lastmod>2021-08-19T16:30:00+00:00</lastmod>
|
||||
<lastmod>2021-08-20T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
@ -281,7 +281,7 @@
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/index.html</loc>
|
||||
<lastmod>2021-08-12T16:30:00+00:00</lastmod>
|
||||
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
@ -797,6 +797,27 @@
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/uncertainty/evidence/index.html</loc>
|
||||
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/uncertainty/evidence/experiment.html</loc>
|
||||
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/uncertainty/index.html</loc>
|
||||
<lastmod>2021-08-21T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/rl/game.html</loc>
|
||||
<lastmod>2020-12-10T16:30:00+00:00</lastmod>
|
||||
|
867
docs/uncertainty/evidence/experiment.html
Normal file
867
docs/uncertainty/evidence/experiment.html
Normal file
@ -0,0 +1,867 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="This trains is EDL model on MNIST"/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Evidential Deep Learning to Quantify Classification Uncertainty Experiment"/>
|
||||
<meta name="twitter:description" content="This trains is EDL model on MNIST"/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/uncertainty/evidence/experiment.html"/>
|
||||
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty Experiment"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty Experiment"/>
|
||||
<meta property="og:description" content="This trains is EDL model on MNIST"/>
|
||||
|
||||
<title>Evidential Deep Learning to Quantify Classification Uncertainty Experiment</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/uncertainty/evidence/experiment.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">uncertainty</a>
|
||||
<a class="parent" href="index.html">evidence</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/experiment.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1><a href="index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a> Experiment</h1>
|
||||
<p>This trains a model based on <a href="index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a>
|
||||
on MNIST dataset.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
|
||||
<span class="lineno">15</span>
|
||||
<span class="lineno">16</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<span class="lineno">17</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
|
||||
<span class="lineno">18</span>
|
||||
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span>
|
||||
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span><span class="p">,</span> <span class="n">calculate</span>
|
||||
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
||||
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_helpers.schedule</span> <span class="kn">import</span> <span class="n">Schedule</span><span class="p">,</span> <span class="n">RelativePiecewise</span>
|
||||
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">BatchIndex</span>
|
||||
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.mnist</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span>
|
||||
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml_nn.uncertainty.evidence</span> <span class="kn">import</span> <span class="n">KLDivergenceLoss</span><span class="p">,</span> <span class="n">TrackStatistics</span><span class="p">,</span> <span class="n">MaximumLikelihoodLoss</span><span class="p">,</span> \
|
||||
<span class="lineno">26</span> <span class="n">CrossEntropyBayesRisk</span><span class="p">,</span> <span class="n">SquaredErrorBayesRisk</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>LeNet based model fro MNIST classification</h2>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">29</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">34</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
|
||||
<span class="lineno">35</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>First $5x5$ convolution layer</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>ReLU activation</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">39</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>$2x2$ max-pooling</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Second $5x5$ convolution layer</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>ReLU activation</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>$2x2$ max-pooling</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>First fully-connected layer that maps to $500$ features</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">50</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>ReLU activation</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">51</span> <span class="bp">self</span><span class="o">.</span><span class="n">act3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Final fully connected layer to output evidence for $10$ classes.
|
||||
The ReLU or Softplus activation is applied to this outside the model to get the
|
||||
non-negative evidence</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Dropout for the hidden layer</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>x</code> is the batch of MNIST images of shape <code>[batch_size, 1, 28, 28]</code></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">59</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Apply first convolution and max pooling.
|
||||
The result has shape <code>[batch_size, 20, 12, 12]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Apply second convolution and max pooling.
|
||||
The result has shape <code>[batch_size, 50, 4, 4]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Flatten the tensor to shape <code>[batch_size, 50 * 4 * 4]</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Apply hidden layer</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act3</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Apply dropout</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Apply final layer and return</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<h2>Configurations</h2>
|
||||
<p>We use <a href="../../experiments/mnist.html#MNISTConfigs"><code>MNISTConfigs</code></a> configurations.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">79</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p><a href="index.html#KLDivergenceLoss">KL Divergence regularization</a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="n">KLDivergenceLoss</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>KL Divergence regularization coefficient schedule</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">kl_div_coef</span><span class="p">:</span> <span class="n">Schedule</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>KL Divergence regularization coefficient schedule</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">kl_div_coef_schedule</span> <span class="o">=</span> <span class="p">[(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-24'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p><a href="index.html#TrackStatistics">Stats module</a> for tracking</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">TrackStatistics</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-25'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Dropout</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-26'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<p>Module to convert the model output to non-zero evidences</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">outputs_to_evidence</span><span class="p">:</span> <span class="n">Module</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-27'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-27'>#</a>
|
||||
</div>
|
||||
<h3>Initialization</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">99</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-28'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-28'>#</a>
|
||||
</div>
|
||||
<p>Set tracker configurations</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">"loss.*"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
||||
<span class="lineno">105</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">"accuracy.*"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
||||
<span class="lineno">106</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_histogram</span><span class="p">(</span><span class="s1">'u.*'</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
||||
<span class="lineno">107</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_histogram</span><span class="p">(</span><span class="s1">'prob.*'</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
||||
<span class="lineno">108</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s1">'annealing_coef.*'</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
||||
<span class="lineno">109</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s1">'kl_div_loss.*'</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-29'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-29'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-30'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-30'>#</a>
|
||||
</div>
|
||||
<h3>Training or validation step</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-31'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-31'>#</a>
|
||||
</div>
|
||||
<p>Training/Evaluation mode</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-32'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-32'>#</a>
|
||||
</div>
|
||||
<p>Move data to the device</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-33'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-33'>#</a>
|
||||
</div>
|
||||
<p>One-hot coded targets</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">eye</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="lineno">127</span> <span class="n">target</span> <span class="o">=</span> <span class="n">eye</span><span class="p">[</span><span class="n">target</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-34'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-34'>#</a>
|
||||
</div>
|
||||
<p>Update global step (number of samples processed) when in training mode</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
|
||||
<span class="lineno">131</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-35'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-35'>#</a>
|
||||
</div>
|
||||
<p>Get model outputs</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-36'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-36'>#</a>
|
||||
</div>
|
||||
<p>Get evidences $e_k \ge 0$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">evidence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">outputs_to_evidence</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-37'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-37'>#</a>
|
||||
</div>
|
||||
<p>Calculate loss</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">evidence</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-38'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-38'>#</a>
|
||||
</div>
|
||||
<p>Calculate KL Divergence regularization loss</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss</span><span class="p">(</span><span class="n">evidence</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||||
<span class="lineno">142</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">"loss."</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>
|
||||
<span class="lineno">143</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">"kl_div_loss."</span><span class="p">,</span> <span class="n">kl_div_loss</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-39'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-39'>#</a>
|
||||
</div>
|
||||
<p>KL Divergence loss coefficient $\lambda_t$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">annealing_coef</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mf">1.</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_coef</span><span class="p">(</span><span class="n">tracker</span><span class="o">.</span><span class="n">get_global_step</span><span class="p">()))</span>
|
||||
<span class="lineno">147</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">"annealing_coef."</span><span class="p">,</span> <span class="n">annealing_coef</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-40'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-40'>#</a>
|
||||
</div>
|
||||
<p>Total loss</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">+</span> <span class="n">annealing_coef</span> <span class="o">*</span> <span class="n">kl_div_loss</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-41'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-41'>#</a>
|
||||
</div>
|
||||
<p>Track statistics</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">stats</span><span class="p">(</span><span class="n">evidence</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-42'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-42'>#</a>
|
||||
</div>
|
||||
<p>Train the model</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-43'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-43'>#</a>
|
||||
</div>
|
||||
<p>Calculate gradients</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-44'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-44'>#</a>
|
||||
</div>
|
||||
<p>Take optimizer step</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-45'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-45'>#</a>
|
||||
</div>
|
||||
<p>Clear the gradients</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">162</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-46'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-46'>#</a>
|
||||
</div>
|
||||
<p>Save the tracked metrics</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-47'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-47'>#</a>
|
||||
</div>
|
||||
<h3>Create model</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">168</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||||
<span class="lineno">169</span><span class="k">def</span> <span class="nf">mnist_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-48'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-48'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">return</span> <span class="n">Model</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-49'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-49'>#</a>
|
||||
</div>
|
||||
<h3>KL Divergence Loss Coefficient Schedule</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">176</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">kl_div_coef</span><span class="p">)</span>
|
||||
<span class="lineno">177</span><span class="k">def</span> <span class="nf">kl_div_coef</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-50'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-50'>#</a>
|
||||
</div>
|
||||
<p>Create a <a href="https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise">relative piecewise schedule</a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">return</span> <span class="n">RelativePiecewise</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">kl_div_coef_schedule</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">epochs</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">train_dataset</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-51'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-51'>#</a>
|
||||
</div>
|
||||
<p><a href="index.html#MaximumLikelihoodLoss">Maximum Likelihood Loss</a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">187</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">loss_func</span><span class="p">,</span> <span class="s1">'max_likelihood_loss'</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">MaximumLikelihoodLoss</span><span class="p">())</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-52'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-52'>#</a>
|
||||
</div>
|
||||
<p><a href="index.html#CrossEntropyBayesRisk">Cross Entropy Bayes Risk</a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">189</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">loss_func</span><span class="p">,</span> <span class="s1">'cross_entropy_bayes_risk'</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">CrossEntropyBayesRisk</span><span class="p">())</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-53'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-53'>#</a>
|
||||
</div>
|
||||
<p><a href="index.html#SquaredErrorBayesRisk">Squared Error Bayes Risk</a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">191</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">loss_func</span><span class="p">,</span> <span class="s1">'squared_error_bayes_risk'</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">SquaredErrorBayesRisk</span><span class="p">())</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-54'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-54'>#</a>
|
||||
</div>
|
||||
<p>ReLU to calculate evidence</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">194</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">outputs_to_evidence</span><span class="p">,</span> <span class="s1">'relu'</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">())</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-55'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-55'>#</a>
|
||||
</div>
|
||||
<p>Softplus to calculate evidence</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">196</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">outputs_to_evidence</span><span class="p">,</span> <span class="s1">'softplus'</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softplus</span><span class="p">())</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-56'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-56'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">199</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-57'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-57'>#</a>
|
||||
</div>
|
||||
<p>Create experiment</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'evidence_mnist'</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-58'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-58'>#</a>
|
||||
</div>
|
||||
<p>Create configurations</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-59'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-59'>#</a>
|
||||
</div>
|
||||
<p>Load configurations</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
|
||||
<span class="lineno">206</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span>
|
||||
<span class="lineno">207</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">0.001</span><span class="p">,</span>
|
||||
<span class="lineno">208</span> <span class="s1">'optimizer.weight_decay'</span><span class="p">:</span> <span class="mf">0.005</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-60'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-60'>#</a>
|
||||
</div>
|
||||
<p>‘loss_func’: ‘max_likelihood_loss’,
|
||||
‘loss_func’: ‘cross_entropy_bayes_risk’,</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">212</span> <span class="s1">'loss_func'</span><span class="p">:</span> <span class="s1">'squared_error_bayes_risk'</span><span class="p">,</span>
|
||||
<span class="lineno">213</span>
|
||||
<span class="lineno">214</span> <span class="s1">'outputs_to_evidence'</span><span class="p">:</span> <span class="s1">'softplus'</span><span class="p">,</span>
|
||||
<span class="lineno">215</span>
|
||||
<span class="lineno">216</span> <span class="s1">'dropout'</span><span class="p">:</span> <span class="mf">0.5</span><span class="p">,</span>
|
||||
<span class="lineno">217</span> <span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-61'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-61'>#</a>
|
||||
</div>
|
||||
<p>Start the experiment and run the training loop</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">219</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
|
||||
<span class="lineno">220</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-62'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-62'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">224</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">225</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
798
docs/uncertainty/evidence/index.html
Normal file
798
docs/uncertainty/evidence/index.html
Normal file
@ -0,0 +1,798 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification Uncertainty."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
|
||||
<meta name="twitter:description" content="A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification Uncertainty."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/uncertainty/evidence/index.html"/>
|
||||
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
|
||||
<meta property="og:description" content="A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification Uncertainty."/>
|
||||
|
||||
<title>Evidential Deep Learning to Quantify Classification Uncertainty</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/uncertainty/evidence/index.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">uncertainty</a>
|
||||
<a class="parent" href="index.html">evidence</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/__init__.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Evidential Deep Learning to Quantify Classification Uncertainty</h1>
|
||||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
|
||||
<a href="https://papers.labml.ai/paper/1806.01768">Evidential Deep Learning to Quantify Classification Uncertainty</a>.</p>
|
||||
<p><a href="https://en.wikipedia.org/wiki/Dempster%E2%80%93Shafer_theory">Dampster-Shafer Theory of Evidence</a>
|
||||
assigns belief masses a set of classes (unlike assigning a probability to a single class).
|
||||
Sum of the masses of all subsets is $1$.
|
||||
Individual class probabilities (plausibilities) can be derived from these masses.</p>
|
||||
<p>Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying “I don’t know”.</p>
|
||||
<p>If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and
|
||||
an overall uncertainty mass $u \ge 0$ to all classes.</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">u + \sum_{k=1}^K b_k = 1</script>
|
||||
</p>
|
||||
<p>Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$
|
||||
and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$.
|
||||
Paper uses term evidence as a measure of the amount of support
|
||||
collected from data in favor of a sample to be classified into a certain class.</p>
|
||||
<p>This corresponds to a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet distribution</a>
|
||||
with parameters $\color{cyan}{\alpha_k} = e_k + 1$, and
|
||||
$\color{cyan}{\alpha_0} = S = \sum_{k=1}^K \color{cyan}{\alpha_k}$ is known as the Dirichlet strength.
|
||||
Dirichlet distribution $D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}})$
|
||||
is a distribution over categorical distribution; i.e. you can sample class probabilities
|
||||
from a Dirichlet distribution.
|
||||
The expected probability for class $k$ is $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$.</p>
|
||||
<p>We get the model to output evidences
|
||||
<script type="math/tex; mode=display">\mathbf{e} = \color{cyan}{\mathbf{\alpha}} - 1 = f(\mathbf{x} | \Theta)</script>
|
||||
for a given input $\mathbf{x}$.
|
||||
We use a function such as
|
||||
<a href="https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html">ReLU</a> or a
|
||||
<a href="https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html">Softplus</a>
|
||||
at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.</p>
|
||||
<p>The paper proposes a few loss functions to train the model, which we have implemented below.</p>
|
||||
<p>Here is the <a href="experiment.html">training code <code>experiment.py</code></a> to train a model on MNIST dataset.</p>
|
||||
<p><a href="https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">54</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">55</span>
|
||||
<span class="lineno">56</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span>
|
||||
<span class="lineno">57</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<p><a id="MaximumLikelihoodLoss"></a></p>
|
||||
<h2>Type II Maximum Likelihood Loss</h2>
|
||||
<p>The distribution D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}}) is a prior on the likelihood
|
||||
$Multi(\mathbf{y} \vert p)$,
|
||||
and the negative log marginal likelihood is calculated by integrating over class probabilities
|
||||
$\mathbf{p}$.</p>
|
||||
<p>If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= -\log \Bigg(
|
||||
\int
|
||||
\prod_{k=1}^K p_k^{y_k}
|
||||
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
|
||||
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
|
||||
d\mathbf{p}
|
||||
\Bigg ) \\
|
||||
&= \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">60</span><span class="k">class</span> <span class="nc">MaximumLikelihoodLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>evidence</code> is $\mathbf{e} \ge 0$ with shape <code>[batch_size, n_classes]</code></li>
|
||||
<li><code>target</code> is $\mathbf{y}$ with shape <code>[batch_size, n_classes]</code></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">84</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>$S = \sum_{k=1}^K \color{cyan}{\alpha_k}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">strength</span> <span class="o">=</span> <span class="n">alpha</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">target</span> <span class="o">*</span> <span class="p">(</span><span class="n">strength</span><span class="o">.</span><span class="n">log</span><span class="p">()[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">-</span> <span class="n">alpha</span><span class="o">.</span><span class="n">log</span><span class="p">()))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Mean loss over the batch</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">98</span> <span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p><a id="CrossEntropyBayesRisk"></a></p>
|
||||
<h2>Bayes Risk with Cross Entropy Loss</h2>
|
||||
<p>Bayes risk is the overall maximum cost of making incorrect estimates.
|
||||
It takes a cost function that gives the cost of making an incorrect estimate
|
||||
and sums it over all possible outcomes based on probability distribution.</p>
|
||||
<p>Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$
|
||||
<script type="math/tex; mode=display">\sum_{k=1}^K -y_k \log p_k</script>
|
||||
</p>
|
||||
<p>We integrate this cost over all $\mathbf{p}$</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= -\log \Bigg(
|
||||
\int
|
||||
\Big[ \sum_{k=1}^K -y_k \log p_k \Big]
|
||||
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
|
||||
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
|
||||
d\mathbf{p}
|
||||
\Bigg ) \\
|
||||
&= \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>where $\psi(\cdot)$ is the $digamma$ function.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">101</span><span class="k">class</span> <span class="nc">CrossEntropyBayesRisk</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>evidence</code> is $\mathbf{e} \ge 0$ with shape <code>[batch_size, n_classes]</code></li>
|
||||
<li><code>target</code> is $\mathbf{y}$ with shape <code>[batch_size, n_classes]</code></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>$S = \sum_{k=1}^K \color{cyan}{\alpha_k}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">strength</span> <span class="o">=</span> <span class="n">alpha</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">target</span> <span class="o">*</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">strength</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">alpha</span><span class="p">)))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Mean loss over the batch</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p><a id="SquaredErrorBayesRisk"></a></p>
|
||||
<h2>Bayes Risk with Squared Error Loss</h2>
|
||||
<p>Here the cost function is squared error,
|
||||
<script type="math/tex; mode=display">\sum_{k=1}^K (y_k - p_k)^2 = \Vert \mathbf{y} - \mathbf{p} \Vert_2^2</script>
|
||||
</p>
|
||||
<p>We integrate this cost over all $\mathbf{p}$</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= -\log \Bigg(
|
||||
\int
|
||||
\Big[ \sum_{k=1}^K (y_k - p_k)^2 \Big]
|
||||
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
|
||||
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
|
||||
d\mathbf{p}
|
||||
\Bigg ) \\
|
||||
&= \sum_{k=1}^K \mathbb{E} \Big[ y_k^2 -2 y_k p_k + p_k^2 \Big] \\
|
||||
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>Where <script type="math/tex; mode=display">\mathbb{E}[p_k] = \hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}</script>
|
||||
is the expected probability when sampled from the Dirichlet distribution
|
||||
and <script type="math/tex; mode=display">\mathbb{E}[p_k^2] = \mathbb{E}[p_k]^2 + \text{Var}(p_k)</script>
|
||||
where
|
||||
<script type="math/tex; mode=display">\text{Var}(p_k) = \frac{\color{cyan}{\alpha_k}(S - \color{cyan}{\alpha_k})}{S^2 (S + 1)}
|
||||
= \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}</script>
|
||||
is the variance.</p>
|
||||
<p>This gives,
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big) \\
|
||||
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k]^2 + \text{Var}(p_k) \Big) \\
|
||||
&= \sum_{k=1}^K \Big( \big( y_k -\mathbb{E}[p_k] \big)^2 + \text{Var}(p_k) \Big) \\
|
||||
&= \sum_{k=1}^K \Big( ( y_k -\hat{p}_k)^2 + \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1} \Big)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and
|
||||
the second part is the variance.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">147</span><span class="k">class</span> <span class="nc">SquaredErrorBayesRisk</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>evidence</code> is $\mathbf{e} \ge 0$ with shape <code>[batch_size, n_classes]</code></li>
|
||||
<li><code>target</code> is $\mathbf{y}$ with shape <code>[batch_size, n_classes]</code></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">191</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>$S = \sum_{k=1}^K \color{cyan}{\alpha_k}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">strength</span> <span class="o">=</span> <span class="n">alpha</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>$\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">p</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">/</span> <span class="n">strength</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Error $(y_k -\hat{p}_k)^2$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">err</span> <span class="o">=</span> <span class="p">(</span><span class="n">target</span> <span class="o">-</span> <span class="n">p</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">var</span> <span class="o">=</span> <span class="n">p</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">p</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">strength</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Sum of them</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">209</span> <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">err</span> <span class="o">+</span> <span class="n">var</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Mean loss over the batch</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">212</span> <span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p><a id="KLDivergenceLoss"></a></p>
|
||||
<h2>KL Divergence Regularization Loss</h2>
|
||||
<p>This tries to shrink the total evidence to zero if the sample cannot be correctly classified.</p>
|
||||
<p>First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$ the
|
||||
Dirichlet parameters after remove the correct evidence.</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
&KL \Big[ D(\mathbf{p} \vert \mathbf{\tilde{\alpha}}) \Big \Vert
|
||||
D(\mathbf{p} \vert <1, \dots, 1>\Big] \\
|
||||
&= \log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
|
||||
{\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg)
|
||||
+ \sum_{k=1}^K (\tilde{\alpha}_k - 1)
|
||||
\Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>where $\Gamma(\cdot)$ is the gamma function,
|
||||
$\psi(\cdot)$ is the $digamma$ function and
|
||||
$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">215</span><span class="k">class</span> <span class="nc">KLDivergenceLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>evidence</code> is $\mathbf{e} \ge 0$ with shape <code>[batch_size, n_classes]</code></li>
|
||||
<li><code>target</code> is $\mathbf{y}$ with shape <code>[batch_size, n_classes]</code></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">238</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-24'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">244</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-25'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Number of classes</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">246</span> <span class="n">n_classes</span> <span class="o">=</span> <span class="n">evidence</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-26'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<p>Remove non-misleading evidence
|
||||
<script type="math/tex; mode=display">\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">249</span> <span class="n">alpha_tilde</span> <span class="o">=</span> <span class="n">target</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">target</span><span class="p">)</span> <span class="o">*</span> <span class="n">alpha</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-27'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-27'>#</a>
|
||||
</div>
|
||||
<p>$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">251</span> <span class="n">strength_tilde</span> <span class="o">=</span> <span class="n">alpha_tilde</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-28'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-28'>#</a>
|
||||
</div>
|
||||
<p>The first term
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
&\log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
|
||||
{\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg) \\
|
||||
&= \log \Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)
|
||||
- \log \Gamma(K)
|
||||
- \sum_{k=1}^K \log \Gamma(\tilde{\alpha}_k)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">261</span> <span class="n">first</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">lgamma</span><span class="p">(</span><span class="n">alpha_tilde</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
|
||||
<span class="lineno">262</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">lgamma</span><span class="p">(</span><span class="n">alpha_tilde</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">n_classes</span><span class="p">)))</span>
|
||||
<span class="lineno">263</span> <span class="o">-</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">lgamma</span><span class="p">(</span><span class="n">alpha_tilde</span><span class="p">))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-29'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-29'>#</a>
|
||||
</div>
|
||||
<p>The second term
|
||||
<script type="math/tex; mode=display">\sum_{k=1}^K (\tilde{\alpha}_k - 1)
|
||||
\Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">268</span> <span class="n">second</span> <span class="o">=</span> <span class="p">(</span>
|
||||
<span class="lineno">269</span> <span class="p">(</span><span class="n">alpha_tilde</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span>
|
||||
<span class="lineno">270</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">alpha_tilde</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">digamma</span><span class="p">(</span><span class="n">strength_tilde</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">])</span>
|
||||
<span class="lineno">271</span> <span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-30'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-30'>#</a>
|
||||
</div>
|
||||
<p>Sum of the terms</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">274</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">first</span> <span class="o">+</span> <span class="n">second</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-31'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-31'>#</a>
|
||||
</div>
|
||||
<p>Mean loss over the batch</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">277</span> <span class="k">return</span> <span class="n">loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-32'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-32'>#</a>
|
||||
</div>
|
||||
<p><a id="TrackStatistics"></a></p>
|
||||
<h3>Track statistics</h3>
|
||||
<p>This module computes statistics and tracks them with <a href="https://docs.labml.ai/api/tracker.html">labml <code>tracker</code></a>.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">280</span><span class="k">class</span> <span class="nc">TrackStatistics</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-33'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-33'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">287</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">evidence</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-34'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-34'>#</a>
|
||||
</div>
|
||||
<p>Number of classes</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">289</span> <span class="n">n_classes</span> <span class="o">=</span> <span class="n">evidence</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-35'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-35'>#</a>
|
||||
</div>
|
||||
<p>Predictions that correctly match with the target (greedy sampling based on highest probability)</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">291</span> <span class="n">match</span> <span class="o">=</span> <span class="n">evidence</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-36'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-36'>#</a>
|
||||
</div>
|
||||
<p>Track accuracy</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">293</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">'accuracy.'</span><span class="p">,</span> <span class="n">match</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="n">match</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-37'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-37'>#</a>
|
||||
</div>
|
||||
<p>$\color{cyan}{\alpha_k} = e_k + 1$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">296</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">evidence</span> <span class="o">+</span> <span class="mf">1.</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-38'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-38'>#</a>
|
||||
</div>
|
||||
<p>$S = \sum_{k=1}^K \color{cyan}{\alpha_k}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">298</span> <span class="n">strength</span> <span class="o">=</span> <span class="n">alpha</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-39'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-39'>#</a>
|
||||
</div>
|
||||
<p>$\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">301</span> <span class="n">expected_probability</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">/</span> <span class="n">strength</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-40'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-40'>#</a>
|
||||
</div>
|
||||
<p>Expected probability of the selected (greedy highset probability) class</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">303</span> <span class="n">expected_probability</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">expected_probability</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-41'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-41'>#</a>
|
||||
</div>
|
||||
<p>Uncertainty mass $u = \frac{K}{S}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">306</span> <span class="n">uncertainty_mass</span> <span class="o">=</span> <span class="n">n_classes</span> <span class="o">/</span> <span class="n">strength</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-42'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-42'>#</a>
|
||||
</div>
|
||||
<p>Track $u$ for correctly predictions</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">309</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">'u.succ.'</span><span class="p">,</span> <span class="n">uncertainty_mass</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="n">match</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-43'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-43'>#</a>
|
||||
</div>
|
||||
<p>Track $u$ for incorrect predictions</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">311</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">'u.fail.'</span><span class="p">,</span> <span class="n">uncertainty_mass</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="o">~</span><span class="n">match</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-44'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-44'>#</a>
|
||||
</div>
|
||||
<p>Track $\hat{p}_k$ for correctly predictions</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">313</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">'prob.succ.'</span><span class="p">,</span> <span class="n">expected_probability</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="n">match</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-45'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-45'>#</a>
|
||||
</div>
|
||||
<p>Track $\hat{p}_k$ for incorrect predictions</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">315</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">'prob.fail.'</span><span class="p">,</span> <span class="n">expected_probability</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="o">~</span><span class="n">match</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
144
docs/uncertainty/evidence/readme.html
Normal file
144
docs/uncertainty/evidence/readme.html
Normal file
@ -0,0 +1,144 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content=""/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
|
||||
<meta name="twitter:description" content=""/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/uncertainty/evidence/readme.html"/>
|
||||
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Evidential Deep Learning to Quantify Classification Uncertainty"/>
|
||||
<meta property="og:description" content=""/>
|
||||
|
||||
<title>Evidential Deep Learning to Quantify Classification Uncertainty</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/uncertainty/evidence/readme.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">uncertainty</a>
|
||||
<a class="parent" href="index.html">evidence</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/readme.md">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1><a href="https://nn.labml.ai/uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></h1>
|
||||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
|
||||
<a href="https://papers.labml.ai/paper/1806.01768">Evidential Deep Learning to Quantify Classification Uncertainty</a>.</p>
|
||||
<p>Here is the <a href="https://nn.labml.ai/uncertainty/evidence/experiment.html">training code <code>experiment.py</code></a> to train a model on MNIST dataset.</p>
|
||||
<p><a href="https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
143
docs/uncertainty/index.html
Normal file
143
docs/uncertainty/index.html
Normal file
@ -0,0 +1,143 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="A set of PyTorch implementations/tutorials related to uncertainty estimation"/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Neural Networks with Uncertainty Estimation"/>
|
||||
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials related to uncertainty estimation"/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/uncertainty/index.html"/>
|
||||
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
|
||||
<meta property="og:description" content="A set of PyTorch implementations/tutorials related to uncertainty estimation"/>
|
||||
|
||||
<title>Neural Networks with Uncertainty Estimation</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/uncertainty/index.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="index.html">uncertainty</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/__init__.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Neural Networks with Uncertainty Estimation</h1>
|
||||
<p>These are neural network architectures that estimate the uncertainty of the predictions.</p>
|
||||
<ul>
|
||||
<li><a href="evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
143
docs/uncertainty/readme.html
Normal file
143
docs/uncertainty/readme.html
Normal file
@ -0,0 +1,143 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content=""/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Neural Networks with Uncertainty Estimation"/>
|
||||
<meta name="twitter:description" content=""/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/uncertainty/readme.html"/>
|
||||
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Neural Networks with Uncertainty Estimation"/>
|
||||
<meta property="og:description" content=""/>
|
||||
|
||||
<title>Neural Networks with Uncertainty Estimation</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/uncertainty/readme.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="index.html">uncertainty</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/readme.md">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1><a href="https://nn.labml.ai/uncertainty/index.html">Neural Networks with Uncertainty Estimation</a></h1>
|
||||
<p>These are neural network architectures that estimate the uncertainty of the predictions.</p>
|
||||
<ul>
|
||||
<li><a href="https://nn.labml.ai/uncertainty/evidence/index.html">Evidential Deep Learning to Quantify Classification Uncertainty</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -94,6 +94,10 @@ Solving games with incomplete information such as poker with CFR.
|
||||
|
||||
* [PonderNet](adaptive_computation/ponder_net/index.html)
|
||||
|
||||
#### ✨ [Uncertainty](uncertainty/index.html)
|
||||
|
||||
* [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html)
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
@ -102,12 +106,12 @@ pip install labml-nn
|
||||
|
||||
### Citing LabML
|
||||
|
||||
If you use LabML for academic research, please cite the library using the following BibTeX entry.
|
||||
If you use this for academic research, please cite it using the following BibTeX entry.
|
||||
|
||||
```bibtex
|
||||
@misc{labml,
|
||||
author = {Varuna Jayasiri, Nipun Wijerathne},
|
||||
title = {LabML: A library to organize machine learning experiments},
|
||||
title = {labml.ai Annotated Paper Implementations},
|
||||
year = {2020},
|
||||
url = {https://nn.labml.ai/},
|
||||
}
|
||||
|
@ -72,7 +72,10 @@ def main():
|
||||
# Create configurations
|
||||
conf = MNISTConfigs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
|
||||
experiment.configs(conf, {
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 0.001,
|
||||
})
|
||||
# Start the experiment and run the training loop
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
13
labml_nn/uncertainty/__init__.py
Normal file
13
labml_nn/uncertainty/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""
|
||||
---
|
||||
title: Neural Networks with Uncertainty Estimation
|
||||
summary: >
|
||||
A set of PyTorch implementations/tutorials related to uncertainty estimation
|
||||
---
|
||||
|
||||
# Neural Networks with Uncertainty Estimation
|
||||
|
||||
These are neural network architectures that estimate the uncertainty of the predictions.
|
||||
|
||||
* [Evidential Deep Learning to Quantify Classification Uncertainty](evidence/index.html)
|
||||
"""
|
315
labml_nn/uncertainty/evidence/__init__.py
Normal file
315
labml_nn/uncertainty/evidence/__init__.py
Normal file
@ -0,0 +1,315 @@
|
||||
"""
|
||||
---
|
||||
title: "Evidential Deep Learning to Quantify Classification Uncertainty"
|
||||
summary: >
|
||||
A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification
|
||||
Uncertainty.
|
||||
---
|
||||
|
||||
# Evidential Deep Learning to Quantify Classification Uncertainty
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of the paper
|
||||
[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
|
||||
|
||||
[Dampster-Shafer Theory of Evidence](https://en.wikipedia.org/wiki/Dempster%E2%80%93Shafer_theory)
|
||||
assigns belief masses a set of classes (unlike assigning a probability to a single class).
|
||||
Sum of the masses of all subsets is $1$.
|
||||
Individual class probabilities (plausibilities) can be derived from these masses.
|
||||
|
||||
Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".
|
||||
|
||||
If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and
|
||||
an overall uncertainty mass $u \ge 0$ to all classes.
|
||||
|
||||
$$u + \sum_{k=1}^K b_k = 1$$
|
||||
|
||||
Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$
|
||||
and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$.
|
||||
Paper uses term evidence as a measure of the amount of support
|
||||
collected from data in favor of a sample to be classified into a certain class.
|
||||
|
||||
This corresponds to a [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution)
|
||||
with parameters $\color{cyan}{\alpha_k} = e_k + 1$, and
|
||||
$\color{cyan}{\alpha_0} = S = \sum_{k=1}^K \color{cyan}{\alpha_k}$ is known as the Dirichlet strength.
|
||||
Dirichlet distribution $D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}})$
|
||||
is a distribution over categorical distribution; i.e. you can sample class probabilities
|
||||
from a Dirichlet distribution.
|
||||
The expected probability for class $k$ is $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$.
|
||||
|
||||
We get the model to output evidences
|
||||
$$\mathbf{e} = \color{cyan}{\mathbf{\alpha}} - 1 = f(\mathbf{x} | \Theta)$$
|
||||
for a given input $\mathbf{x}$.
|
||||
We use a function such as
|
||||
[ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) or a
|
||||
[Softplus](https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html)
|
||||
at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.
|
||||
|
||||
The paper proposes a few loss functions to train the model, which we have implemented below.
|
||||
|
||||
Here is the [training code `experiment.py`](experiment.html) to train a model on MNIST dataset.
|
||||
|
||||
[](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from labml import tracker
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
class MaximumLikelihoodLoss(Module):
|
||||
"""
|
||||
<a id="MaximumLikelihoodLoss"></a>
|
||||
## Type II Maximum Likelihood Loss
|
||||
|
||||
The distribution D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}}) is a prior on the likelihood
|
||||
$Multi(\mathbf{y} \vert p)$,
|
||||
and the negative log marginal likelihood is calculated by integrating over class probabilities
|
||||
$\mathbf{p}$.
|
||||
|
||||
If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,
|
||||
|
||||
\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= -\log \Bigg(
|
||||
\int
|
||||
\prod_{k=1}^K p_k^{y_k}
|
||||
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
|
||||
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
|
||||
d\mathbf{p}
|
||||
\Bigg ) \\
|
||||
&= \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)
|
||||
\end{align}
|
||||
"""
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
|
||||
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
|
||||
"""
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
|
||||
strength = alpha.sum(dim=-1)
|
||||
|
||||
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)$
|
||||
loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)
|
||||
|
||||
# Mean loss over the batch
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class CrossEntropyBayesRisk(Module):
|
||||
"""
|
||||
<a id="CrossEntropyBayesRisk"></a>
|
||||
## Bayes Risk with Cross Entropy Loss
|
||||
|
||||
Bayes risk is the overall maximum cost of making incorrect estimates.
|
||||
It takes a cost function that gives the cost of making an incorrect estimate
|
||||
and sums it over all possible outcomes based on probability distribution.
|
||||
|
||||
Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$
|
||||
$$\sum_{k=1}^K -y_k \log p_k$$
|
||||
|
||||
We integrate this cost over all $\mathbf{p}$
|
||||
|
||||
\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= -\log \Bigg(
|
||||
\int
|
||||
\Big[ \sum_{k=1}^K -y_k \log p_k \Big]
|
||||
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
|
||||
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
|
||||
d\mathbf{p}
|
||||
\Bigg ) \\
|
||||
&= \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)
|
||||
\end{align}
|
||||
|
||||
where $\psi(\cdot)$ is the $digamma$ function.
|
||||
"""
|
||||
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
|
||||
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
|
||||
"""
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
|
||||
strength = alpha.sum(dim=-1)
|
||||
|
||||
# Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)$
|
||||
loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)
|
||||
|
||||
# Mean loss over the batch
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class SquaredErrorBayesRisk(Module):
|
||||
"""
|
||||
<a id="SquaredErrorBayesRisk"></a>
|
||||
## Bayes Risk with Squared Error Loss
|
||||
|
||||
Here the cost function is squared error,
|
||||
$$\sum_{k=1}^K (y_k - p_k)^2 = \Vert \mathbf{y} - \mathbf{p} \Vert_2^2$$
|
||||
|
||||
We integrate this cost over all $\mathbf{p}$
|
||||
|
||||
\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= -\log \Bigg(
|
||||
\int
|
||||
\Big[ \sum_{k=1}^K (y_k - p_k)^2 \Big]
|
||||
\frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
|
||||
\prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
|
||||
d\mathbf{p}
|
||||
\Bigg ) \\
|
||||
&= \sum_{k=1}^K \mathbb{E} \Big[ y_k^2 -2 y_k p_k + p_k^2 \Big] \\
|
||||
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big)
|
||||
\end{align}
|
||||
|
||||
Where $$\mathbb{E}[p_k] = \hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$$
|
||||
is the expected probability when sampled from the Dirichlet distribution
|
||||
and $$\mathbb{E}[p_k^2] = \mathbb{E}[p_k]^2 + \text{Var}(p_k)$$
|
||||
where
|
||||
$$\text{Var}(p_k) = \frac{\color{cyan}{\alpha_k}(S - \color{cyan}{\alpha_k})}{S^2 (S + 1)}
|
||||
= \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$$
|
||||
is the variance.
|
||||
|
||||
This gives,
|
||||
\begin{align}
|
||||
\mathcal{L}(\Theta)
|
||||
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big) \\
|
||||
&= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k]^2 + \text{Var}(p_k) \Big) \\
|
||||
&= \sum_{k=1}^K \Big( \big( y_k -\mathbb{E}[p_k] \big)^2 + \text{Var}(p_k) \Big) \\
|
||||
&= \sum_{k=1}^K \Big( ( y_k -\hat{p}_k)^2 + \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1} \Big)
|
||||
\end{align}
|
||||
|
||||
This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and
|
||||
the second part is the variance.
|
||||
"""
|
||||
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
|
||||
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
|
||||
"""
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
|
||||
strength = alpha.sum(dim=-1)
|
||||
# $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
|
||||
p = alpha / strength[:, None]
|
||||
|
||||
# Error $(y_k -\hat{p}_k)^2$
|
||||
err = (target - p) ** 2
|
||||
# Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$
|
||||
var = p * (1 - p) / (strength[:, None] + 1)
|
||||
|
||||
# Sum of them
|
||||
loss = (err + var).sum(dim=-1)
|
||||
|
||||
# Mean loss over the batch
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class KLDivergenceLoss(Module):
|
||||
"""
|
||||
<a id="KLDivergenceLoss"></a>
|
||||
## KL Divergence Regularization Loss
|
||||
|
||||
This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
|
||||
|
||||
First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$ the
|
||||
Dirichlet parameters after remove the correct evidence.
|
||||
|
||||
\begin{align}
|
||||
&KL \Big[ D(\mathbf{p} \vert \mathbf{\tilde{\alpha}}) \Big \Vert
|
||||
D(\mathbf{p} \vert <1, \dots, 1>\Big] \\
|
||||
&= \log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
|
||||
{\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg)
|
||||
+ \sum_{k=1}^K (\tilde{\alpha}_k - 1)
|
||||
\Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]
|
||||
\end{align}
|
||||
|
||||
where $\Gamma(\cdot)$ is the gamma function,
|
||||
$\psi(\cdot)$ is the $digamma$ function and
|
||||
$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
|
||||
"""
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
* `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
|
||||
* `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
|
||||
"""
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# Number of classes
|
||||
n_classes = evidence.shape[-1]
|
||||
# Remove non-misleading evidence
|
||||
# $$\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$$
|
||||
alpha_tilde = target + (1 - target) * alpha
|
||||
# $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
|
||||
strength_tilde = alpha_tilde.sum(dim=-1)
|
||||
|
||||
# The first term
|
||||
# \begin{align}
|
||||
# &\log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
|
||||
# {\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg) \\
|
||||
# &= \log \Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)
|
||||
# - \log \Gamma(K)
|
||||
# - \sum_{k=1}^K \log \Gamma(\tilde{\alpha}_k)
|
||||
# \end{align}
|
||||
first = (torch.lgamma(alpha_tilde.sum(dim=-1))
|
||||
- torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
|
||||
- (torch.lgamma(alpha_tilde)).sum(dim=-1))
|
||||
|
||||
# The second term
|
||||
# $$\sum_{k=1}^K (\tilde{\alpha}_k - 1)
|
||||
# \Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]$$
|
||||
second = (
|
||||
(alpha_tilde - 1) *
|
||||
(torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
|
||||
).sum(dim=-1)
|
||||
|
||||
# Sum of the terms
|
||||
loss = first + second
|
||||
|
||||
# Mean loss over the batch
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class TrackStatistics(Module):
|
||||
"""
|
||||
<a id="TrackStatistics"></a>
|
||||
### Track statistics
|
||||
|
||||
This module computes statistics and tracks them with [labml `tracker`](https://docs.labml.ai/api/tracker.html).
|
||||
"""
|
||||
def forward(self, evidence: torch.Tensor, target: torch.Tensor):
|
||||
# Number of classes
|
||||
n_classes = evidence.shape[-1]
|
||||
# Predictions that correctly match with the target (greedy sampling based on highest probability)
|
||||
match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))
|
||||
# Track accuracy
|
||||
tracker.add('accuracy.', match.sum() / match.shape[0])
|
||||
|
||||
# $\color{cyan}{\alpha_k} = e_k + 1$
|
||||
alpha = evidence + 1.
|
||||
# $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
|
||||
strength = alpha.sum(dim=-1)
|
||||
|
||||
# $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
|
||||
expected_probability = alpha / strength[:, None]
|
||||
# Expected probability of the selected (greedy highset probability) class
|
||||
expected_probability, _ = expected_probability.max(dim=-1)
|
||||
|
||||
# Uncertainty mass $u = \frac{K}{S}$
|
||||
uncertainty_mass = n_classes / strength
|
||||
|
||||
# Track $u$ for correctly predictions
|
||||
tracker.add('u.succ.', uncertainty_mass.masked_select(match))
|
||||
# Track $u$ for incorrect predictions
|
||||
tracker.add('u.fail.', uncertainty_mass.masked_select(~match))
|
||||
# Track $\hat{p}_k$ for correctly predictions
|
||||
tracker.add('prob.succ.', expected_probability.masked_select(match))
|
||||
# Track $\hat{p}_k$ for incorrect predictions
|
||||
tracker.add('prob.fail.', expected_probability.masked_select(~match))
|
225
labml_nn/uncertainty/evidence/experiment.py
Normal file
225
labml_nn/uncertainty/evidence/experiment.py
Normal file
@ -0,0 +1,225 @@
|
||||
"""
|
||||
---
|
||||
title: "Evidential Deep Learning to Quantify Classification Uncertainty Experiment"
|
||||
summary: >
|
||||
This trains is EDL model on MNIST
|
||||
---
|
||||
|
||||
# [Evidential Deep Learning to Quantify Classification Uncertainty](index.html) Experiment
|
||||
|
||||
This trains a model based on [Evidential Deep Learning to Quantify Classification Uncertainty](index.html)
|
||||
on MNIST dataset.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
|
||||
from labml import tracker, experiment
|
||||
from labml.configs import option, calculate
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.schedule import Schedule, RelativePiecewise
|
||||
from labml_helpers.train_valid import BatchIndex
|
||||
from labml_nn.experiments.mnist import MNISTConfigs
|
||||
from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
|
||||
CrossEntropyBayesRisk, SquaredErrorBayesRisk
|
||||
|
||||
|
||||
class Model(Module):
|
||||
"""
|
||||
## LeNet based model fro MNIST classification
|
||||
"""
|
||||
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
# First $5x5$ convolution layer
|
||||
self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
|
||||
# ReLU activation
|
||||
self.act1 = nn.ReLU()
|
||||
# $2x2$ max-pooling
|
||||
self.max_pool1 = nn.MaxPool2d(2, 2)
|
||||
# Second $5x5$ convolution layer
|
||||
self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
|
||||
# ReLU activation
|
||||
self.act2 = nn.ReLU()
|
||||
# $2x2$ max-pooling
|
||||
self.max_pool2 = nn.MaxPool2d(2, 2)
|
||||
# First fully-connected layer that maps to $500$ features
|
||||
self.fc1 = nn.Linear(50 * 4 * 4, 500)
|
||||
# ReLU activation
|
||||
self.act3 = nn.ReLU()
|
||||
# Final fully connected layer to output evidence for $10$ classes.
|
||||
# The ReLU or Softplus activation is applied to this outside the model to get the
|
||||
# non-negative evidence
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
# Dropout for the hidden layer
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the batch of MNIST images of shape `[batch_size, 1, 28, 28]`
|
||||
"""
|
||||
# Apply first convolution and max pooling.
|
||||
# The result has shape `[batch_size, 20, 12, 12]`
|
||||
x = self.max_pool1(self.act1(self.conv1(x)))
|
||||
# Apply second convolution and max pooling.
|
||||
# The result has shape `[batch_size, 50, 4, 4]`
|
||||
x = self.max_pool2(self.act2(self.conv2(x)))
|
||||
# Flatten the tensor to shape `[batch_size, 50 * 4 * 4]`
|
||||
x = x.view(x.shape[0], -1)
|
||||
# Apply hidden layer
|
||||
x = self.act3(self.fc1(x))
|
||||
# Apply dropout
|
||||
x = self.dropout(x)
|
||||
# Apply final layer and return
|
||||
return self.fc2(x)
|
||||
|
||||
|
||||
class Configs(MNISTConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
We use [`MNISTConfigs`](../../experiments/mnist.html#MNISTConfigs) configurations.
|
||||
"""
|
||||
|
||||
# [KL Divergence regularization](index.html#KLDivergenceLoss)
|
||||
kl_div_loss = KLDivergenceLoss()
|
||||
# KL Divergence regularization coefficient schedule
|
||||
kl_div_coef: Schedule
|
||||
# KL Divergence regularization coefficient schedule
|
||||
kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
|
||||
# [Stats module](index.html#TrackStatistics) for tracking
|
||||
stats = TrackStatistics()
|
||||
# Dropout
|
||||
dropout: float = 0.5
|
||||
# Module to convert the model output to non-zero evidences
|
||||
outputs_to_evidence: Module
|
||||
|
||||
def init(self):
|
||||
"""
|
||||
### Initialization
|
||||
"""
|
||||
# Set tracker configurations
|
||||
tracker.set_scalar("loss.*", True)
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
tracker.set_histogram('u.*', True)
|
||||
tracker.set_histogram('prob.*', False)
|
||||
tracker.set_scalar('annealing_coef.*', False)
|
||||
tracker.set_scalar('kl_div_loss.*', False)
|
||||
|
||||
#
|
||||
self.state_modules = []
|
||||
|
||||
def step(self, batch: Any, batch_idx: BatchIndex):
|
||||
"""
|
||||
### Training or validation step
|
||||
"""
|
||||
|
||||
# Training/Evaluation mode
|
||||
self.model.train(self.mode.is_train)
|
||||
|
||||
# Move data to the device
|
||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||
|
||||
# One-hot coded targets
|
||||
eye = torch.eye(10).to(torch.float).to(self.device)
|
||||
target = eye[target]
|
||||
|
||||
# Update global step (number of samples processed) when in training mode
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
# Get model outputs
|
||||
outputs = self.model(data)
|
||||
# Get evidences $e_k \ge 0$
|
||||
evidence = self.outputs_to_evidence(outputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = self.loss_func(evidence, target)
|
||||
# Calculate KL Divergence regularization loss
|
||||
kl_div_loss = self.kl_div_loss(evidence, target)
|
||||
tracker.add("loss.", loss)
|
||||
tracker.add("kl_div_loss.", kl_div_loss)
|
||||
|
||||
# KL Divergence loss coefficient $\lambda_t$
|
||||
annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
|
||||
tracker.add("annealing_coef.", annealing_coef)
|
||||
|
||||
# Total loss
|
||||
loss = loss + annealing_coef * kl_div_loss
|
||||
|
||||
# Track statistics
|
||||
self.stats(evidence, target)
|
||||
|
||||
# Train the model
|
||||
if self.mode.is_train:
|
||||
# Calculate gradients
|
||||
loss.backward()
|
||||
# Take optimizer step
|
||||
self.optimizer.step()
|
||||
# Clear the gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Save the tracked metrics
|
||||
tracker.save()
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def mnist_model(c: Configs):
|
||||
"""
|
||||
### Create model
|
||||
"""
|
||||
return Model(c.dropout).to(c.device)
|
||||
|
||||
|
||||
@option(Configs.kl_div_coef)
|
||||
def kl_div_coef(c: Configs):
|
||||
"""
|
||||
### KL Divergence Loss Coefficient Schedule
|
||||
"""
|
||||
|
||||
# Create a [relative piecewise schedule](https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise)
|
||||
return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
|
||||
|
||||
|
||||
# [Maximum Likelihood Loss](index.html#MaximumLikelihoodLoss)
|
||||
calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
|
||||
# [Cross Entropy Bayes Risk](index.html#CrossEntropyBayesRisk)
|
||||
calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
|
||||
# [Squared Error Bayes Risk](index.html#SquaredErrorBayesRisk)
|
||||
calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
|
||||
|
||||
# ReLU to calculate evidence
|
||||
calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
|
||||
# Softplus to calculate evidence
|
||||
calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='evidence_mnist')
|
||||
# Create configurations
|
||||
conf = Configs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 0.001,
|
||||
'optimizer.weight_decay': 0.005,
|
||||
|
||||
# 'loss_func': 'max_likelihood_loss',
|
||||
# 'loss_func': 'cross_entropy_bayes_risk',
|
||||
'loss_func': 'squared_error_bayes_risk',
|
||||
|
||||
'outputs_to_evidence': 'softplus',
|
||||
|
||||
'dropout': 0.5,
|
||||
})
|
||||
# Start the experiment and run the training loop
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
8
labml_nn/uncertainty/evidence/readme.md
Normal file
8
labml_nn/uncertainty/evidence/readme.md
Normal file
@ -0,0 +1,8 @@
|
||||
# [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of the paper
|
||||
[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
|
||||
|
||||
Here is the [training code `experiment.py`](https://nn.labml.ai/uncertainty/evidence/experiment.html) to train a model on MNIST dataset.
|
||||
|
||||
[](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)
|
5
labml_nn/uncertainty/readme.md
Normal file
5
labml_nn/uncertainty/readme.md
Normal file
@ -0,0 +1,5 @@
|
||||
# [Neural Networks with Uncertainty Estimation](https://nn.labml.ai/uncertainty/index.html)
|
||||
|
||||
These are neural network architectures that estimate the uncertainty of the predictions.
|
||||
|
||||
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
|
@ -99,6 +99,10 @@ Solving games with incomplete information such as poker with CFR.
|
||||
|
||||
* [PonderNet](https://nn.labml.ai/adaptive_computation/ponder_net/index.html)
|
||||
|
||||
#### ✨ [Uncertainty](https://nn.labml.ai/uncertainty/index.html)
|
||||
|
||||
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
|
@ -1,6 +1,6 @@
|
||||
torch>=1.7
|
||||
labml>=0.4.94
|
||||
labml-helpers>=0.4.77
|
||||
labml>=0.4.132
|
||||
labml-helpers>=0.4.81
|
||||
torchvision
|
||||
numpy>=1.16.3
|
||||
matplotlib>=3.0.3
|
||||
|
6
setup.py
6
setup.py
@ -5,10 +5,10 @@ with open("readme.md", "r") as f:
|
||||
|
||||
setuptools.setup(
|
||||
name='labml-nn',
|
||||
version='0.4.109',
|
||||
version='0.4.110',
|
||||
author="Varuna Jayasiri, Nipun Wijerathne",
|
||||
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
||||
description="🧠 Implementations/tutorials of deep learning papers with side-by-side notes; including transformers (original, xl, switch, feedback, vit), optimizers(adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), reinforcement learning (ppo, dqn), capsnet, distillation, etc.",
|
||||
description="🧑🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/labmlai/annotated_deep_learning_paper_implementations",
|
||||
@ -20,7 +20,7 @@ setuptools.setup(
|
||||
'labml_helpers', 'labml_helpers.*',
|
||||
'test',
|
||||
'test.*')),
|
||||
install_requires=['labml>=0.4.129',
|
||||
install_requires=['labml>=0.4.132',
|
||||
'labml-helpers>=0.4.81',
|
||||
'torch',
|
||||
'einops',
|
||||
|
Reference in New Issue
Block a user