mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-04 06:16:05 +08:00
Implemented Cross Validation & Early Stopping (#38)
This commit is contained in:
65
labml_nn/cnn/cross_validation.py
Normal file
65
labml_nn/cnn/cross_validation.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from torch.utils.data.sampler import SubsetRandomSampler
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchsummary import summary
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# from models.mlp import MLP
|
||||||
|
# from utils.utils import *
|
||||||
|
# from utils.train_dataset import *
|
||||||
|
#from nutsflow import Take, Consume
|
||||||
|
#from nutsml import *
|
||||||
|
from utils.dataloader import *
|
||||||
|
from models.cnn import CNN
|
||||||
|
from utils.train import Trainer
|
||||||
|
|
||||||
|
from utils.cv_train import *
|
||||||
|
|
||||||
|
# Check if GPU is available
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
print("Device: " + str(device))
|
||||||
|
|
||||||
|
# Cifar 10 Datasets location
|
||||||
|
save='./data/Cifar10'
|
||||||
|
|
||||||
|
# Transformations train
|
||||||
|
transform_train = transforms.Compose(
|
||||||
|
[transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
# Load train dataset and dataloader
|
||||||
|
trainset = LoadCifar10DatasetTrain(save, transform_train)
|
||||||
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
|
||||||
|
shuffle=True, num_workers=4)
|
||||||
|
|
||||||
|
# Transformations test (for inference later)
|
||||||
|
transform_test = transforms.Compose(
|
||||||
|
[transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
# Load test dataset and dataloader (for inference later)
|
||||||
|
testset = LoadCifar10DatasetTest(save, transform_test)
|
||||||
|
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
|
||||||
|
shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
# Specify loss function
|
||||||
|
cost = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
epochs=25 #10
|
||||||
|
splits = 4 #5
|
||||||
|
|
||||||
|
# Training - Cross-validation
|
||||||
|
history = cross_val_train(cost, trainset, epochs, splits, device=device)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
best_model, best_val_accuracy = retreive_best_trial()
|
||||||
|
print("Best Validation Accuracy = %.3f"%(best_val_accuracy))
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
accuracy = Test(best_model, cost, testloader, device=device)
|
||||||
|
print("Test Accuracy = %.3f"%(accuracy['val_acc']))
|
||||||
208
labml_nn/cnn/utils/cv_train.py
Normal file
208
labml_nn/cnn/utils/cv_train.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
#!/bin/python
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch.utils.data import Subset
|
||||||
|
|
||||||
|
from sklearn.model_selection import KFold
|
||||||
|
from torch.utils.data.sampler import SubsetRandomSampler
|
||||||
|
from models.cnn import GetCNN
|
||||||
|
from torchsummary import summary
|
||||||
|
import torch.optim as optim
|
||||||
|
import os
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def cross_val_train(cost, trainset, epochs, splits, device=None):
|
||||||
|
|
||||||
|
patience = 4
|
||||||
|
history = []
|
||||||
|
kf = KFold(n_splits=splits, shuffle=True)
|
||||||
|
batch_size = 64
|
||||||
|
now = datetime.now()
|
||||||
|
date_time = now.strftime("%d-%m-%Y_%H:%M:%S")
|
||||||
|
directory = os.path.dirname('./save/tensorboard-%s/'%(date_time))
|
||||||
|
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.mkdir(directory)
|
||||||
|
|
||||||
|
for fold, (train_index, test_index) in enumerate(kf.split(trainset.data, trainset.targets)): #dataset required - compelete training set
|
||||||
|
comment = f'{directory}/fold-{fold}'
|
||||||
|
writer = SummaryWriter(log_dir=comment)
|
||||||
|
|
||||||
|
train_sampler = SubsetRandomSampler(train_index)
|
||||||
|
valid_sampler = SubsetRandomSampler(test_index)
|
||||||
|
traindata = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler,
|
||||||
|
num_workers=2)
|
||||||
|
valdata = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=valid_sampler,
|
||||||
|
num_workers=2)
|
||||||
|
|
||||||
|
net = GetCNN()
|
||||||
|
net.to(device)
|
||||||
|
if fold == 0: #Printing model detials for the first time
|
||||||
|
summary(net, (3, 32, 32))
|
||||||
|
|
||||||
|
|
||||||
|
# Specify optimizer
|
||||||
|
optimizer = optim.Adam(net.parameters(), lr=0.0005, betas=(0.9, 0.95))
|
||||||
|
losses = torch.zeros(epochs)
|
||||||
|
accuracies = torch.zeros(epochs)
|
||||||
|
min_loss = None
|
||||||
|
count = 0
|
||||||
|
for epoch in range(epochs):
|
||||||
|
valid_loss = 0
|
||||||
|
running_loss = 0.0
|
||||||
|
epoch_loss = 0.0
|
||||||
|
train_loss = torch.zeros(epochs)
|
||||||
|
train_steps = 0.0
|
||||||
|
|
||||||
|
# training steps
|
||||||
|
net.train() # Enable Dropout
|
||||||
|
for i, data in enumerate(traindata, 0):
|
||||||
|
# Get the inputs; data is a list of [inputs, labels]
|
||||||
|
if device:
|
||||||
|
images, labels = data[0].to(device), data[1].to(device)
|
||||||
|
else:
|
||||||
|
images, labels = data
|
||||||
|
|
||||||
|
# Forward + backward + optimize
|
||||||
|
outputs = net(images)
|
||||||
|
loss = cost(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
# Zero the parameter gradients
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Print loss
|
||||||
|
running_loss += loss.item()
|
||||||
|
epoch_loss += loss.item()
|
||||||
|
train_loss[epoch] += loss.item()
|
||||||
|
train_steps += 1
|
||||||
|
|
||||||
|
loss_train = train_loss[epoch] / train_steps
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
loss_accuracy = Test(net, cost, valdata, device)
|
||||||
|
|
||||||
|
losses[epoch] = loss_accuracy['val_loss']
|
||||||
|
accuracies[epoch] = loss_accuracy['val_acc']
|
||||||
|
print("Fold %d, Epoch %d, Train Loss %.4f Validation Loss: %.4f, Validation Accuracy: %.4f" % (fold+1, epoch+1, loss_train, losses[epoch], accuracies[epoch]))
|
||||||
|
|
||||||
|
# TensorBoard
|
||||||
|
info = {
|
||||||
|
"Loss/train": loss_train,
|
||||||
|
"Loss/valid": losses[epoch],
|
||||||
|
"Accuracy/valid": accuracies[epoch]
|
||||||
|
}
|
||||||
|
|
||||||
|
for tag, item in info.items():
|
||||||
|
writer.add_scalar(tag, item, global_step=epoch)
|
||||||
|
|
||||||
|
if min_loss == None:
|
||||||
|
min_loss = losses[epoch]
|
||||||
|
|
||||||
|
# Early stopping refered from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
|
||||||
|
if losses[epoch] > min_loss:
|
||||||
|
print("Epoch loss: %.4f, Min loss: %.4f"%(losses[epoch], min_loss))
|
||||||
|
count += 1
|
||||||
|
print(f'Early stopping counter: {count} out of {patience}')
|
||||||
|
if count >= patience:
|
||||||
|
print(f'############### EarlyStopping ##################')
|
||||||
|
break
|
||||||
|
|
||||||
|
# Saving best model
|
||||||
|
elif losses[epoch] <= min_loss:
|
||||||
|
count = 0
|
||||||
|
save_best_model({
|
||||||
|
'epoch': epoch,
|
||||||
|
'state_dict': net.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'accuracy' : accuracies[epoch]
|
||||||
|
}, fold=fold, date_time=date_time)
|
||||||
|
min_loss = losses[epoch]
|
||||||
|
|
||||||
|
history.append({'val_loss': losses[epoch], 'val_acc': accuracies[epoch]})
|
||||||
|
return history
|
||||||
|
|
||||||
|
def save_best_model(state, fold, date_time):
|
||||||
|
directory = os.path.dirname("./save/CV_models-%s/"%(date_time))
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.mkdir(directory)
|
||||||
|
torch.save(state, "%s/fold-%d-model.pt" % (directory, fold))
|
||||||
|
|
||||||
|
def retreive_best_trial():
|
||||||
|
PATH = "./save/"
|
||||||
|
best_model = GetCNN()
|
||||||
|
|
||||||
|
content = os.listdir(PATH)
|
||||||
|
latest_time = 0
|
||||||
|
for item in content:
|
||||||
|
if 'CV_models' in item:
|
||||||
|
foldername = os.path.join(PATH, item)
|
||||||
|
tm = os.path.getmtime(foldername)
|
||||||
|
if tm > latest_time:
|
||||||
|
latest_folder = foldername
|
||||||
|
|
||||||
|
file_type = '/*.pt'
|
||||||
|
files = glob(latest_folder + file_type)
|
||||||
|
|
||||||
|
accuracy = 0
|
||||||
|
for model_file in files:
|
||||||
|
checkpoint = torch.load(model_file)
|
||||||
|
if checkpoint['accuracy'] > accuracy:
|
||||||
|
best_model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
best_val_accuracy = checkpoint['accuracy']
|
||||||
|
# Test(best_model,)
|
||||||
|
|
||||||
|
return best_model, best_val_accuracy
|
||||||
|
|
||||||
|
def val_step(net, cost, images, labels):
|
||||||
|
# forward pass
|
||||||
|
output = net(images)
|
||||||
|
# loss in batch
|
||||||
|
loss = cost(output, labels)
|
||||||
|
|
||||||
|
# update validation loss
|
||||||
|
_, preds = torch.max(output, dim=1)
|
||||||
|
acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))
|
||||||
|
acc_output = {'val_loss': loss.detach(), 'val_acc': acc}
|
||||||
|
return acc_output
|
||||||
|
|
||||||
|
# Test over testloader/valloader loop
|
||||||
|
def Test(net, cost, testloader, device):
|
||||||
|
# Disable Dropout
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
# Bookkeeping
|
||||||
|
correct = 0.0
|
||||||
|
total = 0.0
|
||||||
|
loss = 0.0
|
||||||
|
train_steps = 0.0
|
||||||
|
|
||||||
|
# Infer the model
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in testloader:
|
||||||
|
if device:
|
||||||
|
images, labels = data[0].to(device), data[1].to(device)
|
||||||
|
else:
|
||||||
|
images, labels = data[0], data[1]
|
||||||
|
|
||||||
|
outputs = net(images)
|
||||||
|
# loss in batch
|
||||||
|
loss += cost(outputs, labels)
|
||||||
|
train_steps+=1
|
||||||
|
# losses[epoch] += loss.item()
|
||||||
|
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
loss = loss/train_steps
|
||||||
|
|
||||||
|
accuracy = correct / total
|
||||||
|
loss_accuracy = {'val_loss': loss, 'val_acc': accuracy} #accuracy
|
||||||
|
return loss_accuracy
|
||||||
Reference in New Issue
Block a user