mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			66 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			66 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
 | |
| import torch
 | |
| import torchvision
 | |
| import torchvision.transforms as transforms
 | |
| from torch.utils.data.sampler import SubsetRandomSampler
 | |
| import matplotlib.pyplot as plt
 | |
| import numpy as np
 | |
| import torch.optim as optim
 | |
| from torchsummary import summary
 | |
| import torch.nn as nn
 | |
| 
 | |
| # from models.mlp import MLP
 | |
| # from utils.utils import *
 | |
| # from utils.train_dataset import *
 | |
| #from nutsflow import Take, Consume
 | |
| #from nutsml import *
 | |
| from utils.dataloader import *
 | |
| from models.cnn import CNN
 | |
| from utils.train import Trainer
 | |
| 
 | |
| from utils.cv_train import *
 | |
| 
 | |
| # Check if GPU is available
 | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 | |
| print("Device:  " + str(device))
 | |
| 
 | |
| # Cifar 10 Datasets location
 | |
| save='./data/Cifar10'
 | |
| 
 | |
| # Transformations train
 | |
| transform_train = transforms.Compose(
 | |
|         [transforms.ToTensor(),
 | |
|          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 | |
| 
 | |
| # Load train dataset and dataloader
 | |
| trainset = LoadCifar10DatasetTrain(save, transform_train)
 | |
| trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
 | |
|                                           shuffle=True, num_workers=4)
 | |
| 
 | |
| # Transformations test (for inference later)
 | |
| transform_test = transforms.Compose(
 | |
|         [transforms.ToTensor(),
 | |
|          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 | |
| 
 | |
| # Load test dataset and dataloader (for inference later)
 | |
| testset = LoadCifar10DatasetTest(save, transform_test)
 | |
| testloader = torch.utils.data.DataLoader(testset, batch_size=64,
 | |
|                                          shuffle=False, num_workers=4)
 | |
| 
 | |
| # Specify loss function
 | |
| cost = nn.CrossEntropyLoss()
 | |
| 
 | |
| epochs=25  #10
 | |
| splits = 4 #5
 | |
| 
 | |
| # Training - Cross-validation
 | |
| history = cross_val_train(cost, trainset, epochs, splits, device=device)
 | |
| 
 | |
| # Inference
 | |
| best_model, best_val_accuracy = retreive_best_trial()
 | |
| print("Best Validation Accuracy = %.3f"%(best_val_accuracy))
 | |
| 
 | |
| # Testing
 | |
| accuracy = Test(best_model, cost, testloader, device=device)
 | |
| print("Test Accuracy = %.3f"%(accuracy['val_acc']))
 | 
