Files
Varuna Jayasiri 2038b11d29 ja translation
2023-05-10 17:00:29 -04:00

780 lines
76 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="ja">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="PyTorchの実装と論文「ニューラルネットワークにおける知識の抽出」のチュートリアル。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="ニューラルネットワークでの知識の抽出"/>
<meta name="twitter:description" content="PyTorchの実装と論文「ニューラルネットワークにおける知識の抽出」のチュートリアル。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/distillation/index.html"/>
<meta property="og:title" content="ニューラルネットワークでの知識の抽出"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="ニューラルネットワークでの知識の抽出"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="ニューラルネットワークでの知識の抽出"/>
<meta property="og:description" content="PyTorchの実装と論文「ニューラルネットワークにおける知識の抽出」のチュートリアル。"/>
<title>ニューラルネットワークでの知識の抽出</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/distillation/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">distillation</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/distillation/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>ニューラルネットワークでの知識の抽出</h1>
<p>これは、<a href="https://papers.labml.ai/paper/1503.02531">論文「ニューラルネットワークにおける知識の抽出</a><a href="https://pytorch.org">のPyTorch実装/チュートリアルです</a></p>
<p>これは、トレーニング済みの大規模なネットワークの知識を使用して小規模ネットワークをトレーニングする方法です。つまり、大規模なネットワークから知識を抽出する方法です。</p>
<p>データやラベルで直接トレーニングした場合、正則化を行った大規模なモデルや (ドロップアウトを使用した) モデルのアンサンブルは、小さなモデルよりも一般化が容易です。ただし、小さいモデルでも、大きなモデルの助けを借りてより一般化しやすいようにトレーニングできます。本番環境では、モデルが小さいほど速く、処理能力が少なく、メモリも少なくて済みます。</p>
<p>トレーニング済みモデルの出力確率は、誤ったクラスにもゼロ以外の確率を割り当てるため、ラベルよりも多くの情報を提供します。これらの確率から、サンプルが特定のクラスに属している可能性があることがわかります。たとえば、数字を分類する際、<em>7 桁の画像が与えられた場合、一般化モデルでは 7</em> には高い確率、2 には小さいながらもゼロではない確率が与えられ、他の数字にはほぼゼロの確率を割り当てます。蒸留では、この情報を利用して小型モデルの学習効果を高めます</p>
<h2>ソフトターゲット</h2>
<p>確率は通常、ソフトマックス演算で計算されます。</p>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.5488180000000003em;vertical-align:-1.1218180000000002em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.427em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.04398em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.04398em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.1218180000000002em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></p>
<p>ここで<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="">i</span></span></span></span></span></span>はクラスの確率で、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.04398em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>はロジットです。</p>
<p>出力確率分布と大規模ネットワークの出力確率分布 (ソフトターゲット) の間のクロスエントロピーまたは KL ダイバージェンスを最小化するようにスモールモデルをトレーニングします。</p>
<p>ここでの問題の 1 つは、大規模なネットワークによって誤ったクラスに割り当てられる確率がたいてい非常に小さく、損失の原因にならないことです。そこで、温度を当てることで確率を弱めます <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.13889em">T</span></span></span></span></span></span></p>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.61953em;vertical-align:-1.13453em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style=""><span class="mord coloredeq eqh" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.4849999999999999em;"><span style="top:-2.301288em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mop" style=""><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mop" style=""><span style="">e</span><span style="">x</span><span style="">p</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.808712em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqk" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.50732em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.04398em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose" style="">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.7350000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mop" style=""><span style="">e</span><span style="">x</span><span style="">p</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7114919999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqk" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.4101em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.04398em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose" style="">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.13453em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></span></p>
<p>値が大きいほど<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.13889em">T</span></span></span></span></span></span>、確率は低くなります。</p>
<h2>トレーニング</h2>
<p>論文では、小型モデルをトレーニングする際に実際のラベルを予測するために 2 つ目の損失項を追加することを提案しています。複合損失は、ソフトターゲットと実際のラベルという2つの損失項の加重合計として計算されます。</p>
<p><em>蒸留用のデータセットはトランスファーセットと呼ばれ</em>、論文では同じトレーニングデータを使用することを提案しています。</p>
<h2>私たちの実験</h2>
<p>CIFAR-10データセットでトレーニングします。<a href="large.html"><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8388800000000001em;vertical-align:-0.19444em;"></span><span class="mord">14</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">728</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">266</span></span></span></span></span>ドロップアウトのあるパラメータを持つ大規模なモデルをトレーニングすると</a>、検証セットの精度が 85% になります。<a href="small.html"><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8388800000000001em;vertical-align:-0.19444em;"></span><span class="mord">437</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">034</span></span></span></span></span>パラメータを含む小さなモデルでは</a>、80の精度が得られます。</p>
<p>次に、大きいモデルから蒸留して小さなモデルにトレーニングを行うと、精度は 82%、精度は 2% 向上します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">73</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span>
<span class="lineno">74</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">75</span>
<span class="lineno">76</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">tracker</span>
<span class="lineno">77</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">78</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">BatchIndex</span>
<span class="lineno">79</span><span class="kn">from</span> <span class="nn">labml_nn.distillation.large</span> <span class="kn">import</span> <span class="n">LargeModel</span>
<span class="lineno">80</span><span class="kn">from</span> <span class="nn">labml_nn.distillation.small</span> <span class="kn">import</span> <span class="n">SmallModel</span>
<span class="lineno">81</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.cifar10</span> <span class="kn">import</span> <span class="n">CIFAR10Configs</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>コンフィギュレーション</h2>
<p>これを拡張して<a href="../experiments/cifar10.html"><code class="highlight"><span></span><span class="n">CIFAR10Configs</span></code>
</a>、データセットに関連するすべての構成、オプティマイザー、およびトレーニングループを定義します。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>スモールモデル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">model</span><span class="p">:</span> <span class="n">SmallModel</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>ラージモデル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="n">large</span><span class="p">:</span> <span class="n">LargeModel</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>ソフトターゲットのKLダイバージェンスロス</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">kl_div_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">KLDivLoss</span><span class="p">(</span><span class="n">log_target</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>クロスエントロピー損失による真のラベルロス</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>温度、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.13889em">T</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">5.</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>ソフトターゲットの重量が減ります。</p>
<p>ソフトターゲットによって生成されるグラデーションは、次の方法でスケーリングされます。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.190108em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqg" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqk" style="margin-right:0.13889em">T</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7463142857142857em;"><span style="top:-2.786em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span>これを補うために、この論文ではソフトターゲットの損失を次の倍にスケーリングすることを提案しています</p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8141079999999999em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.13889em">T</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span></span></span></span></span></span></span></span></span></span>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="n">soft_targets_weight</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">100.</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>真のラベルクロスエントロピー損失を実現する重量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">label_loss_weight</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<h3>トレーニング/検証ステップ</h3>
<p>蒸留を含むカスタムトレーニング/検証ステップを定義します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>小型モデルのトレーニング/評価モード</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>評価モードの大型モデル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">large</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>データをデバイスに移動</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>トレーニングモード時にグローバルステップ (処理されたサンプル数) を更新</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">127</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>ラージモデルから出力ロジットを取得 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="lineno">131</span> <span class="n">large_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">large</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>小さいモデルから出力ロジットを取得 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.04398em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>ソフトターゲット <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.61953em;vertical-align:-1.13453em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.4849999999999999em;"><span style="top:-2.301288em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.808712em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.50732em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.7350000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop">exp</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7114919999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.4101em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqi" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03588em">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.03588em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.13453em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">soft_targets</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">large_logits</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>小型モデルの温度調整済み確率 <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:2.61953em;vertical-align:-1.13453em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style=""><span class="mord coloredeq eqh" style=""><span class="mord mathnormal" style="margin-right:0.03588em">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.4849999999999999em;"><span style="top:-2.301288em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mop" style=""><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16195399999999993em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em"></span><span class="mop" style=""><span style="">e</span><span style="">x</span><span style="">p</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.808712em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqk" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.50732em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.04398em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose" style="">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.7350000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mop" style=""><span style="">e</span><span style="">x</span><span style="">p</span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7114919999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqk" style="margin-right:0.13889em">T</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.4101em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:-0.04398em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eql" style="">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose" style="">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.13453em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">soft_prob</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">output</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>ソフトターゲットの損失の計算</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">soft_targets_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_div_loss</span><span class="p">(</span><span class="n">soft_prob</span><span class="p">,</span> <span class="n">soft_targets</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>実際のラベルロスを計算</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">label_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>2 つの損失の加重和</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">soft_targets_weight</span> <span class="o">*</span> <span class="n">soft_targets_loss</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_loss_weight</span> <span class="o">*</span> <span class="n">label_loss</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>損失を記録する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="s2">&quot;loss.kl_div.&quot;</span><span class="p">:</span> <span class="n">soft_targets_loss</span><span class="p">,</span>
<span class="lineno">151</span> <span class="s2">&quot;loss.nll&quot;</span><span class="p">:</span> <span class="n">label_loss</span><span class="p">,</span>
<span class="lineno">152</span> <span class="s2">&quot;loss.&quot;</span><span class="p">:</span> <span class="n">loss</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>精度の計算と記録</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">155</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<span class="lineno">156</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="o">.</span><span class="n">track</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>モデルのトレーニング</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">159</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>勾配の計算</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>最適化の一歩を踏み出す</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>各エポックの最後のバッチでモデルパラメータと勾配を記録します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
<span class="lineno">166</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;model&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>グラデーションをクリア</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>追跡したメトリクスを保存する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<h3>大きいモデルを作成</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">large</span><span class="p">)</span>
<span class="lineno">175</span><span class="k">def</span> <span class="nf">_large_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="k">return</span> <span class="n">LargeModel</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<h3>小型モデルを作成</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">183</span><span class="k">def</span> <span class="nf">_small_student_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">return</span> <span class="n">SmallModel</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<h3><a href="large.html">負荷トレーニング済みの大型モデル</a></h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span><span class="k">def</span> <span class="nf">get_saved_model</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="kn">from</span> <span class="nn">labml_nn.distillation.large</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">LargeConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>評価モード (記録なし)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="n">experiment</span><span class="o">.</span><span class="n">evaluate</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>ラージモデルトレーニング実験のコンフィグを初期化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">LargeConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>保存した設定を読み込む</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load_configs</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>保存/読み込み用のモデルを設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>どのランとチェックポイントをロードするかを設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>実験を開始します。これでモデルが読み込まれ、すべての準備が整います</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>モデルを返却する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">return</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>蒸留による小型モデルのトレーニング</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span><span class="k">def</span> <span class="nf">main</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>保存したモデルをロード</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">large_model</span> <span class="o">=</span> <span class="n">get_saved_model</span><span class="p">(</span><span class="n">run_uuid</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>実験を作成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;distillation&#39;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">&#39;cifar10&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>構成の作成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">223</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>読み込んだラージモデルを設定する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">225</span> <span class="n">conf</span><span class="o">.</span><span class="n">large</span> <span class="o">=</span> <span class="n">large_model</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>構成をロード</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">228</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">229</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
<span class="lineno">230</span> <span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="s1">&#39;_small_student_model&#39;</span><span class="p">,</span>
<span class="lineno">231</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>保存/読み込み用のモデルを設定</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">233</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>実験をゼロから始める</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>実験を開始し、トレーニングループを実行します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">238</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">242</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">243</span> <span class="n">main</span><span class="p">(</span><span class="s1">&#39;d46cd53edaec11eb93c38d6538aee7d6&#39;</span><span class="p">,</span> <span class="mi">1_000_000</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>