3import torch
4import torchvision
5import torchvision.transforms as transforms
6from torch.utils.data import Dataset, random_split
7import matplotlib.pyplot as plt
8import numpy as np
10def LoadCifar10DatasetTrain(save, transform=None):
11    trainset = torchvision.datasets.CIFAR10(root=save, train=True,
12                                        download=True, transform=transform)
13    return trainset
15def LoadCifar10DatasetTest(save, transform):
16    return torchvision.datasets.CIFAR10(root=save, train=False,
17                                       download=False, transform=transform)
19def GetCustTransform():
20    transform_train = transforms.Compose([
21        transforms.RandomRotation(20),
22        transforms.RandomCrop(32, (2, 2), pad_if_needed=False, padding_mode='constant'),
23        transforms.ToTensor(),
24        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
25    return transform_train
27def Dataloader_train_valid(save, batch_size):

See utils/dataloader.py for data augmentations

30    transform_train_valid = GetCustTransform()

Get Cifar 10 Datasets

33    trainset = LoadCifar10DatasetTrain(save, transform_train_valid)
34    train_val_abs = int(len(trainset) * 0.8)
35    train_subset, val_subset = random_split(trainset, [train_val_abs, len(trainset) - train_val_abs])

Get Cifar 10 Dataloaders

38    trainloader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size,
39                                              shuffle=True, num_workers=4)
40
41    valloader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size,
42                                            shuffle=True, num_workers=4)
43    return trainloader, valloader
45def Dataloader_train(save, batch_size):

See utils/dataloader.py for data augmentations

48    transform_train = GetCustTransform()

Get Cifar 10 Datasets

51    trainset = LoadCifar10DatasetTrain(save, transform_train)

Get Cifar 10 Dataloaders

53    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
54                                              shuffle=True, num_workers=4)
55
56    return trainloader
58def Dataloader_test(save, batch_size):

transformation test set

61    transform_test = transforms.Compose(
62        [transforms.ToTensor(),
63         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

initialize test dataset and dataloader

66    testset = LoadCifar10DatasetTest(save, transform_test)
67    testloader = torch.utils.data.DataLoader(testset, batch_size=64,
68                                             shuffle=False, num_workers=4)
69
70    return testloader
72def imshow(im):
73    image = im.cpu().clone().detach().numpy()
74    image = image.transpose(1, 2, 0)
75    image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
76    plt.imshow(image)
77    plt.show()
79def imretrun(im):
80    image = im.cpu().clone().detach().numpy()
81    image = image.transpose(1, 2, 0)
82    image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
83    return image