cycle gan auto download data

This commit is contained in:
Varuna Jayasiri
2021-01-23 15:53:57 +05:30
parent 83478823df
commit ebbe704d65
3 changed files with 42 additions and 37 deletions

View File

@ -11,20 +11,13 @@ summary: >
This is an implementation of paper
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593).
### Running the experiment
To train the model you need to download datasets from
`https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip`
and extract them into folder `labml_nn/data/cycle_gan/[DATASET NAME]`.
You will also have to `dataset_name` configuration to `[DATASET NAME]`.
This defaults to `monet2photo`.
I've taken pieces of code from [https://github.com/eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN).
It is a very good resource if you want to checkout other GAN variations too.
"""
import itertools
import random
from pathlib import PurePath, Path
import zipfile
from typing import Tuple
import torch
@ -32,10 +25,11 @@ import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid, save_image
from torchvision.utils import make_grid
from labml import lab, tracker, experiment, monit, configs
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs
from labml.utils.download import download_file
from labml.utils.pytorch import get_modules
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
@ -201,13 +195,29 @@ class ImageDataset(Dataset):
Dataset to load images
"""
def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str):
root = Path(root)
@staticmethod
def download(dataset_name: str):
url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
root = lab.get_data_path() / 'cycle_gan'
if not root.exists():
root.mkdir(parents=True)
archive = root / f'{dataset_name}.zip'
download_file(url, archive)
with zipfile.ZipFile(archive, 'r') as f:
f.extractall(root)
def __init__(self, dataset_name: str, transforms_, unaligned: bool, mode: str):
root = lab.get_data_path() / 'cycle_gan' / dataset_name
if not root.exists():
self.download(dataset_name)
self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned
self.files_a = sorted(str(f) for f in (root / f'{mode}A').iterdir())
self.files_b = sorted(str(f) for f in (root / f'{mode}B').iterdir())
path_a = root / f'{mode}A'
path_b = root / f'{mode}B'
self.files_a = sorted(str(f) for f in path_a.iterdir())
self.files_b = sorted(str(f) for f in path_b.iterdir())
def __getitem__(self, index):
return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
@ -327,13 +337,8 @@ class Configs(BaseConfigs):
# Arrange images along y-axis
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
# Create folder to store sampled images
images_path = Path(f'images/{self.dataset_name}')
if not images_path.exists():
images_path.mkdir(parents=True)
# Save grid
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
# Show samples
plot_image(image_grid)
def initialize(self):
"""
@ -363,8 +368,6 @@ class Configs(BaseConfigs):
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
# Location of the dataset
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
# Image transformations
transforms_ = [
@ -377,7 +380,7 @@ class Configs(BaseConfigs):
# Training data loader
self.dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, 'train'),
ImageDataset(self.dataset_name, transforms_, True, 'train'),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.data_loader_workers,
@ -385,7 +388,7 @@ class Configs(BaseConfigs):
# Validation data loader
self.valid_dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, "test"),
ImageDataset(self.dataset_name, transforms_, True, "test"),
batch_size=5,
shuffle=True,
num_workers=self.data_loader_workers,
@ -634,14 +637,19 @@ def plot_image(img: torch.Tensor):
"""
from matplotlib import pyplot as plt
# Move tensor to CPU
img = img.cpu()
# Get min and max values of the image for normalization
img_min, img_max = img.min(), img.max()
# Scale image values to be [0...1]
img = (img - img_min) / (img_max - img_min + 1e-5)
# We have to change the order of dimensions to HWC.
img = img.permute(1, 2, 0)
# Show image
# Show Image
plt.imshow(img)
# We don't need axes
plt.axis('off')
# Display
plt.show()
@ -678,20 +686,17 @@ def evaluate():
# Start the experiment
with experiment.start():
# Load your own data, here we try test set.
# I was trying with yosemite photos, they look awesome.
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
# in the call to `experiment.configs`
images_path = lab.get_data_path() / 'cycle_gan' / 'summer2winter_yosemite'
# Image transformations
transforms_ = [
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Load dataset
dataset = ImageDataset(images_path, transforms_, True, 'train')
# Load your own data, here we try test set.
# I was trying with yosemite photos, they look awesome.
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
# in the call to `experiment.configs`
dataset = ImageDataset(conf.dataset_name, transforms_, True, 'train')
# Get an images from dataset
x_image = dataset[10]['x']
# Display the image