Files
Varuna Jayasiri 2038b11d29 ja translation
2023-05-10 17:00:29 -04:00

444 lines
27 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="ja">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="これはCIFAR10データセット用の再利用可能なトレーナーです"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="CIFAR10 実験"/>
<meta name="twitter:description" content="これはCIFAR10データセット用の再利用可能なトレーナーです"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/experiments/cifar10.html"/>
<meta property="og:title" content="CIFAR10 実験"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="CIFAR10 実験"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="CIFAR10 実験"/>
<meta property="og:description" content="これはCIFAR10データセット用の再利用可能なトレーナーです"/>
<title>CIFAR10 実験</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/experiments/cifar10.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">experiments</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/experiments/cifar10.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>CIFAR10 実験</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
<span class="lineno">11</span>
<span class="lineno">12</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">13</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_helpers.datasets.cifar10</span> <span class="kn">import</span> <span class="n">CIFAR10Configs</span> <span class="k">as</span> <span class="n">CIFAR10DatasetConfigs</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.mnist</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>コンフィギュレーション</h2>
<p>これは、およびの CIFAR 10 データセット構成を拡張したものです<a href="https://github.com/labmlai/labml/tree/master/helpers"><code class="highlight"><span></span><span class="n">labml_helpers</span></code>
</a><a href="mnist.html"><code class="highlight"><span></span><span class="n">MNISTConfigs</span></code>
</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">21</span><span class="k">class</span> <span class="nc">CIFAR10Configs</span><span class="p">(</span><span class="n">CIFAR10DatasetConfigs</span><span class="p">,</span> <span class="n">MNISTConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>デフォルトで CIFAR10 データセットを使用</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">30</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;CIFAR10&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<h3>拡張された CIFAR 10 トレインデータセット</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="nd">@option</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="o">.</span><span class="n">train_dataset</span><span class="p">)</span>
<span class="lineno">34</span><span class="k">def</span> <span class="nf">cifar10_train_augmented</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="kn">from</span> <span class="nn">torchvision.datasets</span> <span class="kn">import</span> <span class="n">CIFAR10</span>
<span class="lineno">39</span> <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">transforms</span>
<span class="lineno">40</span> <span class="k">return</span> <span class="n">CIFAR10</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()),</span>
<span class="lineno">41</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">42</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">43</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>パッドとクロップ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">transforms</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>ランダム水平反転</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="lineno">50</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">))</span>
<span class="lineno">51</span> <span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h3>拡張されていない CIFAR 10 検証データセット</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span><span class="nd">@option</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">)</span>
<span class="lineno">55</span><span class="k">def</span> <span class="nf">cifar10_valid_no_augment</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="kn">from</span> <span class="nn">torchvision.datasets</span> <span class="kn">import</span> <span class="n">CIFAR10</span>
<span class="lineno">60</span> <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">transforms</span>
<span class="lineno">61</span> <span class="k">return</span> <span class="n">CIFAR10</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()),</span>
<span class="lineno">62</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">63</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">64</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="lineno">65</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="lineno">66</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">))</span>
<span class="lineno">67</span> <span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<h3>CIFAR-10 分類用の VGG モデル</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span><span class="k">class</span> <span class="nc">CIFAR10VGGModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>コンボリューションとアクティベーションの組み合わせ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">def</span> <span class="nf">conv_block</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="lineno">80</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
<span class="lineno">81</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="lineno">82</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">blocks</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
<span class="lineno">85</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">2</span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord">1</span><span class="mspace"> </span><span class="mord mathnormal">t</span><span class="mord mathnormal">im</span><span class="mord mathnormal">es</span><span class="mord">1</span></span></span></span></span>5つのプーリングレイヤーでサイズの出力が得られます。CIFAR 10 の画像サイズは <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">32</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">32</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">blocks</span><span class="p">)</span> <span class="o">==</span> <span class="mi">5</span>
<span class="lineno">90</span> <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>RGB チャンネル</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="mi">3</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>各ブロックの各レイヤーのチャンネル数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="k">for</span> <span class="n">block</span> <span class="ow">in</span> <span class="n">blocks</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>コンボリューション、ノーマライゼーション、アクティベーションレイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="k">for</span> <span class="n">channels</span> <span class="ow">in</span> <span class="n">block</span><span class="p">:</span>
<span class="lineno">97</span> <span class="n">layers</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_block</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">channels</span><span class="p">)</span>
<span class="lineno">98</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">channels</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>各ブロック終了時の最大プーリング</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>レイヤーを含むシーケンシャルモデルの作成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>最終ロジットレイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>VGG レイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>分類レイヤーの形状を変更</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>最終線形レイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>