3import torch
4import torchvision
5import torchvision.transforms as transforms
6from torch.utils.data import Dataset, random_split
7import matplotlib.pyplot as plt
8import numpy as np
10def LoadCifar10DatasetTrain(save, transform=None):
11 trainset = torchvision.datasets.CIFAR10(root=save, train=True,
12 download=True, transform=transform)
13 return trainset
15def LoadCifar10DatasetTest(save, transform):
16 return torchvision.datasets.CIFAR10(root=save, train=False,
17 download=False, transform=transform)
19def GetCustTransform():
20 transform_train = transforms.Compose([
21 transforms.RandomRotation(20),
22 transforms.RandomCrop(32, (2, 2), pad_if_needed=False, padding_mode='constant'),
23 transforms.ToTensor(),
24 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
25 return transform_train
27def Dataloader_train_valid(save, batch_size):
See utils/dataloader.py for data augmentations
30 transform_train_valid = GetCustTransform()
Get Cifar 10 Datasets
33 trainset = LoadCifar10DatasetTrain(save, transform_train_valid)
34 train_val_abs = int(len(trainset) * 0.8)
35 train_subset, val_subset = random_split(trainset, [train_val_abs, len(trainset) - train_val_abs])
Get Cifar 10 Dataloaders
38 trainloader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size,
39 shuffle=True, num_workers=4)
40
41 valloader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size,
42 shuffle=True, num_workers=4)
43 return trainloader, valloader
45def Dataloader_train(save, batch_size):
See utils/dataloader.py for data augmentations
48 transform_train = GetCustTransform()
Get Cifar 10 Datasets
51 trainset = LoadCifar10DatasetTrain(save, transform_train)
Get Cifar 10 Dataloaders
53 trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
54 shuffle=True, num_workers=4)
55
56 return trainloader
58def Dataloader_test(save, batch_size):
transformation test set
61 transform_test = transforms.Compose(
62 [transforms.ToTensor(),
63 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
initialize test dataset and dataloader
66 testset = LoadCifar10DatasetTest(save, transform_test)
67 testloader = torch.utils.data.DataLoader(testset, batch_size=64,
68 shuffle=False, num_workers=4)
69
70 return testloader
72def imshow(im):
73 image = im.cpu().clone().detach().numpy()
74 image = image.transpose(1, 2, 0)
75 image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
76 plt.imshow(image)
77 plt.show()
79def imretrun(im):
80 image = im.cpu().clone().detach().numpy()
81 image = image.transpose(1, 2, 0)
82 image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
83 return image