mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-01 03:43:09 +08:00 
			
		
		
		
	implemented resnets and transfer learning
This commit is contained in:
		 sachdev.kartik
					sachdev.kartik
				
			
				
					committed by
					
						 Varuna Jayasiri
						Varuna Jayasiri
					
				
			
			
				
	
			
			
			 Varuna Jayasiri
						Varuna Jayasiri
					
				
			
						parent
						
							4f31570f92
						
					
				
				
					commit
					94796efa44
				
			
							
								
								
									
										117
									
								
								labml_nn/resnets/utils/train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								labml_nn/resnets/utils/train.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,117 @@ | ||||
|  | ||||
|  | ||||
| import torch | ||||
| from torch.utils.data import DataLoader, ConcatDataset | ||||
| # from sklearn.model_selection import KFold | ||||
| # from torch.utils.data.sampler import SubsetRandomSampler | ||||
|  | ||||
| import matplotlib.pyplot as plt | ||||
| from pylab import * | ||||
| import os | ||||
|  | ||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR | ||||
|  | ||||
|  | ||||
|  | ||||
| class Trainer(): | ||||
|     def __init__(self, net, opt, cost, name="default", lr=0.0005, use_lr_schedule =False , device=None): | ||||
|         self.net = net | ||||
|         self.opt = opt | ||||
|         self.cost = cost | ||||
|         self.device = device | ||||
|         self.epoch = 0 | ||||
|         self.start_epoch = 0 | ||||
|         self.name = name | ||||
|  | ||||
|         self.lr = lr | ||||
|         self.use_lr_schedule = use_lr_schedule | ||||
|         if self.use_lr_schedule: | ||||
|             self.scheduler = ReduceLROnPlateau( self.opt, 'max', factor=0.1, patience=5, threshold=0.00001, verbose=True) | ||||
|             # self.scheduler = StepLR(self.opt, step_size=15, gamma=0.1) | ||||
|  | ||||
|     # Train loop over epochs. Optinal use testloader to return test accuracy after each epoch | ||||
|     def Train(self, trainloader, epochs, testloader=None): | ||||
|         # Enable Dropout | ||||
|  | ||||
|         # Record loss/accuracies | ||||
|         loss = torch.zeros(epochs) | ||||
|         self.epoch = 0 | ||||
|  | ||||
|         # If testloader is used, loss will be the accuracy | ||||
|         for epoch in range(self.start_epoch, self.start_epoch+epochs): | ||||
|             self.epoch = epoch+1 | ||||
|  | ||||
|             self.net.train()  # Enable Dropout | ||||
|             for data in trainloader: | ||||
|                 # Get the inputs; data is a list of [inputs, labels] | ||||
|                 if self.device: | ||||
|                     images, labels = data[0].to(self.device), data[1].to(self.device) | ||||
|                 else: | ||||
|                     images, labels = data | ||||
|  | ||||
|                 self.opt.zero_grad() | ||||
|                 # Forward + backward + optimize | ||||
|                 outputs = self.net(images) | ||||
|                 epoch_loss = self.cost(outputs, labels) | ||||
|                 epoch_loss.backward() | ||||
|                 self.opt.step() | ||||
|  | ||||
|                 loss[epoch] += epoch_loss.item() | ||||
|  | ||||
|             if testloader: | ||||
|                 loss[epoch] = self.Test(testloader) | ||||
|             else: | ||||
|                 loss[epoch] /= len(trainloader) | ||||
|  | ||||
|             print("Epoch %d Learning rate %.6f %s: %.3f" % ( | ||||
|             self.epoch, self.opt.param_groups[0]['lr'], "Accuracy" if testloader else "Loss", loss[epoch])) | ||||
|  | ||||
|             #learning rate scheduler | ||||
|             if self.use_lr_schedule: | ||||
|                 self.scheduler.step(loss[epoch]) | ||||
|                 # self.scheduler.step() | ||||
|  | ||||
|             # Saving best model | ||||
|             if loss[epoch] >= torch.max(loss): | ||||
|                 self.save_best_model({ | ||||
|                     'epoch': self.epoch, | ||||
|                     'state_dict': self.net.state_dict(), | ||||
|                     'optimizer': self.opt.state_dict(), | ||||
|                 }) | ||||
|  | ||||
|         return loss | ||||
|  | ||||
|     # Testing | ||||
|     def Test(self, testloader, ret="accuracy"): | ||||
|         # Disable Dropout | ||||
|         self.net.eval() | ||||
|  | ||||
|         # Track correct and total | ||||
|         correct = 0.0 | ||||
|         total = 0.0 | ||||
|         with torch.no_grad(): | ||||
|             for data in testloader: | ||||
|                 if self.device: | ||||
|                     images, labels = data[0].to(self.device), data[1].to(self.device) | ||||
|                 else: | ||||
|                     images, labels = data | ||||
|  | ||||
|                 outputs = self.net(images) | ||||
|                 _, predicted = torch.max(outputs.data, 1) | ||||
|                 total += labels.size(0) | ||||
|                 correct += (predicted == labels).sum().item() | ||||
|  | ||||
|         return correct / total | ||||
|  | ||||
|     def save_best_model(self, state): | ||||
|         directory = os.path.dirname("./save/%s-best-model/"%(self.name)) | ||||
|         if not os.path.exists(directory): | ||||
|             os.mkdir(directory) | ||||
|         torch.save(state, "%s/model.pt" %(directory)) | ||||
|  | ||||
|     def save_checkpoint(self, state): | ||||
|         directory = os.path.dirname("./save/%s-checkpoints/"%(self.name)) | ||||
|         if not os.path.exists(directory): | ||||
|             os.mkdir(directory) | ||||
|         torch.save(state, "%s/model_epoch_%s.pt" %(directory, self.epoch)) | ||||
|         # torch.save(state, "./save/checkpoints/model_epoch_%s.pt" % (self.epoch)) | ||||
		Reference in New Issue
	
	Block a user