mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
docs - cross-validation and early-stopping (#48)
Co-authored-by: sachdev.kartik <sachdev.kartik@gmail.com>
This commit is contained in:
BIN
docs/cnn/utils/Cross-validation.png
Normal file
BIN
docs/cnn/utils/Cross-validation.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 72 KiB |
BIN
docs/cnn/utils/Underfitting.png
Normal file
BIN
docs/cnn/utils/Underfitting.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 11 KiB |
BIN
docs/cnn/utils/cv-folds.png
Normal file
BIN
docs/cnn/utils/cv-folds.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 54 KiB |
@ -72,7 +72,38 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
|
||||
<h1>Cross-Validation & Early Stopping</h1>
|
||||
<p>Implementation of fundamental techniques namely <em>Cross-Validation</em> and <em>Early Stopping</em>
|
||||
<h3>Cross-Validation</h3>
|
||||
<p>
|
||||
Getting data is expensive and in some cases, one has no option but to use a limited amount of data for training their machine learning models.
|
||||
This is where Cross-Validation is useful. Steps are as follows:
|
||||
<ol type = "1">
|
||||
<li> Split the data in K folds </li>
|
||||
<li> Use K-1 folds to train a set of models</li>
|
||||
<li> Validate the models on the remaining fold</li>
|
||||
<li> Repeat (1) and (2) for all the folds</li>
|
||||
<li> Average the performance over all runs</li>
|
||||
</ol>
|
||||
</p>
|
||||
<h3>Early-Stopping</h3>
|
||||
Deep Learning networks are prone to overfitting, that is although overfitted models have a good performance on train set, they have poor generalization capabilities.
|
||||
In other words, overfitted models have low bias and high variance. Lower the bias higher the capability of model to fit the data. Higher the variance higher the sensitivity with respect to training data.
|
||||
<br>Formally, it can be represented as: </br>
|
||||
<p><script type="math/tex; mode=display"> loss = {bias}^2 + {variance} + noise </script></p>
|
||||
<p>Therefore, user has to find a tradeoff between bias and variance.</p>
|
||||
<p> </p>
|
||||
<p> Early-Stopping is one of the way to find this tradeoff. It helps to find a good setting of parameters and preventing overfitting on dataset and saving computation time.
|
||||
This can be visualized through the following graph of train loss and validation loss over time: </p> <br>
|
||||
|
||||
|
||||
<a href="https://www.deeplearningbook.org/contents/regularization.html"><img src="Cross-validation.png" alt="Training v/s Validation set Loss"></a>
|
||||
<br>
|
||||
<p> It can be seen that train error continue to decrease but the validation error start to increase after around 40 epochs.
|
||||
Therefore, our goal is to stop the training after the validation loss increases </p>
|
||||
|
||||
</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">3</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
@ -97,7 +128,10 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
|
||||
<h3>Cross-Validation</h3>
|
||||
<p> Splitting of training set in folds can be represented as: </p>
|
||||
<img src="cv-folds.png" alt="CV folds">
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">21</span><span class="k">def</span> <span class="nf">cross_val_train</span><span class="p">(</span><span class="n">cost</span><span class="p">,</span> <span class="n">trainset</span><span class="p">,</span> <span class="n">epochs</span><span class="p">,</span> <span class="n">splits</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
@ -156,7 +190,7 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>training steps</p>
|
||||
<p>Training steps</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">net</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> <span class="c1"># Enable Dropout</span>
|
||||
@ -169,6 +203,7 @@
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Get the inputs; data is a list of [inputs, labels]</p>
|
||||
<p>Load the inputs in GPU if available else CPU</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">68</span> <span class="k">if</span> <span class="n">device</span><span class="p">:</span>
|
||||
@ -207,7 +242,7 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Print loss</p>
|
||||
<p>Calculate loss</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">running_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
||||
@ -223,7 +258,7 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Validation</p>
|
||||
<p>Validation and printing the metrics</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">loss_accuracy</span> <span class="o">=</span> <span class="n">Test</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">valdata</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
|
||||
@ -259,7 +294,17 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Early stopping refered from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py</p>
|
||||
<h3>Early stopping</h3>
|
||||
<p>Early stopping can be understood graphically - the way weights change during the course of training.</p>
|
||||
<ul>
|
||||
<li> Solid contour lines indicate the contours of the negative log-likelihood (train error)</li>
|
||||
<li> Dashed line indicates the trajectory taken by the optimizer</li>
|
||||
<li> w∗ denotes the weight setting correspoding to the minimum training error </li>
|
||||
<li> w denotes the final weights setting chosen by the model after early-stopping </li>
|
||||
</ul>
|
||||
<a href="https://www.deeplearningbook.org/contents/regularization.html"><img src="early-stopping.png" alt="early-stopping" hspace="100" ></a> <!--align="middle"-->
|
||||
<br>
|
||||
<a href="https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py"><em>code reference here</em></a>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">if</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span> <span class="o">></span> <span class="n">min_loss</span><span class="p">:</span>
|
||||
@ -313,7 +358,7 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
|
||||
<p>Retrieve the model which has the best accuracy over the validation set </p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">138</span><span class="k">def</span> <span class="nf">retreive_best_trial</span><span class="p">():</span>
|
||||
@ -367,7 +412,7 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>forward pass</p>
|
||||
<p>Forward pass</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">output</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">images</span><span class="p">)</span></pre></div>
|
||||
@ -378,7 +423,7 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>loss in batch</p>
|
||||
<p>Loss in batch</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">cost</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span></pre></div>
|
||||
@ -389,7 +434,7 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>update validation loss</p>
|
||||
<p>Update validation loss</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">_</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
@ -457,7 +502,7 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>loss in batch</p>
|
||||
<p>Loss in batch</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">loss</span> <span class="o">+=</span> <span class="n">cost</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
|
||||
@ -469,7 +514,7 @@
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>losses[epoch] += loss.item()</p>
|
||||
<p>Calculate loss and accuracy over the validation set</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">_</span><span class="p">,</span> <span class="n">predicted</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">outputs</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
|
BIN
docs/cnn/utils/early-stopping.png
Normal file
BIN
docs/cnn/utils/early-stopping.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 29 KiB |
BIN
docs/cnn/utils/ground_truth.png
Normal file
BIN
docs/cnn/utils/ground_truth.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 27 KiB |
BIN
docs/cnn/utils/overfitting.png
Normal file
BIN
docs/cnn/utils/overfitting.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
Reference in New Issue
Block a user