diff --git a/docs/cnn/utils/Cross-validation.png b/docs/cnn/utils/Cross-validation.png new file mode 100644 index 00000000..f3a79622 Binary files /dev/null and b/docs/cnn/utils/Cross-validation.png differ diff --git a/docs/cnn/utils/Underfitting.png b/docs/cnn/utils/Underfitting.png new file mode 100644 index 00000000..c3ee9a2e Binary files /dev/null and b/docs/cnn/utils/Underfitting.png differ diff --git a/docs/cnn/utils/cv-folds.png b/docs/cnn/utils/cv-folds.png new file mode 100644 index 00000000..4fd0b22b Binary files /dev/null and b/docs/cnn/utils/cv-folds.png differ diff --git a/docs/cnn/utils/cv_train.html b/docs/cnn/utils/cv_train.html index 1bba00d4..29f6db26 100644 --- a/docs/cnn/utils/cv_train.html +++ b/docs/cnn/utils/cv_train.html @@ -72,7 +72,38 @@ - +

Cross-Validation & Early Stopping

+

Implementation of fundamental techniques namely Cross-Validation and Early Stopping +

Cross-Validation

+

+ Getting data is expensive and in some cases, one has no option but to use a limited amount of data for training their machine learning models. + This is where Cross-Validation is useful. Steps are as follows: +

    +
  1. Split the data in K folds
  2. +
  3. Use K-1 folds to train a set of models
  4. +
  5. Validate the models on the remaining fold
  6. +
  7. Repeat (1) and (2) for all the folds
  8. +
  9. Average the performance over all runs
  10. +
+

+

Early-Stopping

+ Deep Learning networks are prone to overfitting, that is although overfitted models have a good performance on train set, they have poor generalization capabilities. + In other words, overfitted models have low bias and high variance. Lower the bias higher the capability of model to fit the data. Higher the variance higher the sensitivity with respect to training data. +
Formally, it can be represented as:
+

+

Therefore, user has to find a tradeoff between bias and variance.

+

+

Early-Stopping is one of the way to find this tradeoff. It helps to find a good setting of parameters and preventing overfitting on dataset and saving computation time. + This can be visualized through the following graph of train loss and validation loss over time:


+ + + Training v/s Validation set Loss +
+

It can be seen that train error continue to decrease but the validation error start to increase after around 40 epochs. + Therefore, our goal is to stop the training after the validation loss increases

+ +

+
3import torch
@@ -97,7 +128,10 @@
                 
-                
+                    

Cross-Validation

+

Splitting of training set in folds can be represented as:

+ CV folds +
21def cross_val_train(cost, trainset, epochs, splits, device=None):
@@ -156,7 +190,7 @@
                 
-                

training steps

+

Training steps

65            net.train()  # Enable Dropout
@@ -169,6 +203,7 @@
                     #
                 

Get the inputs; data is a list of [inputs, labels]

+

Load the inputs in GPU if available else CPU

68                if device:
@@ -207,7 +242,7 @@
                 
-                

Print loss

+

Calculate loss

82                running_loss += loss.item()
@@ -223,7 +258,7 @@
                 
-                

Validation

+

Validation and printing the metrics

90            loss_accuracy = Test(net, cost, valdata, device)
@@ -259,7 +294,17 @@
                 
-                

Early stopping refered from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py

+

Early stopping

+

Early stopping can be understood graphically - the way weights change during the course of training.

+
    +
  • Solid contour lines indicate the contours of the negative log-likelihood (train error)
  • +
  • Dashed line indicates the trajectory taken by the optimizer
  • +
  • w∗ denotes the weight setting correspoding to the minimum training error
  • +
  • w denotes the final weights setting chosen by the model after early-stopping
  • +
+ early-stopping +
+ code reference here
110            if losses[epoch] > min_loss:
@@ -313,7 +358,7 @@
                 
-                
+                

Retrieve the model which has the best accuracy over the validation set

138def retreive_best_trial():
@@ -367,7 +412,7 @@
                 
-                

forward pass

+

Forward pass

166    output = net(images)
@@ -378,7 +423,7 @@ -

loss in batch

+

Loss in batch

168    loss = cost(output, labels)
@@ -389,7 +434,7 @@ -

update validation loss

+

Update validation loss

171    _, preds = torch.max(output, dim=1)
@@ -457,7 +502,7 @@
                 
-                

loss in batch

+

Loss in batch

197            loss += cost(outputs, labels)
@@ -469,7 +514,7 @@
                 
-                

losses[epoch] += loss.item()

+

Calculate loss and accuracy over the validation set

201            _, predicted = torch.max(outputs.data, 1)
diff --git a/docs/cnn/utils/early-stopping.png b/docs/cnn/utils/early-stopping.png
new file mode 100644
index 00000000..53b82f1d
Binary files /dev/null and b/docs/cnn/utils/early-stopping.png differ
diff --git a/docs/cnn/utils/ground_truth.png b/docs/cnn/utils/ground_truth.png
new file mode 100644
index 00000000..c4db603e
Binary files /dev/null and b/docs/cnn/utils/ground_truth.png differ
diff --git a/docs/cnn/utils/overfitting.png b/docs/cnn/utils/overfitting.png
new file mode 100644
index 00000000..a38ba41c
Binary files /dev/null and b/docs/cnn/utils/overfitting.png differ