mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 20:28:41 +08:00
✨ cycle gan auto download data
This commit is contained in:
@ -11,20 +11,13 @@ summary: >
|
||||
This is an implementation of paper
|
||||
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593).
|
||||
|
||||
### Running the experiment
|
||||
To train the model you need to download datasets from
|
||||
`https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip`
|
||||
and extract them into folder `labml_nn/data/cycle_gan/[DATASET NAME]`.
|
||||
You will also have to `dataset_name` configuration to `[DATASET NAME]`.
|
||||
This defaults to `monet2photo`.
|
||||
|
||||
I've taken pieces of code from [https://github.com/eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN).
|
||||
It is a very good resource if you want to checkout other GAN variations too.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import random
|
||||
from pathlib import PurePath, Path
|
||||
import zipfile
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
@ -32,10 +25,11 @@ import torch.nn as nn
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.utils import make_grid, save_image
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
from labml import lab, tracker, experiment, monit, configs
|
||||
from labml import lab, tracker, experiment, monit
|
||||
from labml.configs import BaseConfigs
|
||||
from labml.utils.download import download_file
|
||||
from labml.utils.pytorch import get_modules
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.module import Module
|
||||
@ -201,13 +195,29 @@ class ImageDataset(Dataset):
|
||||
Dataset to load images
|
||||
"""
|
||||
|
||||
def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str):
|
||||
root = Path(root)
|
||||
@staticmethod
|
||||
def download(dataset_name: str):
|
||||
url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
|
||||
root = lab.get_data_path() / 'cycle_gan'
|
||||
if not root.exists():
|
||||
root.mkdir(parents=True)
|
||||
archive = root / f'{dataset_name}.zip'
|
||||
download_file(url, archive)
|
||||
with zipfile.ZipFile(archive, 'r') as f:
|
||||
f.extractall(root)
|
||||
|
||||
def __init__(self, dataset_name: str, transforms_, unaligned: bool, mode: str):
|
||||
root = lab.get_data_path() / 'cycle_gan' / dataset_name
|
||||
if not root.exists():
|
||||
self.download(dataset_name)
|
||||
|
||||
self.transform = transforms.Compose(transforms_)
|
||||
self.unaligned = unaligned
|
||||
|
||||
self.files_a = sorted(str(f) for f in (root / f'{mode}A').iterdir())
|
||||
self.files_b = sorted(str(f) for f in (root / f'{mode}B').iterdir())
|
||||
path_a = root / f'{mode}A'
|
||||
path_b = root / f'{mode}B'
|
||||
self.files_a = sorted(str(f) for f in path_a.iterdir())
|
||||
self.files_b = sorted(str(f) for f in path_b.iterdir())
|
||||
|
||||
def __getitem__(self, index):
|
||||
return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
|
||||
@ -327,13 +337,8 @@ class Configs(BaseConfigs):
|
||||
# Arrange images along y-axis
|
||||
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
||||
|
||||
# Create folder to store sampled images
|
||||
images_path = Path(f'images/{self.dataset_name}')
|
||||
if not images_path.exists():
|
||||
images_path.mkdir(parents=True)
|
||||
|
||||
# Save grid
|
||||
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
||||
# Show samples
|
||||
plot_image(image_grid)
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
@ -363,8 +368,6 @@ class Configs(BaseConfigs):
|
||||
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||
# Location of the dataset
|
||||
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
|
||||
|
||||
# Image transformations
|
||||
transforms_ = [
|
||||
@ -377,7 +380,7 @@ class Configs(BaseConfigs):
|
||||
|
||||
# Training data loader
|
||||
self.dataloader = DataLoader(
|
||||
ImageDataset(images_path, transforms_, True, 'train'),
|
||||
ImageDataset(self.dataset_name, transforms_, True, 'train'),
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.data_loader_workers,
|
||||
@ -385,7 +388,7 @@ class Configs(BaseConfigs):
|
||||
|
||||
# Validation data loader
|
||||
self.valid_dataloader = DataLoader(
|
||||
ImageDataset(images_path, transforms_, True, "test"),
|
||||
ImageDataset(self.dataset_name, transforms_, True, "test"),
|
||||
batch_size=5,
|
||||
shuffle=True,
|
||||
num_workers=self.data_loader_workers,
|
||||
@ -634,14 +637,19 @@ def plot_image(img: torch.Tensor):
|
||||
"""
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Move tensor to CPU
|
||||
img = img.cpu()
|
||||
# Get min and max values of the image for normalization
|
||||
img_min, img_max = img.min(), img.max()
|
||||
# Scale image values to be [0...1]
|
||||
img = (img - img_min) / (img_max - img_min + 1e-5)
|
||||
# We have to change the order of dimensions to HWC.
|
||||
img = img.permute(1, 2, 0)
|
||||
# Show image
|
||||
# Show Image
|
||||
plt.imshow(img)
|
||||
# We don't need axes
|
||||
plt.axis('off')
|
||||
# Display
|
||||
plt.show()
|
||||
|
||||
|
||||
@ -678,20 +686,17 @@ def evaluate():
|
||||
|
||||
# Start the experiment
|
||||
with experiment.start():
|
||||
# Load your own data, here we try test set.
|
||||
# I was trying with yosemite photos, they look awesome.
|
||||
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
|
||||
# in the call to `experiment.configs`
|
||||
images_path = lab.get_data_path() / 'cycle_gan' / 'summer2winter_yosemite'
|
||||
|
||||
# Image transformations
|
||||
transforms_ = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
|
||||
# Load dataset
|
||||
dataset = ImageDataset(images_path, transforms_, True, 'train')
|
||||
# Load your own data, here we try test set.
|
||||
# I was trying with yosemite photos, they look awesome.
|
||||
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
|
||||
# in the call to `experiment.configs`
|
||||
dataset = ImageDataset(conf.dataset_name, transforms_, True, 'train')
|
||||
# Get an images from dataset
|
||||
x_image = dataset[10]['x']
|
||||
# Display the image
|
||||
|
||||
Reference in New Issue
Block a user