cycle gan auto download data

This commit is contained in:
Varuna Jayasiri
2021-01-23 15:53:57 +05:30
parent 83478823df
commit ebbe704d65
3 changed files with 42 additions and 37 deletions

View File

@ -11,20 +11,13 @@ summary: >
This is an implementation of paper This is an implementation of paper
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593). [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593).
### Running the experiment
To train the model you need to download datasets from
`https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip`
and extract them into folder `labml_nn/data/cycle_gan/[DATASET NAME]`.
You will also have to `dataset_name` configuration to `[DATASET NAME]`.
This defaults to `monet2photo`.
I've taken pieces of code from [https://github.com/eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN). I've taken pieces of code from [https://github.com/eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN).
It is a very good resource if you want to checkout other GAN variations too. It is a very good resource if you want to checkout other GAN variations too.
""" """
import itertools import itertools
import random import random
from pathlib import PurePath, Path import zipfile
from typing import Tuple from typing import Tuple
import torch import torch
@ -32,10 +25,11 @@ import torch.nn as nn
import torchvision.transforms as transforms import torchvision.transforms as transforms
from PIL import Image from PIL import Image
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid
from labml import lab, tracker, experiment, monit, configs from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs from labml.configs import BaseConfigs
from labml.utils.download import download_file
from labml.utils.pytorch import get_modules from labml.utils.pytorch import get_modules
from labml_helpers.device import DeviceConfigs from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module from labml_helpers.module import Module
@ -201,13 +195,29 @@ class ImageDataset(Dataset):
Dataset to load images Dataset to load images
""" """
def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str): @staticmethod
root = Path(root) def download(dataset_name: str):
url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
root = lab.get_data_path() / 'cycle_gan'
if not root.exists():
root.mkdir(parents=True)
archive = root / f'{dataset_name}.zip'
download_file(url, archive)
with zipfile.ZipFile(archive, 'r') as f:
f.extractall(root)
def __init__(self, dataset_name: str, transforms_, unaligned: bool, mode: str):
root = lab.get_data_path() / 'cycle_gan' / dataset_name
if not root.exists():
self.download(dataset_name)
self.transform = transforms.Compose(transforms_) self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned self.unaligned = unaligned
self.files_a = sorted(str(f) for f in (root / f'{mode}A').iterdir()) path_a = root / f'{mode}A'
self.files_b = sorted(str(f) for f in (root / f'{mode}B').iterdir()) path_b = root / f'{mode}B'
self.files_a = sorted(str(f) for f in path_a.iterdir())
self.files_b = sorted(str(f) for f in path_b.iterdir())
def __getitem__(self, index): def __getitem__(self, index):
return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])), return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
@ -327,13 +337,8 @@ class Configs(BaseConfigs):
# Arrange images along y-axis # Arrange images along y-axis
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1) image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
# Create folder to store sampled images # Show samples
images_path = Path(f'images/{self.dataset_name}') plot_image(image_grid)
if not images_path.exists():
images_path.mkdir(parents=True)
# Save grid
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
def initialize(self): def initialize(self):
""" """
@ -363,8 +368,6 @@ class Configs(BaseConfigs):
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
# Location of the dataset
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
# Image transformations # Image transformations
transforms_ = [ transforms_ = [
@ -377,7 +380,7 @@ class Configs(BaseConfigs):
# Training data loader # Training data loader
self.dataloader = DataLoader( self.dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, 'train'), ImageDataset(self.dataset_name, transforms_, True, 'train'),
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=True, shuffle=True,
num_workers=self.data_loader_workers, num_workers=self.data_loader_workers,
@ -385,7 +388,7 @@ class Configs(BaseConfigs):
# Validation data loader # Validation data loader
self.valid_dataloader = DataLoader( self.valid_dataloader = DataLoader(
ImageDataset(images_path, transforms_, True, "test"), ImageDataset(self.dataset_name, transforms_, True, "test"),
batch_size=5, batch_size=5,
shuffle=True, shuffle=True,
num_workers=self.data_loader_workers, num_workers=self.data_loader_workers,
@ -634,14 +637,19 @@ def plot_image(img: torch.Tensor):
""" """
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
# Move tensor to CPU
img = img.cpu()
# Get min and max values of the image for normalization # Get min and max values of the image for normalization
img_min, img_max = img.min(), img.max() img_min, img_max = img.min(), img.max()
# Scale image values to be [0...1] # Scale image values to be [0...1]
img = (img - img_min) / (img_max - img_min + 1e-5) img = (img - img_min) / (img_max - img_min + 1e-5)
# We have to change the order of dimensions to HWC. # We have to change the order of dimensions to HWC.
img = img.permute(1, 2, 0) img = img.permute(1, 2, 0)
# Show image # Show Image
plt.imshow(img) plt.imshow(img)
# We don't need axes
plt.axis('off')
# Display
plt.show() plt.show()
@ -678,20 +686,17 @@ def evaluate():
# Start the experiment # Start the experiment
with experiment.start(): with experiment.start():
# Load your own data, here we try test set.
# I was trying with yosemite photos, they look awesome.
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
# in the call to `experiment.configs`
images_path = lab.get_data_path() / 'cycle_gan' / 'summer2winter_yosemite'
# Image transformations # Image transformations
transforms_ = [ transforms_ = [
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
] ]
# Load dataset # Load your own data, here we try test set.
dataset = ImageDataset(images_path, transforms_, True, 'train') # I was trying with yosemite photos, they look awesome.
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
# in the call to `experiment.configs`
dataset = ImageDataset(conf.dataset_name, transforms_, True, 'train')
# Get an images from dataset # Get an images from dataset
x_image = dataset[10]['x'] x_image = dataset[10]['x']
# Display the image # Display the image

View File

@ -1,5 +1,5 @@
torch>=1.6 torch>=1.6
labml>=0.4.86 labml>=0.4.94
labml-helpers>=0.4.72 labml-helpers>=0.4.72
torchvision torchvision
numpy>=1.16.3 numpy>=1.16.3

View File

@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup( setuptools.setup(
name='labml-nn', name='labml-nn',
version='0.4.79', version='0.4.80',
author="Varuna Jayasiri, Nipun Wijerathne", author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com", author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.", description="A collection of PyTorch implementations of neural network architectures and layers.",
@ -20,7 +20,7 @@ setuptools.setup(
'labml_helpers', 'labml_helpers.*', 'labml_helpers', 'labml_helpers.*',
'test', 'test',
'test.*')), 'test.*')),
install_requires=['labml>=0.4.86', install_requires=['labml>=0.4.94',
'labml-helpers>=0.4.72', 'labml-helpers>=0.4.72',
'torch', 'torch',
'einops', 'einops',