diff --git a/docs/cnn/utils/Cross-validation.png b/docs/cnn/utils/Cross-validation.png new file mode 100644 index 00000000..f3a79622 Binary files /dev/null and b/docs/cnn/utils/Cross-validation.png differ diff --git a/docs/cnn/utils/Underfitting.png b/docs/cnn/utils/Underfitting.png new file mode 100644 index 00000000..c3ee9a2e Binary files /dev/null and b/docs/cnn/utils/Underfitting.png differ diff --git a/docs/cnn/utils/cv-folds.png b/docs/cnn/utils/cv-folds.png new file mode 100644 index 00000000..4fd0b22b Binary files /dev/null and b/docs/cnn/utils/cv-folds.png differ diff --git a/docs/cnn/utils/cv_train.html b/docs/cnn/utils/cv_train.html index 1bba00d4..29f6db26 100644 --- a/docs/cnn/utils/cv_train.html +++ b/docs/cnn/utils/cv_train.html @@ -72,7 +72,38 @@
- +Implementation of fundamental techniques namely Cross-Validation and Early Stopping +
+ Getting data is expensive and in some cases, one has no option but to use a limited amount of data for training their machine learning models. + This is where Cross-Validation is useful. Steps are as follows: +
Therefore, user has to find a tradeoff between bias and variance.
++
Early-Stopping is one of the way to find this tradeoff. It helps to find a good setting of parameters and preventing overfitting on dataset and saving computation time. + This can be visualized through the following graph of train loss and validation loss over time:
It can be seen that train error continue to decrease but the validation error start to increase after around 40 epochs. + Therefore, our goal is to stop the training after the validation loss increases
+ + +3import torch
@@ -97,7 +128,10 @@
-
+ Cross-Validation
+ Splitting of training set in folds can be represented as:
+
+
21def cross_val_train(cost, trainset, epochs, splits, device=None):
@@ -156,7 +190,7 @@
- training steps
+ Training steps
65 net.train() # Enable Dropout
@@ -169,6 +203,7 @@
#
Get the inputs; data is a list of [inputs, labels]
+Load the inputs in GPU if available else CPU
68 if device:
@@ -207,7 +242,7 @@
- Print loss
+ Calculate loss
82 running_loss += loss.item()
@@ -223,7 +258,7 @@
- Validation
+ Validation and printing the metrics
90 loss_accuracy = Test(net, cost, valdata, device)
@@ -259,7 +294,17 @@
- Early stopping refered from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
+ Early stopping
+ Early stopping can be understood graphically - the way weights change during the course of training.
+
110 if losses[epoch] > min_loss:
@@ -313,7 +358,7 @@
-
+ Retrieve the model which has the best accuracy over the validation set
138def retreive_best_trial():
@@ -367,7 +412,7 @@
- forward pass
+ Forward pass
168 loss = cost(output, labels)
update validation loss
+Update validation loss
171 _, preds = torch.max(output, dim=1)
@@ -457,7 +502,7 @@
- loss in batch
+ Loss in batch
197 loss += cost(outputs, labels)
@@ -469,7 +514,7 @@
- losses[epoch] += loss.item()
+ Calculate loss and accuracy over the validation set
201 _, predicted = torch.max(outputs.data, 1)
diff --git a/docs/cnn/utils/early-stopping.png b/docs/cnn/utils/early-stopping.png
new file mode 100644
index 00000000..53b82f1d
Binary files /dev/null and b/docs/cnn/utils/early-stopping.png differ
diff --git a/docs/cnn/utils/ground_truth.png b/docs/cnn/utils/ground_truth.png
new file mode 100644
index 00000000..c4db603e
Binary files /dev/null and b/docs/cnn/utils/ground_truth.png differ
diff --git a/docs/cnn/utils/overfitting.png b/docs/cnn/utils/overfitting.png
new file mode 100644
index 00000000..a38ba41c
Binary files /dev/null and b/docs/cnn/utils/overfitting.png differ