implemented resnets and transfer learning

This commit is contained in:
sachdev.kartik
2021-02-25 04:36:37 +01:00
committed by Varuna Jayasiri
parent 4f31570f92
commit 94796efa44
11 changed files with 640 additions and 0 deletions

View File

@ -0,0 +1,117 @@
import torch
from torch.utils.data import DataLoader, ConcatDataset
# from sklearn.model_selection import KFold
# from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
from pylab import *
import os
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
class Trainer():
def __init__(self, net, opt, cost, name="default", lr=0.0005, use_lr_schedule =False , device=None):
self.net = net
self.opt = opt
self.cost = cost
self.device = device
self.epoch = 0
self.start_epoch = 0
self.name = name
self.lr = lr
self.use_lr_schedule = use_lr_schedule
if self.use_lr_schedule:
self.scheduler = ReduceLROnPlateau( self.opt, 'max', factor=0.1, patience=5, threshold=0.00001, verbose=True)
# self.scheduler = StepLR(self.opt, step_size=15, gamma=0.1)
# Train loop over epochs. Optinal use testloader to return test accuracy after each epoch
def Train(self, trainloader, epochs, testloader=None):
# Enable Dropout
# Record loss/accuracies
loss = torch.zeros(epochs)
self.epoch = 0
# If testloader is used, loss will be the accuracy
for epoch in range(self.start_epoch, self.start_epoch+epochs):
self.epoch = epoch+1
self.net.train() # Enable Dropout
for data in trainloader:
# Get the inputs; data is a list of [inputs, labels]
if self.device:
images, labels = data[0].to(self.device), data[1].to(self.device)
else:
images, labels = data
self.opt.zero_grad()
# Forward + backward + optimize
outputs = self.net(images)
epoch_loss = self.cost(outputs, labels)
epoch_loss.backward()
self.opt.step()
loss[epoch] += epoch_loss.item()
if testloader:
loss[epoch] = self.Test(testloader)
else:
loss[epoch] /= len(trainloader)
print("Epoch %d Learning rate %.6f %s: %.3f" % (
self.epoch, self.opt.param_groups[0]['lr'], "Accuracy" if testloader else "Loss", loss[epoch]))
#learning rate scheduler
if self.use_lr_schedule:
self.scheduler.step(loss[epoch])
# self.scheduler.step()
# Saving best model
if loss[epoch] >= torch.max(loss):
self.save_best_model({
'epoch': self.epoch,
'state_dict': self.net.state_dict(),
'optimizer': self.opt.state_dict(),
})
return loss
# Testing
def Test(self, testloader, ret="accuracy"):
# Disable Dropout
self.net.eval()
# Track correct and total
correct = 0.0
total = 0.0
with torch.no_grad():
for data in testloader:
if self.device:
images, labels = data[0].to(self.device), data[1].to(self.device)
else:
images, labels = data
outputs = self.net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total
def save_best_model(self, state):
directory = os.path.dirname("./save/%s-best-model/"%(self.name))
if not os.path.exists(directory):
os.mkdir(directory)
torch.save(state, "%s/model.pt" %(directory))
def save_checkpoint(self, state):
directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
if not os.path.exists(directory):
os.mkdir(directory)
torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))
# torch.save(state, "./save/checkpoints/model_epoch_%s.pt" % (self.epoch))