mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 02:07:56 +08:00
83 lines
3.0 KiB
Python
83 lines
3.0 KiB
Python
#!/bin/python
|
|
|
|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import Dataset, random_split
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
def LoadCifar10DatasetTrain(save, transform=None):
|
|
trainset = torchvision.datasets.CIFAR10(root=save, train=True,
|
|
download=True, transform=transform)
|
|
return trainset
|
|
|
|
def LoadCifar10DatasetTest(save, transform):
|
|
return torchvision.datasets.CIFAR10(root=save, train=False,
|
|
download=False, transform=transform)
|
|
|
|
def GetCustTransform():
|
|
transform_train = transforms.Compose([
|
|
transforms.RandomRotation(20),
|
|
transforms.RandomCrop(32, (2, 2), pad_if_needed=False, padding_mode='constant'),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
return transform_train
|
|
|
|
def Dataloader_train_valid(save, batch_size):
|
|
|
|
# See utils/dataloader.py for data augmentations
|
|
transform_train_valid = GetCustTransform()
|
|
|
|
# Get Cifar 10 Datasets
|
|
trainset = LoadCifar10DatasetTrain(save, transform_train_valid)
|
|
train_val_abs = int(len(trainset) * 0.8)
|
|
train_subset, val_subset = random_split(trainset, [train_val_abs, len(trainset) - train_val_abs])
|
|
|
|
# Get Cifar 10 Dataloaders
|
|
trainloader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size,
|
|
shuffle=True, num_workers=4)
|
|
|
|
valloader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size,
|
|
shuffle=True, num_workers=4)
|
|
return trainloader, valloader
|
|
|
|
def Dataloader_train(save, batch_size):
|
|
|
|
# See utils/dataloader.py for data augmentations
|
|
transform_train = GetCustTransform()
|
|
|
|
# Get Cifar 10 Datasets
|
|
trainset = LoadCifar10DatasetTrain(save, transform_train)
|
|
# Get Cifar 10 Dataloaders
|
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
|
|
shuffle=True, num_workers=4)
|
|
|
|
return trainloader
|
|
|
|
def Dataloader_test(save, batch_size):
|
|
|
|
# transformation test set
|
|
transform_test = transforms.Compose(
|
|
[transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
|
|
# initialize test dataset and dataloader
|
|
testset = LoadCifar10DatasetTest(save, transform_test)
|
|
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
|
|
shuffle=False, num_workers=4)
|
|
|
|
return testloader
|
|
|
|
def imshow(im):
|
|
image = im.cpu().clone().detach().numpy()
|
|
image = image.transpose(1, 2, 0)
|
|
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
|
|
plt.imshow(image)
|
|
plt.show()
|
|
|
|
def imretrun(im):
|
|
image = im.cpu().clone().detach().numpy()
|
|
image = image.transpose(1, 2, 0)
|
|
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) # unnormalize
|
|
return image |