mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 21:40:15 +08:00
✨ cycle gan auto download data
This commit is contained in:
@ -11,20 +11,13 @@ summary: >
|
|||||||
This is an implementation of paper
|
This is an implementation of paper
|
||||||
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593).
|
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593).
|
||||||
|
|
||||||
### Running the experiment
|
|
||||||
To train the model you need to download datasets from
|
|
||||||
`https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip`
|
|
||||||
and extract them into folder `labml_nn/data/cycle_gan/[DATASET NAME]`.
|
|
||||||
You will also have to `dataset_name` configuration to `[DATASET NAME]`.
|
|
||||||
This defaults to `monet2photo`.
|
|
||||||
|
|
||||||
I've taken pieces of code from [https://github.com/eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN).
|
I've taken pieces of code from [https://github.com/eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN).
|
||||||
It is a very good resource if you want to checkout other GAN variations too.
|
It is a very good resource if you want to checkout other GAN variations too.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import random
|
import random
|
||||||
from pathlib import PurePath, Path
|
import zipfile
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -32,10 +25,11 @@ import torch.nn as nn
|
|||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from torchvision.utils import make_grid, save_image
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
from labml import lab, tracker, experiment, monit, configs
|
from labml import lab, tracker, experiment, monit
|
||||||
from labml.configs import BaseConfigs
|
from labml.configs import BaseConfigs
|
||||||
|
from labml.utils.download import download_file
|
||||||
from labml.utils.pytorch import get_modules
|
from labml.utils.pytorch import get_modules
|
||||||
from labml_helpers.device import DeviceConfigs
|
from labml_helpers.device import DeviceConfigs
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
@ -201,13 +195,29 @@ class ImageDataset(Dataset):
|
|||||||
Dataset to load images
|
Dataset to load images
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str):
|
@staticmethod
|
||||||
root = Path(root)
|
def download(dataset_name: str):
|
||||||
|
url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
|
||||||
|
root = lab.get_data_path() / 'cycle_gan'
|
||||||
|
if not root.exists():
|
||||||
|
root.mkdir(parents=True)
|
||||||
|
archive = root / f'{dataset_name}.zip'
|
||||||
|
download_file(url, archive)
|
||||||
|
with zipfile.ZipFile(archive, 'r') as f:
|
||||||
|
f.extractall(root)
|
||||||
|
|
||||||
|
def __init__(self, dataset_name: str, transforms_, unaligned: bool, mode: str):
|
||||||
|
root = lab.get_data_path() / 'cycle_gan' / dataset_name
|
||||||
|
if not root.exists():
|
||||||
|
self.download(dataset_name)
|
||||||
|
|
||||||
self.transform = transforms.Compose(transforms_)
|
self.transform = transforms.Compose(transforms_)
|
||||||
self.unaligned = unaligned
|
self.unaligned = unaligned
|
||||||
|
|
||||||
self.files_a = sorted(str(f) for f in (root / f'{mode}A').iterdir())
|
path_a = root / f'{mode}A'
|
||||||
self.files_b = sorted(str(f) for f in (root / f'{mode}B').iterdir())
|
path_b = root / f'{mode}B'
|
||||||
|
self.files_a = sorted(str(f) for f in path_a.iterdir())
|
||||||
|
self.files_b = sorted(str(f) for f in path_b.iterdir())
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
|
return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
|
||||||
@ -327,13 +337,8 @@ class Configs(BaseConfigs):
|
|||||||
# Arrange images along y-axis
|
# Arrange images along y-axis
|
||||||
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
||||||
|
|
||||||
# Create folder to store sampled images
|
# Show samples
|
||||||
images_path = Path(f'images/{self.dataset_name}')
|
plot_image(image_grid)
|
||||||
if not images_path.exists():
|
|
||||||
images_path.mkdir(parents=True)
|
|
||||||
|
|
||||||
# Save grid
|
|
||||||
save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
|
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
@ -363,8 +368,6 @@ class Configs(BaseConfigs):
|
|||||||
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||||
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||||
# Location of the dataset
|
|
||||||
images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name
|
|
||||||
|
|
||||||
# Image transformations
|
# Image transformations
|
||||||
transforms_ = [
|
transforms_ = [
|
||||||
@ -377,7 +380,7 @@ class Configs(BaseConfigs):
|
|||||||
|
|
||||||
# Training data loader
|
# Training data loader
|
||||||
self.dataloader = DataLoader(
|
self.dataloader = DataLoader(
|
||||||
ImageDataset(images_path, transforms_, True, 'train'),
|
ImageDataset(self.dataset_name, transforms_, True, 'train'),
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=self.data_loader_workers,
|
num_workers=self.data_loader_workers,
|
||||||
@ -385,7 +388,7 @@ class Configs(BaseConfigs):
|
|||||||
|
|
||||||
# Validation data loader
|
# Validation data loader
|
||||||
self.valid_dataloader = DataLoader(
|
self.valid_dataloader = DataLoader(
|
||||||
ImageDataset(images_path, transforms_, True, "test"),
|
ImageDataset(self.dataset_name, transforms_, True, "test"),
|
||||||
batch_size=5,
|
batch_size=5,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=self.data_loader_workers,
|
num_workers=self.data_loader_workers,
|
||||||
@ -634,14 +637,19 @@ def plot_image(img: torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
# Move tensor to CPU
|
||||||
|
img = img.cpu()
|
||||||
# Get min and max values of the image for normalization
|
# Get min and max values of the image for normalization
|
||||||
img_min, img_max = img.min(), img.max()
|
img_min, img_max = img.min(), img.max()
|
||||||
# Scale image values to be [0...1]
|
# Scale image values to be [0...1]
|
||||||
img = (img - img_min) / (img_max - img_min + 1e-5)
|
img = (img - img_min) / (img_max - img_min + 1e-5)
|
||||||
# We have to change the order of dimensions to HWC.
|
# We have to change the order of dimensions to HWC.
|
||||||
img = img.permute(1, 2, 0)
|
img = img.permute(1, 2, 0)
|
||||||
# Show image
|
# Show Image
|
||||||
plt.imshow(img)
|
plt.imshow(img)
|
||||||
|
# We don't need axes
|
||||||
|
plt.axis('off')
|
||||||
|
# Display
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
@ -678,20 +686,17 @@ def evaluate():
|
|||||||
|
|
||||||
# Start the experiment
|
# Start the experiment
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
# Load your own data, here we try test set.
|
|
||||||
# I was trying with yosemite photos, they look awesome.
|
|
||||||
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
|
|
||||||
# in the call to `experiment.configs`
|
|
||||||
images_path = lab.get_data_path() / 'cycle_gan' / 'summer2winter_yosemite'
|
|
||||||
|
|
||||||
# Image transformations
|
# Image transformations
|
||||||
transforms_ = [
|
transforms_ = [
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Load dataset
|
# Load your own data, here we try test set.
|
||||||
dataset = ImageDataset(images_path, transforms_, True, 'train')
|
# I was trying with yosemite photos, they look awesome.
|
||||||
|
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
|
||||||
|
# in the call to `experiment.configs`
|
||||||
|
dataset = ImageDataset(conf.dataset_name, transforms_, True, 'train')
|
||||||
# Get an images from dataset
|
# Get an images from dataset
|
||||||
x_image = dataset[10]['x']
|
x_image = dataset[10]['x']
|
||||||
# Display the image
|
# Display the image
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
torch>=1.6
|
torch>=1.6
|
||||||
labml>=0.4.86
|
labml>=0.4.94
|
||||||
labml-helpers>=0.4.72
|
labml-helpers>=0.4.72
|
||||||
torchvision
|
torchvision
|
||||||
numpy>=1.16.3
|
numpy>=1.16.3
|
||||||
|
|||||||
4
setup.py
4
setup.py
@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
|
|||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='labml-nn',
|
name='labml-nn',
|
||||||
version='0.4.79',
|
version='0.4.80',
|
||||||
author="Varuna Jayasiri, Nipun Wijerathne",
|
author="Varuna Jayasiri, Nipun Wijerathne",
|
||||||
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
||||||
description="A collection of PyTorch implementations of neural network architectures and layers.",
|
description="A collection of PyTorch implementations of neural network architectures and layers.",
|
||||||
@ -20,7 +20,7 @@ setuptools.setup(
|
|||||||
'labml_helpers', 'labml_helpers.*',
|
'labml_helpers', 'labml_helpers.*',
|
||||||
'test',
|
'test',
|
||||||
'test.*')),
|
'test.*')),
|
||||||
install_requires=['labml>=0.4.86',
|
install_requires=['labml>=0.4.94',
|
||||||
'labml-helpers>=0.4.72',
|
'labml-helpers>=0.4.72',
|
||||||
'torch',
|
'torch',
|
||||||
'einops',
|
'einops',
|
||||||
|
|||||||
Reference in New Issue
Block a user