mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
unescape *
This commit is contained in:
@ -24,6 +24,8 @@
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/experiments/cifar10.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
@ -67,6 +69,7 @@
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>CIFAR10 Experiment</h1>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
|
||||
@ -86,9 +89,10 @@
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>Configurations</h2>
|
||||
<p>This extends from CIFAR 10 dataset configurations from
|
||||
<a href="https://github.com/labmlai/labml/tree/master/helpers"><code>labml_helpers</code></a>
|
||||
and <a href="mnist.html"><code>MNISTConfigs</code></a>.</p>
|
||||
<p>This extends from CIFAR 10 dataset configurations from <a href="https://github.com/labmlai/labml/tree/master/helpers"><code>labml_helpers</code>
|
||||
</a> and <a href="mnist.html"><code>MNISTConfigs</code>
|
||||
</a>.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">21</span><span class="k">class</span> <span class="nc">CIFAR10Configs</span><span class="p">(</span><span class="n">CIFAR10DatasetConfigs</span><span class="p">,</span> <span class="n">MNISTConfigs</span><span class="p">):</span></pre></div>
|
||||
@ -99,7 +103,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<p>Use CIFAR10 dataset by default</p>
|
||||
<p>Use CIFAR10 dataset by default </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">30</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'CIFAR10'</span></pre></div>
|
||||
@ -111,6 +116,7 @@
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<h3>Augmented CIFAR 10 train dataset</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">33</span><span class="nd">@option</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="o">.</span><span class="n">train_dataset</span><span class="p">)</span>
|
||||
@ -138,7 +144,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Pad and crop</p>
|
||||
<p>Pad and crop </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">transforms</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span></pre></div>
|
||||
@ -149,7 +156,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Random horizontal flip</p>
|
||||
<p>Random horizontal flip </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span></pre></div>
|
||||
@ -160,7 +168,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
|
||||
@ -174,6 +183,7 @@
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<h3>Non-augmented CIFAR 10 validation dataset</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">54</span><span class="nd">@option</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="o">.</span><span class="n">valid_dataset</span><span class="p">)</span>
|
||||
@ -205,6 +215,7 @@
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<h3>VGG model for CIFAR-10 classification</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">70</span><span class="k">class</span> <span class="nc">CIFAR10VGGModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
@ -215,7 +226,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Convolution and activation combined</p>
|
||||
<p> Convolution and activation combined</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">def</span> <span class="nf">conv_block</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">)</span> <span class="o">-></span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">:</span></pre></div>
|
||||
@ -252,8 +264,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>5 $2 \times 2$ pooling layers will produce a output of size $1 \ times 1$.
|
||||
CIFAR 10 image size is $32 \times 32$</p>
|
||||
<p>5 <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">2</span></span></span></span> pooling layers will produce a output of size <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord">1</span><span class="mspace"> </span><span class="mord mathnormal">t</span><span class="mord mathnormal">im</span><span class="mord mathnormal">es</span><span class="mord">1</span></span></span></span>. CIFAR 10 image size is <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">32</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">32</span></span></span></span> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">89</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">blocks</span><span class="p">)</span> <span class="o">==</span> <span class="mi">5</span>
|
||||
@ -265,7 +277,8 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>RGB channels</p>
|
||||
<p>RGB channels </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="mi">3</span></pre></div>
|
||||
@ -276,7 +289,8 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Number of channels in each layer in each block</p>
|
||||
<p>Number of channels in each layer in each block </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">94</span> <span class="k">for</span> <span class="n">block</span> <span class="ow">in</span> <span class="n">blocks</span><span class="p">:</span></pre></div>
|
||||
@ -287,7 +301,8 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Convolution, Normalization and Activation layers</p>
|
||||
<p>Convolution, Normalization and Activation layers </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">96</span> <span class="k">for</span> <span class="n">channels</span> <span class="ow">in</span> <span class="n">block</span><span class="p">:</span>
|
||||
@ -300,7 +315,8 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Max pooling at end of each block</p>
|
||||
<p>Max pooling at end of each block </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">)]</span></pre></div>
|
||||
@ -311,7 +327,8 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Create a sequential model with the layers</p>
|
||||
<p>Create a sequential model with the layers </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
|
||||
@ -322,7 +339,8 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Final logits layer</p>
|
||||
<p>Final logits layer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span></pre></div>
|
||||
@ -344,7 +362,8 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>The VGG layers</p>
|
||||
<p>The VGG layers </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">109</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
@ -355,7 +374,8 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>Reshape for classification layer</p>
|
||||
<p>Reshape for classification layer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
@ -366,7 +386,8 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>Final linear layer</p>
|
||||
<p>Final linear layer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">113</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||
@ -377,24 +398,6 @@ CIFAR 10 image size is $32 \times 32$</p>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
@ -24,6 +24,8 @@
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/experiments/index.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
@ -66,24 +68,6 @@
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
@ -24,6 +24,8 @@
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/experiments/mnist.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
@ -67,6 +69,7 @@
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>MNIST Experiment</h1>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
@ -87,9 +90,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<p><a id="MNISTConfigs"></p>
|
||||
<h2>Trainer configurations</h2>
|
||||
<p></a></p>
|
||||
<p> <a id="MNISTConfigs"> ## Trainer configurations </a></p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">24</span><span class="k">class</span> <span class="nc">MNISTConfigs</span><span class="p">(</span><span class="n">MNISTDatasetConfigs</span><span class="p">,</span> <span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
|
||||
@ -100,7 +102,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<p>Optimizer</p>
|
||||
<p>Optimizer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">32</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
|
||||
@ -111,7 +114,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>Training device</p>
|
||||
<p>Training device </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">34</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
|
||||
@ -122,7 +126,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Classification model</p>
|
||||
<p>Classification model </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">37</span> <span class="n">model</span><span class="p">:</span> <span class="n">Module</span></pre></div>
|
||||
@ -133,7 +138,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Number of epochs to train for</p>
|
||||
<p>Number of epochs to train for </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span></pre></div>
|
||||
@ -144,7 +150,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Number of times to switch between training and validation within an epoch</p>
|
||||
<p>Number of times to switch between training and validation within an epoch </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">inner_iterations</span> <span class="o">=</span> <span class="mi">10</span></pre></div>
|
||||
@ -155,7 +162,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Accuracy function</p>
|
||||
<p>Accuracy function </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">()</span></pre></div>
|
||||
@ -166,7 +174,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Loss function</p>
|
||||
<p>Loss function </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
|
||||
@ -178,6 +187,7 @@
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<h3>Initialization</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
@ -188,7 +198,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Set tracker configurations</p>
|
||||
<p>Set tracker configurations </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">"loss.*"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
||||
@ -200,7 +211,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Add a hook to log module outputs</p>
|
||||
<p>Add a hook to log module outputs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">)</span></pre></div>
|
||||
@ -211,10 +223,8 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Add accuracy as a state module.
|
||||
The name is probably confusing, since it’s meant to store
|
||||
states between training and validation for RNNs.
|
||||
This will keep the accuracy metric stats separate for training and validation.</p>
|
||||
<p>Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation. </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">62</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">]</span></pre></div>
|
||||
@ -226,6 +236,7 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<h3>Training or validation step</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">64</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
|
||||
@ -236,7 +247,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Training/Evaluation mode</p>
|
||||
<p>Training/Evaluation mode </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">70</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
|
||||
@ -247,7 +259,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Move data to the device</p>
|
||||
<p>Move data to the device </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
@ -258,7 +271,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Update global step (number of samples processed) when in training mode</p>
|
||||
<p>Update global step (number of samples processed) when in training mode </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
|
||||
@ -270,7 +284,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Whether to capture model outputs</p>
|
||||
<p>Whether to capture model outputs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">is_log_activations</span><span class="o">=</span><span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">):</span></pre></div>
|
||||
@ -281,7 +296,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Get model outputs.</p>
|
||||
<p>Get model outputs. </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||||
@ -292,7 +308,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Calculate and log loss</p>
|
||||
<p>Calculate and log loss </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||||
@ -304,7 +321,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Calculate and log accuracy</p>
|
||||
<p>Calculate and log accuracy </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">89</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||||
@ -316,7 +334,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Train the model</p>
|
||||
<p>Train the model </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">93</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
|
||||
@ -327,7 +346,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>Calculate gradients</p>
|
||||
<p>Calculate gradients </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
|
||||
@ -338,7 +358,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>Take optimizer step</p>
|
||||
<p>Take optimizer step </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
||||
@ -349,7 +370,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>Log the model parameters and gradients on last batch of every epoch</p>
|
||||
<p>Log the model parameters and gradients on last batch of every epoch </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">99</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
|
||||
@ -361,7 +383,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Clear the gradients</p>
|
||||
<p>Clear the gradients </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">102</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
|
||||
@ -372,7 +395,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<p>Save the tracked metrics</p>
|
||||
<p>Save the tracked metrics </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
|
||||
@ -384,6 +408,7 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<a href='#section-27'>#</a>
|
||||
</div>
|
||||
<h3>Default optimizer configurations</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">108</span><span class="nd">@option</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
|
||||
@ -409,24 +434,6 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
@ -24,6 +24,8 @@
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/experiments/nlp_autoregression.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
@ -67,6 +69,7 @@
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Auto-regressive NLP model trainer</h1>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span>
|
||||
@ -92,6 +95,7 @@
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h3>Cross entropy loss</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">28</span><span class="k">class</span> <span class="nc">CrossEntropyLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
@ -127,11 +131,9 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p><a id="NLPAutoRegressionConfigs"></p>
|
||||
<h2>Trainer configurations</h2>
|
||||
<p></a></p>
|
||||
<p>This has the basic configurations for NLP auto-regressive task training.
|
||||
All the properties are configurable.</p>
|
||||
<p> <a id="NLPAutoRegressionConfigs"> ## Trainer configurations </a></p>
|
||||
<p>This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">41</span><span class="k">class</span> <span class="nc">NLPAutoRegressionConfigs</span><span class="p">(</span><span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
|
||||
@ -142,7 +144,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Optimizer</p>
|
||||
<p>Optimizer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
|
||||
@ -153,7 +156,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Training device</p>
|
||||
<p>Training device </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
|
||||
@ -164,7 +168,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Autoregressive model</p>
|
||||
<p>Autoregressive model </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">model</span><span class="p">:</span> <span class="n">Module</span></pre></div>
|
||||
@ -175,7 +180,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Text dataset</p>
|
||||
<p>Text dataset </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">text</span><span class="p">:</span> <span class="n">TextDataset</span></pre></div>
|
||||
@ -186,7 +192,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>Batch size</p>
|
||||
<p>Batch size </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
|
||||
@ -197,7 +204,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Length of the sequence, or context size</p>
|
||||
<p>Length of the sequence, or context size </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
|
||||
@ -208,7 +216,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Number of token in vocabulary</p>
|
||||
<p>Number of token in vocabulary </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">n_tokens</span><span class="p">:</span> <span class="nb">int</span></pre></div>
|
||||
@ -219,7 +228,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Tokenizer</p>
|
||||
<p>Tokenizer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">67</span> <span class="n">tokenizer</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="s1">'character'</span></pre></div>
|
||||
@ -230,7 +240,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p>Text prompt to start sampling (for illustration)</p>
|
||||
<p>Text prompt to start sampling (for illustration) </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">prompt</span><span class="p">:</span> <span class="nb">str</span></pre></div>
|
||||
@ -241,7 +252,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>The token separator when sampling (blank for character level tokenization)</p>
|
||||
<p>The token separator when sampling (blank for character level tokenization) </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">prompt_separator</span><span class="p">:</span> <span class="nb">str</span></pre></div>
|
||||
@ -252,7 +264,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Whether to periodically save models</p>
|
||||
<p>Whether to periodically save models </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">is_save_models</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
|
||||
@ -263,7 +276,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Loss function</p>
|
||||
<p>Loss function </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">78</span> <span class="n">loss_func</span> <span class="o">=</span> <span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
|
||||
@ -274,7 +288,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Accuracy function</p>
|
||||
<p>Accuracy function </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">()</span></pre></div>
|
||||
@ -285,7 +300,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Model embedding size</p>
|
||||
<p>Model embedding size </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
|
||||
@ -296,7 +312,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Gradient clipping</p>
|
||||
<p>Gradient clipping </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">grad_norm_clip</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span></pre></div>
|
||||
@ -307,7 +324,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Training data loader</p>
|
||||
<p>Training data loader </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="s1">'shuffled_train_loader'</span></pre></div>
|
||||
@ -318,7 +336,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Validation data loader</p>
|
||||
<p>Validation data loader </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="s1">'shuffled_valid_loader'</span></pre></div>
|
||||
@ -330,6 +349,7 @@ All the properties are configurable.</p>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<h3>Initialization</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">91</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
@ -340,7 +360,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>Set tracker configurations</p>
|
||||
<p>Set tracker configurations </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">"accuracy.*"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
||||
@ -352,7 +373,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>Add a hook to log module outputs</p>
|
||||
<p>Add a hook to log module outputs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">99</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">)</span></pre></div>
|
||||
@ -363,10 +385,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Add accuracy as a state module.
|
||||
The name is probably confusing, since it’s meant to store
|
||||
states between training and validation for RNNs.
|
||||
This will keep the accuracy metric stats separate for training and validation.</p>
|
||||
<p>Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation. </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">104</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">]</span></pre></div>
|
||||
@ -377,7 +397,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<p>Override to calculate and log other metrics</p>
|
||||
<p>Override to calculate and log other metrics </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">106</span> <span class="k">def</span> <span class="nf">other_metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
@ -400,6 +421,7 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<a href='#section-28'>#</a>
|
||||
</div>
|
||||
<h3>Training or validation step</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
|
||||
@ -410,7 +432,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-29'>#</a>
|
||||
</div>
|
||||
<p>Set training/eval mode</p>
|
||||
<p>Set training/eval mode </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
|
||||
@ -421,7 +444,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-30'>#</a>
|
||||
</div>
|
||||
<p>Move data to the device</p>
|
||||
<p>Move data to the device </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
@ -432,7 +456,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-31'>#</a>
|
||||
</div>
|
||||
<p>Update global step (number of tokens processed) when in training mode</p>
|
||||
<p>Update global step (number of tokens processed) when in training mode </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
|
||||
@ -444,7 +469,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-32'>#</a>
|
||||
</div>
|
||||
<p>Whether to capture model outputs</p>
|
||||
<p>Whether to capture model outputs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">126</span> <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">is_log_activations</span><span class="o">=</span><span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">):</span></pre></div>
|
||||
@ -455,9 +481,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-33'>#</a>
|
||||
</div>
|
||||
<p>Get model outputs.
|
||||
It’s returning a tuple for states when using RNNs.
|
||||
This is not implemented yet. 😜</p>
|
||||
<p>Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜 </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">130</span> <span class="n">output</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||||
@ -468,7 +493,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-34'>#</a>
|
||||
</div>
|
||||
<p>Calculate and log loss</p>
|
||||
<p>Calculate and log loss </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||||
@ -480,7 +506,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-35'>#</a>
|
||||
</div>
|
||||
<p>Calculate and log accuracy</p>
|
||||
<p>Calculate and log accuracy </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">137</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||||
@ -494,7 +521,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-36'>#</a>
|
||||
</div>
|
||||
<p>Train the model</p>
|
||||
<p>Train the model </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
|
||||
@ -505,7 +533,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-37'>#</a>
|
||||
</div>
|
||||
<p>Calculate gradients</p>
|
||||
<p>Calculate gradients </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
|
||||
@ -516,7 +545,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-38'>#</a>
|
||||
</div>
|
||||
<p>Clip gradients</p>
|
||||
<p>Clip gradients </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">grad_norm_clip</span><span class="p">)</span></pre></div>
|
||||
@ -527,7 +557,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-39'>#</a>
|
||||
</div>
|
||||
<p>Take optimizer step</p>
|
||||
<p>Take optimizer step </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">149</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
||||
@ -538,7 +569,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-40'>#</a>
|
||||
</div>
|
||||
<p>Log the model parameters and gradients on last batch of every epoch</p>
|
||||
<p>Log the model parameters and gradients on last batch of every epoch </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">151</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
|
||||
@ -550,7 +582,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-41'>#</a>
|
||||
</div>
|
||||
<p>Clear the gradients</p>
|
||||
<p>Clear the gradients </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">154</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
|
||||
@ -561,7 +594,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-42'>#</a>
|
||||
</div>
|
||||
<p>Save the tracked metrics</p>
|
||||
<p>Save the tracked metrics </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
|
||||
@ -573,6 +607,7 @@ This is not implemented yet. 😜</p>
|
||||
<a href='#section-43'>#</a>
|
||||
</div>
|
||||
<h3>Sampling function to generate samples periodically while training</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">159</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
@ -583,7 +618,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-44'>#</a>
|
||||
</div>
|
||||
<p>Starting prompt</p>
|
||||
<p>Starting prompt </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">prompt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prompt</span></pre></div>
|
||||
@ -594,7 +630,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-45'>#</a>
|
||||
</div>
|
||||
<p>Collect output for printing</p>
|
||||
<p>Collect output for printing </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">log</span> <span class="o">=</span> <span class="p">[(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">)]</span></pre></div>
|
||||
@ -605,7 +642,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-46'>#</a>
|
||||
</div>
|
||||
<p>Sample 25 tokens</p>
|
||||
<p>Sample 25 tokens </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">169</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">'Sample'</span><span class="p">,</span> <span class="mi">25</span><span class="p">):</span></pre></div>
|
||||
@ -616,7 +654,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-47'>#</a>
|
||||
</div>
|
||||
<p>Tokenize the prompt</p>
|
||||
<p>Tokenize the prompt </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
@ -628,7 +667,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-48'>#</a>
|
||||
</div>
|
||||
<p>Get the model output</p>
|
||||
<p>Get the model output </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">174</span> <span class="n">output</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||||
@ -639,7 +679,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-49'>#</a>
|
||||
</div>
|
||||
<p>Get the model prediction (greedy)</p>
|
||||
<p>Get the model prediction (greedy) </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span></pre></div>
|
||||
@ -650,7 +691,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-50'>#</a>
|
||||
</div>
|
||||
<p>Add the prediction to prompt</p>
|
||||
<p>Add the prediction to prompt </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">178</span> <span class="n">prompt</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prompt_separator</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span></pre></div>
|
||||
@ -661,7 +703,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-51'>#</a>
|
||||
</div>
|
||||
<p>Add the prediction for logging</p>
|
||||
<p>Add the prediction for logging </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">180</span> <span class="n">log</span> <span class="o">+=</span> <span class="p">[(</span><span class="bp">self</span><span class="o">.</span><span class="n">prompt_separator</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">Text</span><span class="o">.</span><span class="n">value</span><span class="p">)]</span></pre></div>
|
||||
@ -672,7 +715,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-52'>#</a>
|
||||
</div>
|
||||
<p>Print the sampled output</p>
|
||||
<p>Print the sampled output </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">log</span><span class="p">)</span></pre></div>
|
||||
@ -684,6 +728,7 @@ This is not implemented yet. 😜</p>
|
||||
<a href='#section-53'>#</a>
|
||||
</div>
|
||||
<h3>Default <a href="../optimizers/configs.html">optimizer configurations</a></h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">186</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
|
||||
@ -711,7 +756,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-55'>#</a>
|
||||
</div>
|
||||
<p>Get number of tokens</p>
|
||||
<p> Get number of tokens</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">200</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">)</span>
|
||||
@ -734,13 +780,12 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-57'>#</a>
|
||||
</div>
|
||||
<h3>Basic english tokenizer</h3>
|
||||
<p>We use character level tokenizer in this experiment.
|
||||
You can switch by setting,</p>
|
||||
<pre><code> 'tokenizer': 'basic_english',
|
||||
</code></pre>
|
||||
|
||||
<h3>Basic english tokenizer</h3>
|
||||
<p>We use character level tokenizer in this experiment. You can switch by setting,</p>
|
||||
<pre class="lang-"> 'tokenizer': 'basic_english',
|
||||
</pre>
|
||||
<p>as the configurations dictionary when starting the experiment.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">208</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
||||
@ -765,6 +810,7 @@ You can switch by setting,</p>
|
||||
<a href='#section-59'>#</a>
|
||||
</div>
|
||||
<h3>Character level tokenizer</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">227</span><span class="k">def</span> <span class="nf">character_tokenizer</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
||||
@ -787,6 +833,7 @@ You can switch by setting,</p>
|
||||
<a href='#section-61'>#</a>
|
||||
</div>
|
||||
<h3>Character level tokenizer configuration</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">234</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
||||
@ -811,6 +858,7 @@ You can switch by setting,</p>
|
||||
</div>
|
||||
<h3>Tiny Shakespeare dataset</h3>
|
||||
<p>It will download from the url if not present</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">242</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">text</span><span class="p">)</span>
|
||||
@ -837,6 +885,7 @@ You can switch by setting,</p>
|
||||
<a href='#section-65'>#</a>
|
||||
</div>
|
||||
<h3>Sequential training data loader</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">255</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
|
||||
@ -863,6 +912,7 @@ You can switch by setting,</p>
|
||||
<a href='#section-67'>#</a>
|
||||
</div>
|
||||
<h3>Sequential validation data loader</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">266</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">)</span>
|
||||
@ -889,8 +939,9 @@ You can switch by setting,</p>
|
||||
<a href='#section-69'>#</a>
|
||||
</div>
|
||||
<h3>Transpose batch</h3>
|
||||
<p><code>DataLoader</code> collects the batches on the first dimension.
|
||||
We need to transpose it to be sequence first.</p>
|
||||
<p><code>DataLoader</code>
|
||||
collects the batches on the first dimension. We need to transpose it to be sequence first.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">277</span><span class="k">def</span> <span class="nf">transpose_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span></pre></div>
|
||||
@ -912,7 +963,9 @@ We need to transpose it to be sequence first.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-71'>#</a>
|
||||
</div>
|
||||
<p>Stack the batch along the second dimension <code>dim=1</code></p>
|
||||
<p>Stack the batch along the second dimension <code>dim=1</code>
|
||||
</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">287</span> <span class="n">src</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">transposed_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
@ -927,6 +980,7 @@ We need to transpose it to be sequence first.</p>
|
||||
<a href='#section-72'>#</a>
|
||||
</div>
|
||||
<h3>Shuffled training data loader</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">293</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
|
||||
@ -955,6 +1009,7 @@ We need to transpose it to be sequence first.</p>
|
||||
<a href='#section-74'>#</a>
|
||||
</div>
|
||||
<h3>Shuffled validation data loader</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">306</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">)</span>
|
||||
@ -982,24 +1037,6 @@ We need to transpose it to be sequence first.</p>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
@ -24,6 +24,8 @@
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/experiments/nlp_classification.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
@ -67,6 +69,7 @@
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>NLP model trainer for classification</h1>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">Counter</span>
|
||||
@ -92,11 +95,9 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<p><a id="NLPClassificationConfigs"></p>
|
||||
<h2>Trainer configurations</h2>
|
||||
<p></a></p>
|
||||
<p>This has the basic configurations for NLP classification task training.
|
||||
All the properties are configurable.</p>
|
||||
<p> <a id="NLPClassificationConfigs"> ## Trainer configurations </a></p>
|
||||
<p>This has the basic configurations for NLP classification task training. All the properties are configurable.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">29</span><span class="k">class</span> <span class="nc">NLPClassificationConfigs</span><span class="p">(</span><span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
|
||||
@ -107,7 +108,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<p>Optimizer</p>
|
||||
<p>Optimizer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">40</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
|
||||
@ -118,7 +120,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>Training device</p>
|
||||
<p>Training device </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
|
||||
@ -129,7 +132,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Autoregressive model</p>
|
||||
<p>Autoregressive model </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">model</span><span class="p">:</span> <span class="n">Module</span></pre></div>
|
||||
@ -140,7 +144,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Batch size</p>
|
||||
<p>Batch size </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
|
||||
@ -151,7 +156,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Length of the sequence, or context size</p>
|
||||
<p>Length of the sequence, or context size </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
|
||||
@ -162,7 +168,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Vocabulary</p>
|
||||
<p>Vocabulary </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">51</span> <span class="n">vocab</span><span class="p">:</span> <span class="n">Vocab</span> <span class="o">=</span> <span class="s1">'ag_news'</span></pre></div>
|
||||
@ -173,7 +180,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Number of token in vocabulary</p>
|
||||
<p>Number of token in vocabulary </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">n_tokens</span><span class="p">:</span> <span class="nb">int</span></pre></div>
|
||||
@ -184,7 +192,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>Number of classes</p>
|
||||
<p>Number of classes </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">55</span> <span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="s1">'ag_news'</span></pre></div>
|
||||
@ -195,7 +204,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Tokenizer</p>
|
||||
<p>Tokenizer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">tokenizer</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="s1">'character'</span></pre></div>
|
||||
@ -206,7 +216,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Whether to periodically save models</p>
|
||||
<p>Whether to periodically save models </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">is_save_models</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
|
||||
@ -217,7 +228,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Loss function</p>
|
||||
<p>Loss function </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
|
||||
@ -228,7 +240,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p>Accuracy function</p>
|
||||
<p>Accuracy function </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">()</span></pre></div>
|
||||
@ -239,7 +252,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Model embedding size</p>
|
||||
<p>Model embedding size </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">67</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
|
||||
@ -250,7 +264,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Gradient clipping</p>
|
||||
<p>Gradient clipping </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">69</span> <span class="n">grad_norm_clip</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span></pre></div>
|
||||
@ -261,7 +276,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Training data loader</p>
|
||||
<p>Training data loader </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="s1">'ag_news'</span></pre></div>
|
||||
@ -272,7 +288,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Validation data loader</p>
|
||||
<p>Validation data loader </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="s1">'ag_news'</span></pre></div>
|
||||
@ -284,6 +301,7 @@ All the properties are configurable.</p>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<h3>Initialization</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
@ -294,7 +312,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Set tracker configurations</p>
|
||||
<p>Set tracker configurations </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">81</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">"accuracy.*"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
||||
@ -306,7 +325,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Add a hook to log module outputs</p>
|
||||
<p>Add a hook to log module outputs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">)</span></pre></div>
|
||||
@ -317,10 +337,8 @@ All the properties are configurable.</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Add accuracy as a state module.
|
||||
The name is probably confusing, since it’s meant to store
|
||||
states between training and validation for RNNs.
|
||||
This will keep the accuracy metric stats separate for training and validation.</p>
|
||||
<p>Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation. </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">89</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">]</span></pre></div>
|
||||
@ -332,6 +350,7 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<h3>Training or validation step</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">91</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
|
||||
@ -342,7 +361,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>Move data to the device</p>
|
||||
<p>Move data to the device </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
@ -353,7 +373,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>Update global step (number of tokens processed) when in training mode</p>
|
||||
<p>Update global step (number of tokens processed) when in training mode </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">100</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
|
||||
@ -365,7 +386,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Whether to capture model outputs</p>
|
||||
<p>Whether to capture model outputs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">104</span> <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">is_log_activations</span><span class="o">=</span><span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">):</span></pre></div>
|
||||
@ -376,9 +398,8 @@ This will keep the accuracy metric stats separate for training and validation.</
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<p>Get model outputs.
|
||||
It’s returning a tuple for states when using RNNs.
|
||||
This is not implemented yet. 😜</p>
|
||||
<p>Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜 </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">output</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||||
@ -389,7 +410,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-27'>#</a>
|
||||
</div>
|
||||
<p>Calculate and log loss</p>
|
||||
<p>Calculate and log loss </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||||
@ -401,7 +423,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-28'>#</a>
|
||||
</div>
|
||||
<p>Calculate and log accuracy</p>
|
||||
<p>Calculate and log accuracy </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
||||
@ -413,7 +436,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-29'>#</a>
|
||||
</div>
|
||||
<p>Train the model</p>
|
||||
<p>Train the model </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
|
||||
@ -424,7 +448,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-30'>#</a>
|
||||
</div>
|
||||
<p>Calculate gradients</p>
|
||||
<p>Calculate gradients </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
|
||||
@ -435,7 +460,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-31'>#</a>
|
||||
</div>
|
||||
<p>Clip gradients</p>
|
||||
<p>Clip gradients </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">grad_norm_clip</span><span class="p">)</span></pre></div>
|
||||
@ -446,7 +472,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-32'>#</a>
|
||||
</div>
|
||||
<p>Take optimizer step</p>
|
||||
<p>Take optimizer step </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
||||
@ -457,7 +484,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-33'>#</a>
|
||||
</div>
|
||||
<p>Log the model parameters and gradients on last batch of every epoch</p>
|
||||
<p>Log the model parameters and gradients on last batch of every epoch </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
|
||||
@ -469,7 +497,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-34'>#</a>
|
||||
</div>
|
||||
<p>Clear the gradients</p>
|
||||
<p>Clear the gradients </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">130</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
|
||||
@ -480,7 +509,8 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-35'>#</a>
|
||||
</div>
|
||||
<p>Save the tracked metrics</p>
|
||||
<p>Save the tracked metrics </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
|
||||
@ -492,6 +522,7 @@ This is not implemented yet. 😜</p>
|
||||
<a href='#section-36'>#</a>
|
||||
</div>
|
||||
<h3>Default <a href="../optimizers/configs.html">optimizer configurations</a></h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">136</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPClassificationConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
|
||||
@ -519,13 +550,12 @@ This is not implemented yet. 😜</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-38'>#</a>
|
||||
</div>
|
||||
<h3>Basic english tokenizer</h3>
|
||||
<p>We use character level tokenizer in this experiment.
|
||||
You can switch by setting,</p>
|
||||
<pre><code> 'tokenizer': 'basic_english',
|
||||
</code></pre>
|
||||
|
||||
<h3>Basic english tokenizer</h3>
|
||||
<p>We use character level tokenizer in this experiment. You can switch by setting,</p>
|
||||
<pre class="lang-"> 'tokenizer': 'basic_english',
|
||||
</pre>
|
||||
<p>as the configurations dictionary when starting the experiment.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">150</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPClassificationConfigs</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
||||
@ -550,6 +580,7 @@ You can switch by setting,</p>
|
||||
<a href='#section-40'>#</a>
|
||||
</div>
|
||||
<h3>Character level tokenizer</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">169</span><span class="k">def</span> <span class="nf">character_tokenizer</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
||||
@ -571,7 +602,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-42'>#</a>
|
||||
</div>
|
||||
<p>Character level tokenizer configuration</p>
|
||||
<p> Character level tokenizer configuration</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">176</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPClassificationConfigs</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
||||
@ -594,7 +626,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-44'>#</a>
|
||||
</div>
|
||||
<p>Get number of tokens</p>
|
||||
<p> Get number of tokens</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">184</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPClassificationConfigs</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">)</span>
|
||||
@ -618,6 +651,7 @@ You can switch by setting,</p>
|
||||
<a href='#section-46'>#</a>
|
||||
</div>
|
||||
<h2>Function to load data into batches</h2>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">192</span><span class="k">class</span> <span class="nc">CollateFunc</span><span class="p">:</span></pre></div>
|
||||
@ -628,13 +662,19 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-47'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>tokenizer</code> is the tokenizer function</li>
|
||||
<li><code>vocab</code> is the vocabulary</li>
|
||||
<li><code>seq_len</code> is the length of the sequence</li>
|
||||
<li><code>padding_token</code> is the token used for padding when the <code>seq_len</code> is larger than the text length</li>
|
||||
<li><code>classifier_token</code> is the <code>[CLS]</code> token which we set at end of the input</li>
|
||||
</ul>
|
||||
<ul><li><code>tokenizer</code>
|
||||
is the tokenizer function </li>
|
||||
<li><code>vocab</code>
|
||||
is the vocabulary </li>
|
||||
<li><code>seq_len</code>
|
||||
is the length of the sequence </li>
|
||||
<li><code>padding_token</code>
|
||||
is the token used for padding when the <code>seq_len</code>
|
||||
is larger than the text length </li>
|
||||
<li><code>classifier_token</code>
|
||||
is the <code>[CLS]</code>
|
||||
token which we set at end of the input</li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">vocab</span><span class="p">:</span> <span class="n">Vocab</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">padding_token</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">classifier_token</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||||
@ -660,9 +700,10 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-49'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>batch</code> is the batch of data collected by the <code>DataLoader</code></li>
|
||||
</ul>
|
||||
<ul><li><code>batch</code>
|
||||
is the batch of data collected by the <code>DataLoader</code>
|
||||
</li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span></pre></div>
|
||||
@ -673,7 +714,9 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-50'>#</a>
|
||||
</div>
|
||||
<p>Input data tensor, initialized with <code>padding_token</code></p>
|
||||
<p>Input data tensor, initialized with <code>padding_token</code>
|
||||
</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)),</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding_token</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span></pre></div>
|
||||
@ -684,7 +727,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-51'>#</a>
|
||||
</div>
|
||||
<p>Empty labels tensor</p>
|
||||
<p>Empty labels tensor </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span></pre></div>
|
||||
@ -695,7 +739,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-52'>#</a>
|
||||
</div>
|
||||
<p>Loop through the samples</p>
|
||||
<p>Loop through the samples </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">222</span> <span class="k">for</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">_label</span><span class="p">,</span> <span class="n">_text</span><span class="p">))</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span></pre></div>
|
||||
@ -706,7 +751,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-53'>#</a>
|
||||
</div>
|
||||
<p>Set the label</p>
|
||||
<p>Set the label </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">224</span> <span class="n">labels</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">_label</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span></pre></div>
|
||||
@ -717,7 +763,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-54'>#</a>
|
||||
</div>
|
||||
<p>Tokenize the input text</p>
|
||||
<p>Tokenize the input text </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">_text</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab</span><span class="p">[</span><span class="n">token</span><span class="p">]</span> <span class="k">for</span> <span class="n">token</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">(</span><span class="n">_text</span><span class="p">)]</span></pre></div>
|
||||
@ -728,7 +775,9 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-55'>#</a>
|
||||
</div>
|
||||
<p>Truncate upto <code>seq_len</code></p>
|
||||
<p>Truncate upto <code>seq_len</code>
|
||||
</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">_text</span> <span class="o">=</span> <span class="n">_text</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">]</span></pre></div>
|
||||
@ -739,7 +788,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-56'>#</a>
|
||||
</div>
|
||||
<p>Transpose and add to data</p>
|
||||
<p>Transpose and add to data </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">data</span><span class="p">[:</span><span class="nb">len</span><span class="p">(</span><span class="n">_text</span><span class="p">),</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">(</span><span class="n">_text</span><span class="p">)</span></pre></div>
|
||||
@ -750,7 +800,9 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-57'>#</a>
|
||||
</div>
|
||||
<p>Set the final token in the sequence to <code>[CLS]</code></p>
|
||||
<p>Set the final token in the sequence to <code>[CLS]</code>
|
||||
</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">233</span> <span class="n">data</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classifier_token</span></pre></div>
|
||||
@ -761,7 +813,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-58'>#</a>
|
||||
</div>
|
||||
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">236</span> <span class="k">return</span> <span class="n">data</span><span class="p">,</span> <span class="n">labels</span></pre></div>
|
||||
@ -773,8 +826,12 @@ You can switch by setting,</p>
|
||||
<a href='#section-59'>#</a>
|
||||
</div>
|
||||
<h3>AG News dataset</h3>
|
||||
<p>This loads the AG News dataset and the set the values for
|
||||
<code>n_classes</code>, <code>vocab</code>, <code>train_loader</code>, and <code>valid_loader</code>.</p>
|
||||
<p>This loads the AG News dataset and the set the values for <code>n_classes</code>
|
||||
, <code>vocab</code>
|
||||
, <code>train_loader</code>
|
||||
, and <code>valid_loader</code>
|
||||
.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">239</span><span class="nd">@option</span><span class="p">([</span><span class="n">NLPClassificationConfigs</span><span class="o">.</span><span class="n">n_classes</span><span class="p">,</span>
|
||||
@ -789,7 +846,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-60'>#</a>
|
||||
</div>
|
||||
<p>Get training and validation datasets</p>
|
||||
<p>Get training and validation datasets </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">train</span><span class="p">,</span> <span class="n">valid</span> <span class="o">=</span> <span class="n">torchtext</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">AG_NEWS</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="nb">str</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">'ag_news'</span><span class="p">),</span> <span class="n">split</span><span class="o">=</span><span class="p">(</span><span class="s1">'train'</span><span class="p">,</span> <span class="s1">'test'</span><span class="p">))</span></pre></div>
|
||||
@ -800,7 +858,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-61'>#</a>
|
||||
</div>
|
||||
<p>Load data to memory</p>
|
||||
<p>Load data to memory </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">255</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">'Load data'</span><span class="p">):</span>
|
||||
@ -812,7 +871,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-62'>#</a>
|
||||
</div>
|
||||
<p>Create <a href="../utils.html#map_style_dataset">map-style datasets</a></p>
|
||||
<p>Create <a href="../utils.html#map_style_dataset">map-style datasets</a> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">259</span> <span class="n">train</span><span class="p">,</span> <span class="n">valid</span> <span class="o">=</span> <span class="n">MapStyleDataset</span><span class="p">(</span><span class="n">train</span><span class="p">),</span> <span class="n">MapStyleDataset</span><span class="p">(</span><span class="n">valid</span><span class="p">)</span></pre></div>
|
||||
@ -823,7 +883,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-63'>#</a>
|
||||
</div>
|
||||
<p>Get tokenizer</p>
|
||||
<p>Get tokenizer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">262</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">tokenizer</span></pre></div>
|
||||
@ -834,7 +895,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-64'>#</a>
|
||||
</div>
|
||||
<p>Create a counter</p>
|
||||
<p>Create a counter </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">265</span> <span class="n">counter</span> <span class="o">=</span> <span class="n">Counter</span><span class="p">()</span></pre></div>
|
||||
@ -845,7 +907,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-65'>#</a>
|
||||
</div>
|
||||
<p>Collect tokens from training dataset</p>
|
||||
<p>Collect tokens from training dataset </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">267</span> <span class="k">for</span> <span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">line</span><span class="p">)</span> <span class="ow">in</span> <span class="n">train</span><span class="p">:</span>
|
||||
@ -857,7 +920,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-66'>#</a>
|
||||
</div>
|
||||
<p>Collect tokens from validation dataset</p>
|
||||
<p>Collect tokens from validation dataset </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">270</span> <span class="k">for</span> <span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">line</span><span class="p">)</span> <span class="ow">in</span> <span class="n">valid</span><span class="p">:</span>
|
||||
@ -869,7 +933,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-67'>#</a>
|
||||
</div>
|
||||
<p>Create vocabulary</p>
|
||||
<p>Create vocabulary </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">273</span> <span class="n">vocab</span> <span class="o">=</span> <span class="n">Vocab</span><span class="p">(</span><span class="n">counter</span><span class="p">,</span> <span class="n">min_freq</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
@ -880,7 +945,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-68'>#</a>
|
||||
</div>
|
||||
<p>Create training data loader</p>
|
||||
<p>Create training data loader </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">276</span> <span class="n">train_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
@ -892,7 +958,8 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-69'>#</a>
|
||||
</div>
|
||||
<p>Create validation data loader</p>
|
||||
<p>Create validation data loader </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">279</span> <span class="n">valid_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">valid</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
@ -904,7 +971,12 @@ You can switch by setting,</p>
|
||||
<div class='section-link'>
|
||||
<a href='#section-70'>#</a>
|
||||
</div>
|
||||
<p>Return <code>n_classes</code>, <code>vocab</code>, <code>train_loader</code>, and <code>valid_loader</code></p>
|
||||
<p>Return <code>n_classes</code>
|
||||
, <code>vocab</code>
|
||||
, <code>train_loader</code>
|
||||
, and <code>valid_loader</code>
|
||||
</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">283</span> <span class="k">return</span> <span class="mi">4</span><span class="p">,</span> <span class="n">vocab</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span></pre></div>
|
||||
@ -915,24 +987,6 @@ You can switch by setting,</p>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
Reference in New Issue
Block a user