Files
2024-06-21 19:35:22 +05:30

910 lines
55 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="这列火车是 MNIST 上的 EDL 型号"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="基于证据的深度学习量化分类不确定性实验"/>
<meta name="twitter:description" content="这列火车是 MNIST 上的 EDL 型号"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/uncertainty/evidence/experiment.html"/>
<meta property="og:title" content="基于证据的深度学习量化分类不确定性实验"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="基于证据的深度学习量化分类不确定性实验"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="基于证据的深度学习量化分类不确定性实验"/>
<meta property="og:description" content="这列火车是 MNIST 上的 EDL 型号"/>
<title>基于证据的深度学习量化分类不确定性实验</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/uncertainty/evidence/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">uncertainty</a>
<a class="parent" href="index.html">evidence</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/uncertainty/evidence/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>基于<a href="index.html">证据的深度学习量化分类不确定性</a>实验</h1>
<p>这将训练一个基于<a href="index.html">证据深度学习的模型以量化MNIST数据集上的分类不确定性</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
<span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">17</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">18</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span><span class="p">,</span> <span class="n">calculate</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_helpers.schedule</span> <span class="kn">import</span> <span class="n">Schedule</span><span class="p">,</span> <span class="n">RelativePiecewise</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">BatchIndex</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.mnist</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml_nn.uncertainty.evidence</span> <span class="kn">import</span> <span class="n">KLDivergenceLoss</span><span class="p">,</span> <span class="n">TrackStatistics</span><span class="p">,</span> <span class="n">MaximumLikelihoodLoss</span><span class="p">,</span> \
<span class="lineno">26</span> <span class="n">CrossEntropyBayesRisk</span><span class="p">,</span> <span class="n">SquaredErrorBayesRisk</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>基于 LeNET 的 MINST 分类模型</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
<span class="lineno">35</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>第一个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style="">5</span><span class="mord mathnormal" style="">x</span><span class="mord" style="">5</span></span></span></span></span></span>卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>激活 ReLU</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">2</span><span class="mord mathnormal" style="">x</span><span class="mord" style="">2</span></span></span></span></span></span>max-pooling</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>第二个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style="">5</span><span class="mord mathnormal" style="">x</span><span class="mord" style="">5</span></span></span></span></span></span>卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>激活 ReLU</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">2</span><span class="mord mathnormal" style="">x</span><span class="mord" style="">2</span></span></span></span></span></span>max-pooling</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>第一个映射到要素的完全连接的<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">500</span></span></span></span></span>图层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">50</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>激活 ReLU</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="bp">self</span><span class="o">.</span><span class="n">act3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>最后一个完全连接的层,用于输出<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">10</span></span></span></span></span>课堂证据。RelU 或 Softplus 激活在模型之外应用于此,以获得非负面证据</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">500</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>隐藏图层的退出</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
是一批 MNIST 形状的图像<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>应用第一个卷积和最大池。结果有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">12</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>应用第二个卷积和最大池。结果有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_pool2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>将张量展平成形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">50</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">*</span> <span class="mi">4</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>应用隐藏层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act3</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>申请退学</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>应用最后一层然后返回</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<h2>配置</h2>
<p>我们使用<a href="../../experiments/mnist.html#MNISTConfigs"><code class="highlight"><span></span><span class="n">MNISTConfigs</span></code>
</a>配置。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p><a href="index.html#KLDivergenceLoss">KL 分歧正则化</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="n">KLDivergenceLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>KL 发散正则化系数时间表</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">kl_div_coef</span><span class="p">:</span> <span class="n">Schedule</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>KL 发散正则化系数时间表</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">kl_div_coef_schedule</span> <span class="o">=</span> <span class="p">[(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>用于跟踪的<a href="index.html#TrackStatistics">统计模块</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">TrackStatistics</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>辍学</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>用于将模型输出转换为非零证据的模块</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">outputs_to_evidence</span><span class="p">:</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<h3>初始化</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>设置跟踪器配置</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;loss.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">105</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;accuracy.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">106</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_histogram</span><span class="p">(</span><span class="s1">&#39;u.*&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">107</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_histogram</span><span class="p">(</span><span class="s1">&#39;prob.*&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="lineno">108</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s1">&#39;annealing_coef.*&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="lineno">109</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s1">&#39;kl_div_loss.*&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<h3>培训或验证步骤</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>训练/评估模式</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>将数据移动到设备</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>一热编码目标</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">eye</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">127</span> <span class="n">target</span> <span class="o">=</span> <span class="n">eye</span><span class="p">[</span><span class="n">target</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>在训练模式下更新全局步长(处理的样本数)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">131</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>获取模型输出</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>获取证据<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.7859700000000001em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">evidence</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">outputs_to_evidence</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>计算损失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">evidence</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>计算 KL 背离正则化损失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss</span><span class="p">(</span><span class="n">evidence</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<span class="lineno">142</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>
<span class="lineno">143</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;kl_div_loss.&quot;</span><span class="p">,</span> <span class="n">kl_div_loss</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>KL 背离损失系数<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">λ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">annealing_coef</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mf">1.</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_coef</span><span class="p">(</span><span class="n">tracker</span><span class="o">.</span><span class="n">get_global_step</span><span class="p">()))</span>
<span class="lineno">147</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;annealing_coef.&quot;</span><span class="p">,</span> <span class="n">annealing_coef</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>总亏损</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">+</span> <span class="n">annealing_coef</span> <span class="o">*</span> <span class="n">kl_div_loss</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>追踪统计数据</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">stats</span><span class="p">(</span><span class="n">evidence</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>训练模型</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>计算梯度</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>采取优化器步骤</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>清除渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>保存跟踪的指标</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<h3>创建模型</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">169</span><span class="k">def</span> <span class="nf">mnist_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">return</span> <span class="n">Model</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<h3>KL 背离损失系数时间表</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">kl_div_coef</span><span class="p">)</span>
<span class="lineno">177</span><span class="k">def</span> <span class="nf">kl_div_coef</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>创建<a href="https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise">相对的分段时间表</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">return</span> <span class="n">RelativePiecewise</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">kl_div_coef_schedule</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">epochs</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">train_dataset</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p><a href="index.html#MaximumLikelihoodLoss">最大似然损失</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">loss_func</span><span class="p">,</span> <span class="s1">&#39;max_likelihood_loss&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">MaximumLikelihoodLoss</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p><a href="index.html#CrossEntropyBayesRisk">交叉熵贝叶斯风险</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">loss_func</span><span class="p">,</span> <span class="s1">&#39;cross_entropy_bayes_risk&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">CrossEntropyBayesRisk</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p><a href="index.html#SquaredErrorBayesRisk">平方误差贝叶斯风险</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">loss_func</span><span class="p">,</span> <span class="s1">&#39;squared_error_bayes_risk&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">SquaredErrorBayesRisk</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>RelU 来计算证据</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">outputs_to_evidence</span><span class="p">,</span> <span class="s1">&#39;relu&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Softplus 用于计算证据</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">outputs_to_evidence</span><span class="p">,</span> <span class="s1">&#39;softplus&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softplus</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>创建实验</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;evidence_mnist&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>创建配置</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>装载配置</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">206</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">207</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">0.001</span><span class="p">,</span>
<span class="lineno">208</span> <span class="s1">&#39;optimizer.weight_decay&#39;</span><span class="p">:</span> <span class="mf">0.005</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>'loss_func''max_imilihood_loss''loss_func''cross_entropy_bayes_risk '</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="s1">&#39;loss_func&#39;</span><span class="p">:</span> <span class="s1">&#39;squared_error_bayes_risk&#39;</span><span class="p">,</span>
<span class="lineno">213</span>
<span class="lineno">214</span> <span class="s1">&#39;outputs_to_evidence&#39;</span><span class="p">:</span> <span class="s1">&#39;softplus&#39;</span><span class="p">,</span>
<span class="lineno">215</span>
<span class="lineno">216</span> <span class="s1">&#39;dropout&#39;</span><span class="p">:</span> <span class="mf">0.5</span><span class="p">,</span>
<span class="lineno">217</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>开始实验并运行训练循环</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">220</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">225</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>