Cross-Validation & Early Stopping

Implementation of fundamental techniques namely Cross-Validation and Early Stopping

Cross-Validation

Getting data is expensive and in some cases, one has no option but to use a limited amount of data for training their machine learning models. This is where Cross-Validation is useful. Steps are as follows:

  1. Split the data in K folds
  2. Use K-1 folds to train a set of models
  3. Validate the models on the remaining fold
  4. Repeat (1) and (2) for all the folds
  5. Average the performance over all runs

Early-Stopping

Deep Learning networks are prone to overfitting, that is although overfitted models have a good performance on train set, they have poor generalization capabilities. In other words, overfitted models have low bias and high variance. Lower the bias higher the capability of model to fit the data. Higher the variance higher the sensitivity with respect to training data.
Formally, it can be represented as:

Therefore, user has to find a tradeoff between bias and variance.

Early-Stopping is one of the way to find this tradeoff. It helps to find a good setting of parameters and preventing overfitting on dataset and saving computation time. This can be visualized through the following graph of train loss and validation loss over time:


Training v/s Validation set Loss

It can be seen that train error continue to decrease but the validation error start to increase after around 40 epochs. Therefore, our goal is to stop the training after the validation loss increases

3import torch
4
5from torch.utils.data import Subset
6
7from sklearn.model_selection import KFold
8from torch.utils.data.sampler import SubsetRandomSampler
9from models.cnn import GetCNN
10from torchsummary import summary
11import torch.optim as optim
12import os
13
14from torch.utils.tensorboard import SummaryWriter
15
16from datetime import datetime
17from glob import glob

Cross-Validation

Splitting of training set in folds can be represented as:

CV folds
21def cross_val_train(cost, trainset, epochs, splits, device=None):
22
23    patience = 4
24    history = []
25    kf = KFold(n_splits=splits, shuffle=True)
26    batch_size = 64
27    now = datetime.now()
28    date_time = now.strftime("%d-%m-%Y_%H:%M:%S")
29    directory = os.path.dirname('./save/tensorboard-%s/'%(date_time))
30
31    if not os.path.exists(directory):
32        os.mkdir(directory)
33
34    for fold, (train_index, test_index) in enumerate(kf.split(trainset.data, trainset.targets)): #dataset required - compelete training set
35        comment = f'{directory}/fold-{fold}'
36        writer = SummaryWriter(log_dir=comment)
37
38        train_sampler = SubsetRandomSampler(train_index)
39        valid_sampler = SubsetRandomSampler(test_index)
40        traindata = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler,
41                                                   num_workers=2)
42        valdata = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=valid_sampler,
43                                                   num_workers=2)
44
45        net = GetCNN()
46        net.to(device)
47        if fold == 0: #Printing model detials for the first time
48            summary(net, (3, 32, 32))

Specify optimizer

52        optimizer = optim.Adam(net.parameters(), lr=0.0005, betas=(0.9, 0.95))
53        losses = torch.zeros(epochs)
54        accuracies = torch.zeros(epochs)
55        min_loss = None
56        count = 0
57        for epoch in range(epochs):
58            valid_loss = 0
59            running_loss = 0.0
60            epoch_loss = 0.0
61            train_loss = torch.zeros(epochs)
62            train_steps = 0.0

Training steps

65            net.train()  # Enable Dropout
66            for i, data in enumerate(traindata, 0):

Get the inputs; data is a list of [inputs, labels]

Load the inputs in GPU if available else CPU

68                if device:
69                    images, labels = data[0].to(device), data[1].to(device)
70                else:
71                    images, labels = data

Forward + backward + optimize

74                outputs = net(images)
75                loss = cost(outputs, labels)
76                loss.backward()
77                optimizer.step()

Zero the parameter gradients

79                optimizer.zero_grad()

Calculate loss

82                running_loss += loss.item()
83                epoch_loss += loss.item()
84                train_loss[epoch] += loss.item()
85                train_steps += 1
86
87            loss_train = train_loss[epoch] / train_steps

Validation and printing the metrics

90            loss_accuracy = Test(net, cost, valdata, device)
91
92            losses[epoch] = loss_accuracy['val_loss']
93            accuracies[epoch] = loss_accuracy['val_acc']
94            print("Fold %d, Epoch %d, Train Loss %.4f Validation Loss: %.4f, Validation Accuracy: %.4f" % (fold+1, epoch+1, loss_train, losses[epoch], accuracies[epoch]))

TensorBoard

97            info = {
98                "Loss/train": loss_train,
99                "Loss/valid": losses[epoch],
100                "Accuracy/valid": accuracies[epoch]
101                }
102
103            for tag, item in info.items():
104                writer.add_scalar(tag, item, global_step=epoch)
105
106            if min_loss == None:
107                min_loss = losses[epoch]

Early stopping

Early stopping can be understood graphically - the way weights change during the course of training.

  • Solid contour lines indicate the contours of the negative log-likelihood (train error)
  • Dashed line indicates the trajectory taken by the optimizer
  • w∗ denotes the weight setting correspoding to the minimum training error
  • w denotes the final weights setting chosen by the model after early-stopping
early-stopping
code reference here
110            if losses[epoch] > min_loss:
111                print("Epoch loss: %.4f, Min loss: %.4f"%(losses[epoch], min_loss))
112                count += 1
113                print(f'Early stopping counter: {count} out of {patience}')
114                if count >= patience:
115                    print(f'############### EarlyStopping ##################')
116                    break

Saving best model

119            elif losses[epoch] <= min_loss:
120                count = 0
121                save_best_model({
122                    'epoch': epoch,
123                    'state_dict': net.state_dict(),
124                    'optimizer': optimizer.state_dict(),
125                    'accuracy' : accuracies[epoch]
126                }, fold=fold, date_time=date_time)
127                min_loss = losses[epoch]
128
129            history.append({'val_loss': losses[epoch], 'val_acc': accuracies[epoch]})
130    return history
132def save_best_model(state, fold, date_time):
133    directory = os.path.dirname("./save/CV_models-%s/"%(date_time))
134    if not os.path.exists(directory):
135        os.mkdir(directory)
136    torch.save(state, "%s/fold-%d-model.pt" % (directory, fold))

Retrieve the model which has the best accuracy over the validation set

138def retreive_best_trial():
139    PATH = "./save/"
140    best_model = GetCNN()
141
142    content = os.listdir(PATH)
143    latest_time = 0
144    for item in content:
145        if 'CV_models' in item:
146            foldername = os.path.join(PATH, item)
147            tm = os.path.getmtime(foldername)
148            if tm > latest_time:
149                latest_folder = foldername
150
151    file_type = '/*.pt'
152    files = glob(latest_folder + file_type)
153
154    accuracy = 0
155    for model_file in files:
156        checkpoint = torch.load(model_file)
157        if checkpoint['accuracy'] > accuracy:
158            best_model.load_state_dict(checkpoint['state_dict'])
159            best_val_accuracy = checkpoint['accuracy']

Test(best_model,)

162    return best_model, best_val_accuracy
164def val_step(net, cost, images, labels):

Forward pass

166    output = net(images)

Loss in batch

168    loss = cost(output, labels)

Update validation loss

171    _, preds = torch.max(output, dim=1)
172    acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))
173    acc_output = {'val_loss': loss.detach(), 'val_acc': acc}
174    return acc_output

Test over testloader/valloader loop

177def Test(net, cost, testloader, device):

Disable Dropout

179    net.eval()

Bookkeeping

182    correct = 0.0
183    total = 0.0
184    loss = 0.0
185    train_steps = 0.0

Infer the model

188    with torch.no_grad():
189        for data in testloader:
190            if device:
191                images, labels = data[0].to(device), data[1].to(device)
192            else:
193                images, labels = data[0], data[1]
194
195            outputs = net(images)

Loss in batch

197            loss += cost(outputs, labels)
198            train_steps+=1

Calculate loss and accuracy over the validation set

201            _, predicted = torch.max(outputs.data, 1)
202            total += labels.size(0)
203            correct += (predicted == labels).sum().item()
204        loss = loss/train_steps
205
206    accuracy = correct / total
207    loss_accuracy = {'val_loss': loss, 'val_acc': accuracy} #accuracy
208    return loss_accuracy