mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			118 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			118 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
 | |
| 
 | |
| import torch
 | |
| from torch.utils.data import DataLoader, ConcatDataset
 | |
| # from sklearn.model_selection import KFold
 | |
| # from torch.utils.data.sampler import SubsetRandomSampler
 | |
| 
 | |
| import matplotlib.pyplot as plt
 | |
| from pylab import *
 | |
| import os
 | |
| 
 | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
 | |
| 
 | |
| 
 | |
| 
 | |
| class Trainer():
 | |
|     def __init__(self, net, opt, cost, name="default", lr=0.0005, use_lr_schedule =False , device=None):
 | |
|         self.net = net
 | |
|         self.opt = opt
 | |
|         self.cost = cost
 | |
|         self.device = device
 | |
|         self.epoch = 0
 | |
|         self.start_epoch = 0
 | |
|         self.name = name
 | |
| 
 | |
|         self.lr = lr
 | |
|         self.use_lr_schedule = use_lr_schedule
 | |
|         if self.use_lr_schedule:
 | |
|             self.scheduler = ReduceLROnPlateau( self.opt, 'max', factor=0.1, patience=5, threshold=0.00001, verbose=True)
 | |
|             # self.scheduler = StepLR(self.opt, step_size=15, gamma=0.1)
 | |
| 
 | |
|     # Train loop over epochs. Optinal use testloader to return test accuracy after each epoch
 | |
|     def Train(self, trainloader, epochs, testloader=None):
 | |
|         # Enable Dropout
 | |
| 
 | |
|         # Record loss/accuracies
 | |
|         loss = torch.zeros(epochs)
 | |
|         self.epoch = 0
 | |
| 
 | |
|         # If testloader is used, loss will be the accuracy
 | |
|         for epoch in range(self.start_epoch, self.start_epoch+epochs):
 | |
|             self.epoch = epoch+1
 | |
| 
 | |
|             self.net.train()  # Enable Dropout
 | |
|             for data in trainloader:
 | |
|                 # Get the inputs; data is a list of [inputs, labels]
 | |
|                 if self.device:
 | |
|                     images, labels = data[0].to(self.device), data[1].to(self.device)
 | |
|                 else:
 | |
|                     images, labels = data
 | |
| 
 | |
|                 self.opt.zero_grad()
 | |
|                 # Forward + backward + optimize
 | |
|                 outputs = self.net(images)
 | |
|                 epoch_loss = self.cost(outputs, labels)
 | |
|                 epoch_loss.backward()
 | |
|                 self.opt.step()
 | |
| 
 | |
|                 loss[epoch] += epoch_loss.item()
 | |
| 
 | |
|             if testloader:
 | |
|                 loss[epoch] = self.Test(testloader)
 | |
|             else:
 | |
|                 loss[epoch] /= len(trainloader)
 | |
| 
 | |
|             print("Epoch %d Learning rate %.6f %s: %.3f" % (
 | |
|             self.epoch, self.opt.param_groups[0]['lr'], "Accuracy" if testloader else "Loss", loss[epoch]))
 | |
| 
 | |
|             #learning rate scheduler
 | |
|             if self.use_lr_schedule:
 | |
|                 self.scheduler.step(loss[epoch])
 | |
|                 # self.scheduler.step()
 | |
| 
 | |
|             # Saving best model
 | |
|             if loss[epoch] >= torch.max(loss):
 | |
|                 self.save_best_model({
 | |
|                     'epoch': self.epoch,
 | |
|                     'state_dict': self.net.state_dict(),
 | |
|                     'optimizer': self.opt.state_dict(),
 | |
|                 })
 | |
| 
 | |
|         return loss
 | |
| 
 | |
|     # Testing
 | |
|     def Test(self, testloader, ret="accuracy"):
 | |
|         # Disable Dropout
 | |
|         self.net.eval()
 | |
| 
 | |
|         # Track correct and total
 | |
|         correct = 0.0
 | |
|         total = 0.0
 | |
|         with torch.no_grad():
 | |
|             for data in testloader:
 | |
|                 if self.device:
 | |
|                     images, labels = data[0].to(self.device), data[1].to(self.device)
 | |
|                 else:
 | |
|                     images, labels = data
 | |
| 
 | |
|                 outputs = self.net(images)
 | |
|                 _, predicted = torch.max(outputs.data, 1)
 | |
|                 total += labels.size(0)
 | |
|                 correct += (predicted == labels).sum().item()
 | |
| 
 | |
|         return correct / total
 | |
| 
 | |
|     def save_best_model(self, state):
 | |
|         directory = os.path.dirname("./save/%s-best-model/"%(self.name))
 | |
|         if not os.path.exists(directory):
 | |
|             os.mkdir(directory)
 | |
|         torch.save(state, "%s/model.pt" %(directory))
 | |
| 
 | |
|     def save_checkpoint(self, state):
 | |
|         directory = os.path.dirname("./save/%s-checkpoints/"%(self.name))
 | |
|         if not os.path.exists(directory):
 | |
|             os.mkdir(directory)
 | |
|         torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch))
 | |
|         # torch.save(state, "./save/checkpoints/model_epoch_%s.pt" % (self.epoch))
 | 
