2import torch
3import torchvision
4import torchvision.transforms as transforms
5from torch.utils.data.sampler import SubsetRandomSampler
6import matplotlib.pyplot as plt
7import numpy as np
8import torch.optim as optim
9from torchsummary import summary
10import torch.nn as nn

from models.mlp import MLP from utils.utils import * from utils.train_dataset import * from nutsflow import Take, Consume from nutsml import *

17from utils.dataloader import *
18from models.cnn import CNN
19from utils.train import Trainer
20
21from utils.cv_train import *

Check if GPU is available

24device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25print("Device:  " + str(device))

Cifar 10 Datasets location

28save='./data/Cifar10'

Transformations train

31transform_train = transforms.Compose(
32        [transforms.ToTensor(),
33         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Load train dataset and dataloader

36trainset = LoadCifar10DatasetTrain(save, transform_train)
37trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
38                                          shuffle=True, num_workers=4)

Transformations test (for inference later)

41transform_test = transforms.Compose(
42        [transforms.ToTensor(),
43         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Load test dataset and dataloader (for inference later)

46testset = LoadCifar10DatasetTest(save, transform_test)
47testloader = torch.utils.data.DataLoader(testset, batch_size=64,
48                                         shuffle=False, num_workers=4)

Specify loss function

51cost = nn.CrossEntropyLoss()
52
53epochs=25  #10
54splits = 4 #5

Training - Cross-validation

57history = cross_val_train(cost, trainset, epochs, splits, device=device)

Inference

60best_model, best_val_accuracy = retreive_best_trial()
61print("Best Validation Accuracy = %.3f"%(best_val_accuracy))

Testing

64accuracy = Test(best_model, cost, testloader, device=device)
65print("Test Accuracy = %.3f"%(accuracy['val_acc']))