mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 12:01:45 +08:00
118 lines
3.9 KiB
Python
118 lines
3.9 KiB
Python
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader, ConcatDataset
|
|
# from sklearn.model_selection import KFold
|
|
# from torch.utils.data.sampler import SubsetRandomSampler
|
|
|
|
import matplotlib.pyplot as plt
|
|
from pylab import *
|
|
import os
|
|
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
|
|
|
|
|
|
|
class Trainer():
|
|
def __init__(self, net, opt, cost, name="default", lr=0.0005, use_lr_schedule =False , device=None):
|
|
self.net = net
|
|
self.opt = opt
|
|
self.cost = cost
|
|
self.device = device
|
|
self.epoch = 0
|
|
self.start_epoch = 0
|
|
self.name = name
|
|
|
|
self.lr = lr
|
|
self.use_lr_schedule = use_lr_schedule
|
|
if self.use_lr_schedule:
|
|
self.scheduler = ReduceLROnPlateau( self.opt, 'max', factor=0.1, patience=5, threshold=0.00001, verbose=True)
|
|
# self.scheduler = StepLR(self.opt, step_size=15, gamma=0.1)
|
|
|
|
# Train loop over epochs. Optinal use testloader to return test accuracy after each epoch
|
|
def Train(self, trainloader, epochs, testloader=None):
|
|
# Enable Dropout
|
|
|
|
# Record loss/accuracies
|
|
loss = torch.zeros(epochs)
|
|
self.epoch = 0
|
|
|
|
# If testloader is used, loss will be the accuracy
|
|
for epoch in range(self.start_epoch, self.start_epoch+epochs):
|
|
self.epoch = epoch+1
|
|
|
|
self.net.train() # Enable Dropout
|
|
for data in trainloader:
|
|
# Get the inputs; data is a list of [inputs, labels]
|
|
if self.device:
|
|
images, labels = data[0].to(self.device), data[1].to(self.device)
|
|
else:
|
|
images, labels = data
|
|
|
|
self.opt.zero_grad()
|
|
# Forward + backward + optimize
|
|
outputs = self.net(images)
|
|
epoch_loss = self.cost(outputs, labels)
|
|
epoch_loss.backward()
|
|
self.opt.step()
|
|
|
|
loss[epoch] += epoch_loss.item()
|
|
|
|
if testloader:
|
|
loss[epoch] = self.Test(testloader)
|
|
else:
|
|
loss[epoch] /= len(trainloader)
|
|
|
|
print("Epoch %d Learning rate %.6f %s: %.3f" % (
|
|
self.epoch, self.opt.param_groups[0]['lr'], "Accuracy" if testloader else "Loss", loss[epoch]))
|
|
|
|
#learning rate scheduler
|
|
if self.use_lr_schedule:
|
|
self.scheduler.step(loss[epoch])
|
|
# self.scheduler.step()
|
|
|
|
# Saving best model
|
|
if loss[epoch] >= torch.max(loss):
|
|
self.save_best_model({
|
|
'epoch': self.epoch,
|
|
'state_dict': self.net.state_dict(),
|
|
'optimizer': self.opt.state_dict(),
|
|
})
|
|
|
|
return loss
|
|
|
|
# Testing
|
|
def Test(self, testloader, ret="accuracy"):
|
|
# Disable Dropout
|
|
self.net.eval()
|
|
|
|
# Track correct and total
|
|
correct = 0.0
|
|
total = 0.0
|
|
with torch.no_grad():
|
|
for data in testloader:
|
|
if self.device:
|
|
images, labels = data[0].to(self.device), data[1].to(self.device)
|
|
else:
|
|
images, labels = data
|
|
|
|
outputs = self.net(images)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
return correct / total
|
|
|
|
def save_best_model(self, state):
|
|
directory = os.path.dirname("./save/%s-best-model/"%(self.name))
|
|
if not os.path.exists(directory):
|
|
os.mkdir(directory)
|
|
torch.save(state, "%s/model.pt" %(directory))
|
|
|
|
def save_checkpoint(self, state):
|
|
directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
|
|
if not os.path.exists(directory):
|
|
os.mkdir(directory)
|
|
torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))
|
|
# torch.save(state, "./save/checkpoints/model_epoch_%s.pt" % (self.epoch))
|