Files
Kartik a55a824f8a docs - cross-validation and early-stopping (#48)
Co-authored-by: sachdev.kartik <sachdev.kartik@gmail.com>
2021-05-03 14:05:14 +05:30

563 lines
51 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="cv_train.py"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/cnn/utils/cv_train.html"/>
<meta property="og:title" content="cv_train.py"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="cv_train.py"/>
<meta property="og:description" content=""/>
<title>cv_train.py</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/cnn/utils/cv_train.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">cnn</a>
<a class="parent" href="index.html">utils</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/cnn/utils/cv_train.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Cross-Validation & Early Stopping</h1>
<p>Implementation of fundamental techniques namely <em>Cross-Validation</em> and <em>Early Stopping</em>
<h3>Cross-Validation</h3>
<p>
Getting data is expensive and in some cases, one has no option but to use a limited amount of data for training their machine learning models.
This is where Cross-Validation is useful. Steps are as follows:
<ol type = "1">
<li> Split the data in K folds </li>
<li> Use K-1 folds to train a set of models</li>
<li> Validate the models on the remaining fold</li>
<li> Repeat (1) and (2) for all the folds</li>
<li> Average the performance over all runs</li>
</ol>
</p>
<h3>Early-Stopping</h3>
Deep Learning networks are prone to overfitting, that is although overfitted models have a good performance on train set, they have poor generalization capabilities.
In other words, overfitted models have low bias and high variance. Lower the bias higher the capability of model to fit the data. Higher the variance higher the sensitivity with respect to training data.
<br>Formally, it can be represented as: </br>
<p><script type="math/tex; mode=display"> loss = {bias}^2 + {variance} + noise </script></p>
<p>Therefore, user has to find a tradeoff between bias and variance.</p>
<p> </p>
<p> Early-Stopping is one of the way to find this tradeoff. It helps to find a good setting of parameters and preventing overfitting on dataset and saving computation time.
This can be visualized through the following graph of train loss and validation loss over time: </p> <br>
<a href="https://www.deeplearningbook.org/contents/regularization.html"><img src="Cross-validation.png" alt="Training v/s Validation set Loss"></a>
<br>
<p> It can be seen that train error continue to decrease but the validation error start to increase after around 40 epochs.
Therefore, our goal is to stop the training after the validation loss increases </p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">3</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">4</span>
<span class="lineno">5</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Subset</span>
<span class="lineno">6</span>
<span class="lineno">7</span><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">KFold</span>
<span class="lineno">8</span><span class="kn">from</span> <span class="nn">torch.utils.data.sampler</span> <span class="kn">import</span> <span class="n">SubsetRandomSampler</span>
<span class="lineno">9</span><span class="kn">from</span> <span class="nn">models.cnn</span> <span class="kn">import</span> <span class="n">GetCNN</span>
<span class="lineno">10</span><span class="kn">from</span> <span class="nn">torchsummary</span> <span class="kn">import</span> <span class="n">summary</span>
<span class="lineno">11</span><span class="kn">import</span> <span class="nn">torch.optim</span> <span class="k">as</span> <span class="nn">optim</span>
<span class="lineno">12</span><span class="kn">import</span> <span class="nn">os</span>
<span class="lineno">13</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">torch.utils.tensorboard</span> <span class="kn">import</span> <span class="n">SummaryWriter</span>
<span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">datetime</span> <span class="kn">import</span> <span class="n">datetime</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">glob</span> <span class="kn">import</span> <span class="n">glob</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>Cross-Validation</h3>
<p> Splitting of training set in folds can be represented as: </p>
<img src="cv-folds.png" alt="CV folds">
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">21</span><span class="k">def</span> <span class="nf">cross_val_train</span><span class="p">(</span><span class="n">cost</span><span class="p">,</span> <span class="n">trainset</span><span class="p">,</span> <span class="n">epochs</span><span class="p">,</span> <span class="n">splits</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="lineno">22</span>
<span class="lineno">23</span> <span class="n">patience</span> <span class="o">=</span> <span class="mi">4</span>
<span class="lineno">24</span> <span class="n">history</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">25</span> <span class="n">kf</span> <span class="o">=</span> <span class="n">KFold</span><span class="p">(</span><span class="n">n_splits</span><span class="o">=</span><span class="n">splits</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">26</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="lineno">27</span> <span class="n">now</span> <span class="o">=</span> <span class="n">datetime</span><span class="o">.</span><span class="n">now</span><span class="p">()</span>
<span class="lineno">28</span> <span class="n">date_time</span> <span class="o">=</span> <span class="n">now</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%d</span><span class="s2">-%m-%Y_%H:%M:%S&quot;</span><span class="p">)</span>
<span class="lineno">29</span> <span class="n">directory</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="s1">&#39;./save/tensorboard-</span><span class="si">%s</span><span class="s1">/&#39;</span><span class="o">%</span><span class="p">(</span><span class="n">date_time</span><span class="p">))</span>
<span class="lineno">30</span>
<span class="lineno">31</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">directory</span><span class="p">):</span>
<span class="lineno">32</span> <span class="n">os</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">directory</span><span class="p">)</span>
<span class="lineno">33</span>
<span class="lineno">34</span> <span class="k">for</span> <span class="n">fold</span><span class="p">,</span> <span class="p">(</span><span class="n">train_index</span><span class="p">,</span> <span class="n">test_index</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">kf</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">trainset</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">trainset</span><span class="o">.</span><span class="n">targets</span><span class="p">)):</span> <span class="c1">#dataset required - compelete training set</span>
<span class="lineno">35</span> <span class="n">comment</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">directory</span><span class="si">}</span><span class="s1">/fold-</span><span class="si">{</span><span class="n">fold</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="lineno">36</span> <span class="n">writer</span> <span class="o">=</span> <span class="n">SummaryWriter</span><span class="p">(</span><span class="n">log_dir</span><span class="o">=</span><span class="n">comment</span><span class="p">)</span>
<span class="lineno">37</span>
<span class="lineno">38</span> <span class="n">train_sampler</span> <span class="o">=</span> <span class="n">SubsetRandomSampler</span><span class="p">(</span><span class="n">train_index</span><span class="p">)</span>
<span class="lineno">39</span> <span class="n">valid_sampler</span> <span class="o">=</span> <span class="n">SubsetRandomSampler</span><span class="p">(</span><span class="n">test_index</span><span class="p">)</span>
<span class="lineno">40</span> <span class="n">traindata</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">trainset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">sampler</span><span class="o">=</span><span class="n">train_sampler</span><span class="p">,</span>
<span class="lineno">41</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="lineno">42</span> <span class="n">valdata</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">trainset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">sampler</span><span class="o">=</span><span class="n">valid_sampler</span><span class="p">,</span>
<span class="lineno">43</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="lineno">44</span>
<span class="lineno">45</span> <span class="n">net</span> <span class="o">=</span> <span class="n">GetCNN</span><span class="p">()</span>
<span class="lineno">46</span> <span class="n">net</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">47</span> <span class="k">if</span> <span class="n">fold</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1">#Printing model detials for the first time</span>
<span class="lineno">48</span> <span class="n">summary</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Specify optimizer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.0005</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">))</span>
<span class="lineno">53</span> <span class="n">losses</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">epochs</span><span class="p">)</span>
<span class="lineno">54</span> <span class="n">accuracies</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">epochs</span><span class="p">)</span>
<span class="lineno">55</span> <span class="n">min_loss</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">56</span> <span class="n">count</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">57</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
<span class="lineno">58</span> <span class="n">valid_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">59</span> <span class="n">running_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="lineno">60</span> <span class="n">epoch_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="lineno">61</span> <span class="n">train_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">epochs</span><span class="p">)</span>
<span class="lineno">62</span> <span class="n">train_steps</span> <span class="o">=</span> <span class="mf">0.0</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Training steps</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">net</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> <span class="c1"># Enable Dropout</span>
<span class="lineno">66</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">data</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">traindata</span><span class="p">,</span> <span class="mi">0</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Get the inputs; data is a list of [inputs, labels]</p>
<p>Load the inputs in GPU if available else CPU</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="k">if</span> <span class="n">device</span><span class="p">:</span>
<span class="lineno">69</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">70</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">71</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Forward + backward + optimize</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">images</span><span class="p">)</span>
<span class="lineno">75</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">cost</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="lineno">76</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="lineno">77</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Zero the parameter gradients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Calculate loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">running_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="lineno">83</span> <span class="n">epoch_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="lineno">84</span> <span class="n">train_loss</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span> <span class="o">+=</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="lineno">85</span> <span class="n">train_steps</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="lineno">86</span>
<span class="lineno">87</span> <span class="n">loss_train</span> <span class="o">=</span> <span class="n">train_loss</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span> <span class="o">/</span> <span class="n">train_steps</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Validation and printing the metrics</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">loss_accuracy</span> <span class="o">=</span> <span class="n">Test</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">valdata</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
<span class="lineno">91</span>
<span class="lineno">92</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span> <span class="o">=</span> <span class="n">loss_accuracy</span><span class="p">[</span><span class="s1">&#39;val_loss&#39;</span><span class="p">]</span>
<span class="lineno">93</span> <span class="n">accuracies</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span> <span class="o">=</span> <span class="n">loss_accuracy</span><span class="p">[</span><span class="s1">&#39;val_acc&#39;</span><span class="p">]</span>
<span class="lineno">94</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Fold </span><span class="si">%d</span><span class="s2">, Epoch </span><span class="si">%d</span><span class="s2">, Train Loss </span><span class="si">%.4f</span><span class="s2"> Validation Loss: </span><span class="si">%.4f</span><span class="s2">, Validation Accuracy: </span><span class="si">%.4f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">fold</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">loss_train</span><span class="p">,</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">],</span> <span class="n">accuracies</span><span class="p">[</span><span class="n">epoch</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>TensorBoard</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">info</span> <span class="o">=</span> <span class="p">{</span>
<span class="lineno">98</span> <span class="s2">&quot;Loss/train&quot;</span><span class="p">:</span> <span class="n">loss_train</span><span class="p">,</span>
<span class="lineno">99</span> <span class="s2">&quot;Loss/valid&quot;</span><span class="p">:</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">],</span>
<span class="lineno">100</span> <span class="s2">&quot;Accuracy/valid&quot;</span><span class="p">:</span> <span class="n">accuracies</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span>
<span class="lineno">101</span> <span class="p">}</span>
<span class="lineno">102</span>
<span class="lineno">103</span> <span class="k">for</span> <span class="n">tag</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">info</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="lineno">104</span> <span class="n">writer</span><span class="o">.</span><span class="n">add_scalar</span><span class="p">(</span><span class="n">tag</span><span class="p">,</span> <span class="n">item</span><span class="p">,</span> <span class="n">global_step</span><span class="o">=</span><span class="n">epoch</span><span class="p">)</span>
<span class="lineno">105</span>
<span class="lineno">106</span> <span class="k">if</span> <span class="n">min_loss</span> <span class="o">==</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">107</span> <span class="n">min_loss</span> <span class="o">=</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<h3>Early stopping</h3>
<p>Early stopping can be understood graphically - the way weights change during the course of training.</p>
<ul>
<li> Solid contour lines indicate the contours of the negative log-likelihood (train error)</li>
<li> Dashed line indicates the trajectory taken by the optimizer</li>
<li> w denotes the weight setting correspoding to the minimum training error </li>
<li> w denotes the final weights setting chosen by the model after early-stopping </li>
</ul>
<a href="https://www.deeplearningbook.org/contents/regularization.html"><img src="early-stopping.png" alt="early-stopping" hspace="100" ></a> <!--align="middle"-->
<br>
<a href="https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py"><em>code reference here</em></a>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">if</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">min_loss</span><span class="p">:</span>
<span class="lineno">111</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Epoch loss: </span><span class="si">%.4f</span><span class="s2">, Min loss: </span><span class="si">%.4f</span><span class="s2">&quot;</span><span class="o">%</span><span class="p">(</span><span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">],</span> <span class="n">min_loss</span><span class="p">))</span>
<span class="lineno">112</span> <span class="n">count</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="lineno">113</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Early stopping counter: </span><span class="si">{</span><span class="n">count</span><span class="si">}</span><span class="s1"> out of </span><span class="si">{</span><span class="n">patience</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="lineno">114</span> <span class="k">if</span> <span class="n">count</span> <span class="o">&gt;=</span> <span class="n">patience</span><span class="p">:</span>
<span class="lineno">115</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;############### EarlyStopping ##################&#39;</span><span class="p">)</span>
<span class="lineno">116</span> <span class="k">break</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Saving best model</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">elif</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="n">min_loss</span><span class="p">:</span>
<span class="lineno">120</span> <span class="n">count</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">121</span> <span class="n">save_best_model</span><span class="p">({</span>
<span class="lineno">122</span> <span class="s1">&#39;epoch&#39;</span><span class="p">:</span> <span class="n">epoch</span><span class="p">,</span>
<span class="lineno">123</span> <span class="s1">&#39;state_dict&#39;</span><span class="p">:</span> <span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span>
<span class="lineno">124</span> <span class="s1">&#39;optimizer&#39;</span><span class="p">:</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span>
<span class="lineno">125</span> <span class="s1">&#39;accuracy&#39;</span> <span class="p">:</span> <span class="n">accuracies</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span>
<span class="lineno">126</span> <span class="p">},</span> <span class="n">fold</span><span class="o">=</span><span class="n">fold</span><span class="p">,</span> <span class="n">date_time</span><span class="o">=</span><span class="n">date_time</span><span class="p">)</span>
<span class="lineno">127</span> <span class="n">min_loss</span> <span class="o">=</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span>
<span class="lineno">128</span>
<span class="lineno">129</span> <span class="n">history</span><span class="o">.</span><span class="n">append</span><span class="p">({</span><span class="s1">&#39;val_loss&#39;</span><span class="p">:</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">],</span> <span class="s1">&#39;val_acc&#39;</span><span class="p">:</span> <span class="n">accuracies</span><span class="p">[</span><span class="n">epoch</span><span class="p">]})</span>
<span class="lineno">130</span> <span class="k">return</span> <span class="n">history</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span><span class="k">def</span> <span class="nf">save_best_model</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">fold</span><span class="p">,</span> <span class="n">date_time</span><span class="p">):</span>
<span class="lineno">133</span> <span class="n">directory</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="s2">&quot;./save/CV_models-</span><span class="si">%s</span><span class="s2">/&quot;</span><span class="o">%</span><span class="p">(</span><span class="n">date_time</span><span class="p">))</span>
<span class="lineno">134</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">directory</span><span class="p">):</span>
<span class="lineno">135</span> <span class="n">os</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">directory</span><span class="p">)</span>
<span class="lineno">136</span> <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="s2">&quot;</span><span class="si">%s</span><span class="s2">/fold-</span><span class="si">%d</span><span class="s2">-model.pt&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">directory</span><span class="p">,</span> <span class="n">fold</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Retrieve the model which has the best accuracy over the validation set </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span><span class="k">def</span> <span class="nf">retreive_best_trial</span><span class="p">():</span>
<span class="lineno">139</span> <span class="n">PATH</span> <span class="o">=</span> <span class="s2">&quot;./save/&quot;</span>
<span class="lineno">140</span> <span class="n">best_model</span> <span class="o">=</span> <span class="n">GetCNN</span><span class="p">()</span>
<span class="lineno">141</span>
<span class="lineno">142</span> <span class="n">content</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">PATH</span><span class="p">)</span>
<span class="lineno">143</span> <span class="n">latest_time</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">144</span> <span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">content</span><span class="p">:</span>
<span class="lineno">145</span> <span class="k">if</span> <span class="s1">&#39;CV_models&#39;</span> <span class="ow">in</span> <span class="n">item</span><span class="p">:</span>
<span class="lineno">146</span> <span class="n">foldername</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">PATH</span><span class="p">,</span> <span class="n">item</span><span class="p">)</span>
<span class="lineno">147</span> <span class="n">tm</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">getmtime</span><span class="p">(</span><span class="n">foldername</span><span class="p">)</span>
<span class="lineno">148</span> <span class="k">if</span> <span class="n">tm</span> <span class="o">&gt;</span> <span class="n">latest_time</span><span class="p">:</span>
<span class="lineno">149</span> <span class="n">latest_folder</span> <span class="o">=</span> <span class="n">foldername</span>
<span class="lineno">150</span>
<span class="lineno">151</span> <span class="n">file_type</span> <span class="o">=</span> <span class="s1">&#39;/*.pt&#39;</span>
<span class="lineno">152</span> <span class="n">files</span> <span class="o">=</span> <span class="n">glob</span><span class="p">(</span><span class="n">latest_folder</span> <span class="o">+</span> <span class="n">file_type</span><span class="p">)</span>
<span class="lineno">153</span>
<span class="lineno">154</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">155</span> <span class="k">for</span> <span class="n">model_file</span> <span class="ow">in</span> <span class="n">files</span><span class="p">:</span>
<span class="lineno">156</span> <span class="n">checkpoint</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">model_file</span><span class="p">)</span>
<span class="lineno">157</span> <span class="k">if</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s1">&#39;accuracy&#39;</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">accuracy</span><span class="p">:</span>
<span class="lineno">158</span> <span class="n">best_model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">checkpoint</span><span class="p">[</span><span class="s1">&#39;state_dict&#39;</span><span class="p">])</span>
<span class="lineno">159</span> <span class="n">best_val_accuracy</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s1">&#39;accuracy&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Test(best_model,)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="k">return</span> <span class="n">best_model</span><span class="p">,</span> <span class="n">best_val_accuracy</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span><span class="k">def</span> <span class="nf">val_step</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Forward pass</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">output</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">images</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Loss in batch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">cost</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Update validation loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">_</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">172</span> <span class="n">acc</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">preds</span> <span class="o">==</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">))</span>
<span class="lineno">173</span> <span class="n">acc_output</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;val_loss&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="o">.</span><span class="n">detach</span><span class="p">(),</span> <span class="s1">&#39;val_acc&#39;</span><span class="p">:</span> <span class="n">acc</span><span class="p">}</span>
<span class="lineno">174</span> <span class="k">return</span> <span class="n">acc_output</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Test over testloader/valloader loop</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">177</span><span class="k">def</span> <span class="nf">Test</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">testloader</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Disable Dropout</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">net</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Bookkeeping</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">correct</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="lineno">183</span> <span class="n">total</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="lineno">184</span> <span class="n">loss</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="lineno">185</span> <span class="n">train_steps</span> <span class="o">=</span> <span class="mf">0.0</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Infer the model</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="lineno">189</span> <span class="k">for</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">testloader</span><span class="p">:</span>
<span class="lineno">190</span> <span class="k">if</span> <span class="n">device</span><span class="p">:</span>
<span class="lineno">191</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">192</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">193</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="lineno">194</span>
<span class="lineno">195</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">images</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Loss in batch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">loss</span> <span class="o">+=</span> <span class="n">cost</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="lineno">198</span> <span class="n">train_steps</span><span class="o">+=</span><span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Calculate loss and accuracy over the validation set</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">_</span><span class="p">,</span> <span class="n">predicted</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">outputs</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">202</span> <span class="n">total</span> <span class="o">+=</span> <span class="n">labels</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="lineno">203</span> <span class="n">correct</span> <span class="o">+=</span> <span class="p">(</span><span class="n">predicted</span> <span class="o">==</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="lineno">204</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="o">/</span><span class="n">train_steps</span>
<span class="lineno">205</span>
<span class="lineno">206</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">correct</span> <span class="o">/</span> <span class="n">total</span>
<span class="lineno">207</span> <span class="n">loss_accuracy</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;val_loss&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="s1">&#39;val_acc&#39;</span><span class="p">:</span> <span class="n">accuracy</span><span class="p">}</span> <span class="c1">#accuracy</span>
<span class="lineno">208</span> <span class="k">return</span> <span class="n">loss_accuracy</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>