mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 06:16:05 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			118 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			118 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
 | 
						|
 | 
						|
import torch
 | 
						|
from torch.utils.data import DataLoader, ConcatDataset
 | 
						|
# from sklearn.model_selection import KFold
 | 
						|
# from torch.utils.data.sampler import SubsetRandomSampler
 | 
						|
 | 
						|
import matplotlib.pyplot as plt
 | 
						|
from pylab import *
 | 
						|
import os
 | 
						|
 | 
						|
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
 | 
						|
 | 
						|
 | 
						|
 | 
						|
class Trainer():
 | 
						|
    def __init__(self, net, opt, cost, name="default", lr=0.0005, use_lr_schedule =False , device=None):
 | 
						|
        self.net = net
 | 
						|
        self.opt = opt
 | 
						|
        self.cost = cost
 | 
						|
        self.device = device
 | 
						|
        self.epoch = 0
 | 
						|
        self.start_epoch = 0
 | 
						|
        self.name = name
 | 
						|
 | 
						|
        self.lr = lr
 | 
						|
        self.use_lr_schedule = use_lr_schedule
 | 
						|
        if self.use_lr_schedule:
 | 
						|
            self.scheduler = ReduceLROnPlateau( self.opt, 'max', factor=0.1, patience=5, threshold=0.00001, verbose=True)
 | 
						|
            # self.scheduler = StepLR(self.opt, step_size=15, gamma=0.1)
 | 
						|
 | 
						|
    # Train loop over epochs. Optinal use testloader to return test accuracy after each epoch
 | 
						|
    def Train(self, trainloader, epochs, testloader=None):
 | 
						|
        # Enable Dropout
 | 
						|
 | 
						|
        # Record loss/accuracies
 | 
						|
        loss = torch.zeros(epochs)
 | 
						|
        self.epoch = 0
 | 
						|
 | 
						|
        # If testloader is used, loss will be the accuracy
 | 
						|
        for epoch in range(self.start_epoch, self.start_epoch+epochs):
 | 
						|
            self.epoch = epoch+1
 | 
						|
 | 
						|
            self.net.train()  # Enable Dropout
 | 
						|
            for data in trainloader:
 | 
						|
                # Get the inputs; data is a list of [inputs, labels]
 | 
						|
                if self.device:
 | 
						|
                    images, labels = data[0].to(self.device), data[1].to(self.device)
 | 
						|
                else:
 | 
						|
                    images, labels = data
 | 
						|
 | 
						|
                self.opt.zero_grad()
 | 
						|
                # Forward + backward + optimize
 | 
						|
                outputs = self.net(images)
 | 
						|
                epoch_loss = self.cost(outputs, labels)
 | 
						|
                epoch_loss.backward()
 | 
						|
                self.opt.step()
 | 
						|
 | 
						|
                loss[epoch] += epoch_loss.item()
 | 
						|
 | 
						|
            if testloader:
 | 
						|
                loss[epoch] = self.Test(testloader)
 | 
						|
            else:
 | 
						|
                loss[epoch] /= len(trainloader)
 | 
						|
 | 
						|
            print("Epoch %d Learning rate %.6f %s: %.3f" % (
 | 
						|
            self.epoch, self.opt.param_groups[0]['lr'], "Accuracy" if testloader else "Loss", loss[epoch]))
 | 
						|
 | 
						|
            #learning rate scheduler
 | 
						|
            if self.use_lr_schedule:
 | 
						|
                self.scheduler.step(loss[epoch])
 | 
						|
                # self.scheduler.step()
 | 
						|
 | 
						|
            # Saving best model
 | 
						|
            if loss[epoch] >= torch.max(loss):
 | 
						|
                self.save_best_model({
 | 
						|
                    'epoch': self.epoch,
 | 
						|
                    'state_dict': self.net.state_dict(),
 | 
						|
                    'optimizer': self.opt.state_dict(),
 | 
						|
                })
 | 
						|
 | 
						|
        return loss
 | 
						|
 | 
						|
    # Testing
 | 
						|
    def Test(self, testloader, ret="accuracy"):
 | 
						|
        # Disable Dropout
 | 
						|
        self.net.eval()
 | 
						|
 | 
						|
        # Track correct and total
 | 
						|
        correct = 0.0
 | 
						|
        total = 0.0
 | 
						|
        with torch.no_grad():
 | 
						|
            for data in testloader:
 | 
						|
                if self.device:
 | 
						|
                    images, labels = data[0].to(self.device), data[1].to(self.device)
 | 
						|
                else:
 | 
						|
                    images, labels = data
 | 
						|
 | 
						|
                outputs = self.net(images)
 | 
						|
                _, predicted = torch.max(outputs.data, 1)
 | 
						|
                total += labels.size(0)
 | 
						|
                correct += (predicted == labels).sum().item()
 | 
						|
 | 
						|
        return correct / total
 | 
						|
 | 
						|
    def save_best_model(self, state):
 | 
						|
        directory = os.path.dirname("./save/%s-best-model/"%(self.name))
 | 
						|
        if not os.path.exists(directory):
 | 
						|
            os.mkdir(directory)
 | 
						|
        torch.save(state, "%s/model.pt" %(directory))
 | 
						|
 | 
						|
    def save_checkpoint(self, state):
 | 
						|
        directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
 | 
						|
        if not os.path.exists(directory):
 | 
						|
            os.mkdir(directory)
 | 
						|
        torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))
 | 
						|
        # torch.save(state, "./save/checkpoints/model_epoch_%s.pt" % (self.epoch))
 |