mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 01:13:00 +08:00
779 lines
73 KiB
HTML
779 lines
73 KiB
HTML
<!DOCTYPE html>
|
||
<html>
|
||
<head>
|
||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||
<meta name="description" content="PyTorch implementation and tutorial of the paper Distilling the Knowledge in a Neural Network."/>
|
||
|
||
<meta name="twitter:card" content="summary"/>
|
||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta name="twitter:title" content="Distilling the Knowledge in a Neural Network"/>
|
||
<meta name="twitter:description" content="PyTorch implementation and tutorial of the paper Distilling the Knowledge in a Neural Network."/>
|
||
<meta name="twitter:site" content="@labmlai"/>
|
||
<meta name="twitter:creator" content="@labmlai"/>
|
||
|
||
<meta property="og:url" content="https://nn.labml.ai/distillation/index.html"/>
|
||
<meta property="og:title" content="Distilling the Knowledge in a Neural Network"/>
|
||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||
<meta property="og:type" content="object"/>
|
||
<meta property="og:title" content="Distilling the Knowledge in a Neural Network"/>
|
||
<meta property="og:description" content="PyTorch implementation and tutorial of the paper Distilling the Knowledge in a Neural Network."/>
|
||
|
||
<title>Distilling the Knowledge in a Neural Network</title>
|
||
<link rel="shortcut icon" href="/icon.png"/>
|
||
<link rel="stylesheet" href="../pylit.css?v=1">
|
||
<link rel="canonical" href="https://nn.labml.ai/distillation/index.html"/>
|
||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||
|
||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||
<script>
|
||
window.dataLayer = window.dataLayer || [];
|
||
|
||
function gtag() {
|
||
dataLayer.push(arguments);
|
||
}
|
||
|
||
gtag('js', new Date());
|
||
|
||
gtag('config', 'G-4V3HC8HBLH');
|
||
</script>
|
||
</head>
|
||
<body>
|
||
<div id='container'>
|
||
<div id="background"></div>
|
||
<div class='section'>
|
||
<div class='docs'>
|
||
<p>
|
||
<a class="parent" href="/">home</a>
|
||
<a class="parent" href="index.html">distillation</a>
|
||
</p>
|
||
<p>
|
||
|
||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/distillation/__init__.py">
|
||
<img alt="Github"
|
||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||
style="max-width:100%;"/></a>
|
||
<a href="https://twitter.com/labmlai"
|
||
rel="nofollow">
|
||
<img alt="Twitter"
|
||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||
style="max-width:100%;"/></a>
|
||
</p>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-0'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-0'>#</a>
|
||
</div>
|
||
<h1>Distilling the Knowledge in a Neural Network</h1>
|
||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of the paper <a href="https://papers.labml.ai/paper/1503.02531">Distilling the Knowledge in a Neural Network</a>.</p>
|
||
<p>It's a way of training a small network using the knowledge in a trained larger network; i.e. distilling the knowledge from the large network.</p>
|
||
<p>A large model with regularization or an ensemble of models (using dropout) generalizes better than a small model when trained directly on the data and labels. However, a small model can be trained to generalize better with help of a large model. Smaller models are better in production: faster, less compute, less memory.</p>
|
||
<p>The output probabilities of a trained model give more information than the labels because it assigns non-zero probabilities to incorrect classes as well. These probabilities tell us that a sample has a chance of belonging to certain classes. For instance, when classifying digits, when given an image of digit <em>7</em>, a generalized model will give a high probability to 7 and a small but non-zero probability to 2, while assigning almost zero probability to other digits. Distillation uses this information to train a small model better.</p>
|
||
<h2>Soft Targets</h2>
|
||
<p>The probabilities are usually computed with a softmax operation,</p>
|
||
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.5488180000000003em;vertical-align:-1.1218180000000002em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.427em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.04398em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.04398em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.1218180000000002em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p>
|
||
<p>where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> is the probability for class <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="">i</span></span></span></span></span> and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.04398em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> is the logit.</p>
|
||
<p>We train the small model to minimize the Cross entropy or KL Divergence between its output probability distribution and the large network's output probability distribution (soft targets).</p>
|
||
<p>One of the problems here is that the probabilities assigned to incorrect classes by the large network are often very small and don't contribute to the loss. So they soften the probabilities by applying a temperature <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.13889em">T</span></span></span></span></span>,</p>
|
||
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.61953em;vertical-align:-1.13453em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.4849999999999999em;"><span style="top:-2.301288em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.808712em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.50732em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.04398em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.7350000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7114919999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.4101em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.04398em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.13453em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p>
|
||
<p>where higher values for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.13889em">T</span></span></span></span></span> will produce softer probabilities.</p>
|
||
<h2>Training</h2>
|
||
<p>Paper suggests adding a second loss term for predicting the actual labels when training the small model. We calculate the composite loss as the weighted sum of the two loss terms: soft targets and actual labels.</p>
|
||
<p>The dataset for distillation is called <em>the transfer set</em>, and the paper suggests using the same training data.</p>
|
||
<h2>Our experiment</h2>
|
||
<p>We train on CIFAR-10 dataset. We <a href="large.html">train a large model</a> that has <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8388800000000001em;vertical-align:-0.19444em;"></span><span class="mord">14</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">728</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">266</span></span></span></span> parameters with dropout and it gives an accuracy of 85% on the validation set. A <a href="small.html">small model</a> with <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8388800000000001em;vertical-align:-0.19444em;"></span><span class="mord">437</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">034</span></span></span></span> parameters gives an accuracy of 80%.</p>
|
||
<p>We then train the small model with distillation from the large model, and it gives an accuracy of 82%; a 2% increase in the accuracy.</p>
|
||
<p><a href="https://app.labml.ai/run/d6182e2adaf011eb927c91a2a1710932"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">74</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="lineno">75</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span>
|
||
<span class="lineno">76</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||
<span class="lineno">77</span>
|
||
<span class="lineno">78</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">tracker</span>
|
||
<span class="lineno">79</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
|
||
<span class="lineno">80</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">BatchIndex</span>
|
||
<span class="lineno">81</span><span class="kn">from</span> <span class="nn">labml_nn.distillation.large</span> <span class="kn">import</span> <span class="n">LargeModel</span>
|
||
<span class="lineno">82</span><span class="kn">from</span> <span class="nn">labml_nn.distillation.small</span> <span class="kn">import</span> <span class="n">SmallModel</span>
|
||
<span class="lineno">83</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.cifar10</span> <span class="kn">import</span> <span class="n">CIFAR10Configs</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-1'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-1'>#</a>
|
||
</div>
|
||
<h2>Configurations</h2>
|
||
<p>This extends from <a href="../experiments/cifar10.html"><code class="highlight"><span></span><span class="n">CIFAR10Configs</span></code>
|
||
</a> which defines all the dataset related configurations, optimizer, and a training loop.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">86</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-2'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-2'>#</a>
|
||
</div>
|
||
<p>The small model </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">94</span> <span class="n">model</span><span class="p">:</span> <span class="n">SmallModel</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-3'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-3'>#</a>
|
||
</div>
|
||
<p>The large model </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">large</span><span class="p">:</span> <span class="n">LargeModel</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-4'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-4'>#</a>
|
||
</div>
|
||
<p>KL Divergence loss for soft targets </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">98</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">KLDivLoss</span><span class="p">(</span><span class="n">log_target</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-5'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-5'>#</a>
|
||
</div>
|
||
<p>Cross entropy loss for true label loss </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-6'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-6'>#</a>
|
||
</div>
|
||
<p>Temperature, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.13889em">T</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">102</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">5.</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-7'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-7'>#</a>
|
||
</div>
|
||
<p>Weight for soft targets loss.</p>
|
||
<p>The gradients produced by soft targets get scaled by <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.190108em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqg" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7463142857142857em;"><span style="top:-2.786em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span>. To compensate for this the paper suggests scaling the soft targets loss by a factor of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8141079999999999em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.13889em">T</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">soft_targets_weight</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">100.</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-8'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-8'>#</a>
|
||
</div>
|
||
<p>Weight for true label cross entropy loss </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">110</span> <span class="n">label_loss_weight</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-9'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-9'>#</a>
|
||
</div>
|
||
<h3>Training/validation step</h3>
|
||
<p>We define a custom training/validation step to include the distillation</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-10'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-10'>#</a>
|
||
</div>
|
||
<p>Training/Evaluation mode for the small model </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-11'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-11'>#</a>
|
||
</div>
|
||
<p>Large model in evaluation mode </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">large</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-12'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-12'>#</a>
|
||
</div>
|
||
<p>Move data to the device </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">125</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-13'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-13'>#</a>
|
||
</div>
|
||
<p>Update global step (number of samples processed) when in training mode </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">128</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
|
||
<span class="lineno">129</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-14'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-14'>#</a>
|
||
</div>
|
||
<p>Get the output logits, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>, from the large model </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">132</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
|
||
<span class="lineno">133</span> <span class="n">large_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">large</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-15'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-15'>#</a>
|
||
</div>
|
||
<p>Get the output logits, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.04398em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>, from the small model </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-16'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-16'>#</a>
|
||
</div>
|
||
<p>Soft targets <span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.61953em;vertical-align:-1.13453em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.4849999999999999em;"><span style="top:-2.301288em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.808712em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.50732em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.7350000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7114919999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.4101em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqi" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.13453em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">soft_targets</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">large_logits</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-17'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-17'>#</a>
|
||
</div>
|
||
<p>Temperature adjusted probabilities of the small model <span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.61953em;vertical-align:-1.13453em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.4849999999999999em;"><span style="top:-2.301288em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.808712em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.50732em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.04398em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.7350000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7114919999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.4101em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.04398em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.13453em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">soft_prob</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">output</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-18'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-18'>#</a>
|
||
</div>
|
||
<p>Calculate the soft targets loss </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">soft_targets_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss</span><span class="p">(</span><span class="n">soft_prob</span><span class="p">,</span> <span class="n">soft_targets</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-19'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-19'>#</a>
|
||
</div>
|
||
<p>Calculate the true label loss </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">label_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-20'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-20'>#</a>
|
||
</div>
|
||
<p>Weighted sum of the two losses </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">soft_targets_weight</span> <span class="o">*</span> <span class="n">soft_targets_loss</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_loss_weight</span> <span class="o">*</span> <span class="n">label_loss</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-21'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-21'>#</a>
|
||
</div>
|
||
<p>Log the losses </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">152</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="s2">"loss.kl_div."</span><span class="p">:</span> <span class="n">soft_targets_loss</span><span class="p">,</span>
|
||
<span class="lineno">153</span> <span class="s2">"loss.nll"</span><span class="p">:</span> <span class="n">label_loss</span><span class="p">,</span>
|
||
<span class="lineno">154</span> <span class="s2">"loss."</span><span class="p">:</span> <span class="n">loss</span><span class="p">})</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-22'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-22'>#</a>
|
||
</div>
|
||
<p>Calculate and log accuracy </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">157</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||
<span class="lineno">158</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="o">.</span><span class="n">track</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-23'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-23'>#</a>
|
||
</div>
|
||
<p>Train the model </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">161</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-24'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-24'>#</a>
|
||
</div>
|
||
<p>Calculate gradients </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">163</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-25'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-25'>#</a>
|
||
</div>
|
||
<p>Take optimizer step </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">165</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-26'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-26'>#</a>
|
||
</div>
|
||
<p>Log the model parameters and gradients on last batch of every epoch </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">167</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
|
||
<span class="lineno">168</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">'model'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-27'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-27'>#</a>
|
||
</div>
|
||
<p>Clear the gradients </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">170</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-28'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-28'>#</a>
|
||
</div>
|
||
<p>Save the tracked metrics </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-29'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-29'>#</a>
|
||
</div>
|
||
<h3>Create large model</h3>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">176</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">large</span><span class="p">)</span>
|
||
<span class="lineno">177</span><span class="k">def</span> <span class="nf">_large_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-30'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-30'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">181</span> <span class="k">return</span> <span class="n">LargeModel</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-31'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-31'>#</a>
|
||
</div>
|
||
<h3>Create small model</h3>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">184</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||
<span class="lineno">185</span><span class="k">def</span> <span class="nf">_small_student_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-32'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-32'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">189</span> <span class="k">return</span> <span class="n">SmallModel</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-33'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-33'>#</a>
|
||
</div>
|
||
<h3>Load <a href="large.html">trained large model</a></h3>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">192</span><span class="k">def</span> <span class="nf">get_saved_model</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-34'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-34'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">197</span> <span class="kn">from</span> <span class="nn">labml_nn.distillation.large</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">LargeConfigs</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-35'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-35'>#</a>
|
||
</div>
|
||
<p>In evaluation mode (no recording) </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">experiment</span><span class="o">.</span><span class="n">evaluate</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-36'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-36'>#</a>
|
||
</div>
|
||
<p>Initialize configs of the large model training experiment </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">202</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">LargeConfigs</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-37'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-37'>#</a>
|
||
</div>
|
||
<p>Load saved configs </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load_configs</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">))</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-38'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-38'>#</a>
|
||
</div>
|
||
<p>Set models for saving/loading </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">'model'</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-39'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-39'>#</a>
|
||
</div>
|
||
<p>Set which run and checkpoint to load </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-40'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-40'>#</a>
|
||
</div>
|
||
<p>Start the experiment - this will load the model, and prepare everything </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-41'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-41'>#</a>
|
||
</div>
|
||
<p>Return the model </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">213</span> <span class="k">return</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-42'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-42'>#</a>
|
||
</div>
|
||
<p> Train a small model with distillation</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">216</span><span class="k">def</span> <span class="nf">main</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-43'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-43'>#</a>
|
||
</div>
|
||
<p>Load saved model </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">large_model</span> <span class="o">=</span> <span class="n">get_saved_model</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-44'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-44'>#</a>
|
||
</div>
|
||
<p>Create experiment </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">223</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'distillation'</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">'cifar10'</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-45'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-45'>#</a>
|
||
</div>
|
||
<p>Create configurations </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">225</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-46'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-46'>#</a>
|
||
</div>
|
||
<p>Set the loaded large model </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">conf</span><span class="o">.</span><span class="n">large</span> <span class="o">=</span> <span class="n">large_model</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-47'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-47'>#</a>
|
||
</div>
|
||
<p>Load configurations </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">229</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
|
||
<span class="lineno">230</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span>
|
||
<span class="lineno">231</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
|
||
<span class="lineno">232</span> <span class="s1">'model'</span><span class="p">:</span> <span class="s1">'_small_student_model'</span><span class="p">,</span>
|
||
<span class="lineno">233</span> <span class="p">})</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-48'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-48'>#</a>
|
||
</div>
|
||
<p>Set model for saving/loading </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">'model'</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-49'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-49'>#</a>
|
||
</div>
|
||
<p>Start experiment from scratch </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">237</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-50'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-50'>#</a>
|
||
</div>
|
||
<p>Start the experiment and run the training loop </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
|
||
<span class="lineno">240</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-51'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-51'>#</a>
|
||
</div>
|
||
<p> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">244</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||
<span class="lineno">245</span> <span class="n">main</span><span class="p">(</span><span class="s1">'d46cd53edaec11eb93c38d6538aee7d6'</span><span class="p">,</span> <span class="mi">1_000_000</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='footer'>
|
||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||
<a href="https://labml.ai">labml.ai</a>
|
||
</div>
|
||
</div>
|
||
<script src=../interactive.js?v=1"></script>
|
||
<script>
|
||
function handleImages() {
|
||
var images = document.querySelectorAll('p>img')
|
||
|
||
for (var i = 0; i < images.length; ++i) {
|
||
handleImage(images[i])
|
||
}
|
||
}
|
||
|
||
function handleImage(img) {
|
||
img.parentElement.style.textAlign = 'center'
|
||
|
||
var modal = document.createElement('div')
|
||
modal.id = 'modal'
|
||
|
||
var modalContent = document.createElement('div')
|
||
modal.appendChild(modalContent)
|
||
|
||
var modalImage = document.createElement('img')
|
||
modalContent.appendChild(modalImage)
|
||
|
||
var span = document.createElement('span')
|
||
span.classList.add('close')
|
||
span.textContent = 'x'
|
||
modal.appendChild(span)
|
||
|
||
img.onclick = function () {
|
||
console.log('clicked')
|
||
document.body.appendChild(modal)
|
||
modalImage.src = img.src
|
||
}
|
||
|
||
span.onclick = function () {
|
||
document.body.removeChild(modal)
|
||
}
|
||
}
|
||
|
||
handleImages()
|
||
</script>
|
||
</body>
|
||
</html> |