mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 16:50:39 +08:00
766 lines
28 KiB
Python
766 lines
28 KiB
Python
"""
|
|
---
|
|
title: Cycle GAN
|
|
summary: >
|
|
A simple PyTorch implementation/tutorial of Cycle GAN introduced in paper
|
|
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
|
|
---
|
|
|
|
# Cycle GAN
|
|
|
|
This is an implementation of paper
|
|
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593).
|
|
|
|
I've taken pieces of code from [https://github.com/eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN).
|
|
It is a very good resource if you want to checkout other GAN variations too.
|
|
|
|
Cycle GAN does image-to-image translation.
|
|
It trains a model to translate an image from given distribution to another, say images of class A and B
|
|
Images of a certain distribution could be things like images of a certain style, or nature.
|
|
The models do not need paired images between A and B.
|
|
Just a set of images of each class is enough.
|
|
This works very well on changing between image styles, lighting changes, pattern changes, etc.
|
|
For example, changing summer to winter, painting style to photos, and horses to zebras.
|
|
|
|
Cycle GAN trains two generator models and two discriminator models.
|
|
One generator translates images from A to B and the other from B to A.
|
|
The discriminators test whether the generated images look real.
|
|
|
|
This file contains the model code as well as training code.
|
|
We also have a Google Colab notebook.
|
|
|
|
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/gan/cycle_gan.ipynb)
|
|
[](https://web.lab-ml.com/run?uuid=93b11a665d6811ebaac80242ac1c0002)
|
|
"""
|
|
|
|
import itertools
|
|
import random
|
|
import zipfile
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision.utils import make_grid
|
|
|
|
from labml import lab, tracker, experiment, monit
|
|
from labml.configs import BaseConfigs
|
|
from labml.utils.download import download_file
|
|
from labml.utils.pytorch import get_modules
|
|
from labml_helpers.device import DeviceConfigs
|
|
from labml_helpers.module import Module
|
|
|
|
|
|
class GeneratorResNet(Module):
|
|
"""
|
|
The generator is a residual network.
|
|
"""
|
|
|
|
def __init__(self, input_channels: int, n_residual_blocks: int):
|
|
super().__init__()
|
|
# This first block runs a $7\times7$ convolution and maps the image to
|
|
# a feature map.
|
|
# The output feature map has same height and width because we have
|
|
# a padding of $3$.
|
|
# Reflection padding is used because it gives better image quality at edges.
|
|
#
|
|
# `inplace=True` in `ReLU` saves a little bit of memory.
|
|
out_features = 64
|
|
layers = [
|
|
nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
|
|
nn.InstanceNorm2d(out_features),
|
|
nn.ReLU(inplace=True),
|
|
]
|
|
in_features = out_features
|
|
|
|
# We down-sample with two $3 \times 3$ convolutions
|
|
# with stride of 2
|
|
for _ in range(2):
|
|
out_features *= 2
|
|
layers += [
|
|
nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
|
|
nn.InstanceNorm2d(out_features),
|
|
nn.ReLU(inplace=True),
|
|
]
|
|
in_features = out_features
|
|
|
|
# We take this through `n_residual_blocks`.
|
|
# This module is defined below.
|
|
for _ in range(n_residual_blocks):
|
|
layers += [ResidualBlock(out_features)]
|
|
|
|
# Then the resulting feature map is up-sampled
|
|
# to match the original image height and width.
|
|
for _ in range(2):
|
|
out_features //= 2
|
|
layers += [
|
|
nn.Upsample(scale_factor=2),
|
|
nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
|
|
nn.InstanceNorm2d(out_features),
|
|
nn.ReLU(inplace=True),
|
|
]
|
|
in_features = out_features
|
|
|
|
# Finally we map the feature map to an RGB image
|
|
layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
|
|
|
|
# Create a sequential module with the layers
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
# Initialize weights to $\mathcal{N}(0, 0.2)$
|
|
self.apply(weights_init_normal)
|
|
|
|
def __call__(self, x):
|
|
return self.layers(x)
|
|
|
|
|
|
class ResidualBlock(Module):
|
|
"""
|
|
This is the residual block, with two convolution layers.
|
|
"""
|
|
|
|
def __init__(self, in_features: int):
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
|
|
nn.InstanceNorm2d(in_features),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
|
|
nn.InstanceNorm2d(in_features),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
def __call__(self, x: torch.Tensor):
|
|
return x + self.block(x)
|
|
|
|
|
|
class Discriminator(Module):
|
|
"""
|
|
This is the discriminator.
|
|
"""
|
|
|
|
def __init__(self, input_shape: Tuple[int, int, int]):
|
|
super().__init__()
|
|
channels, height, width = input_shape
|
|
|
|
# Output of the discriminator is also map of probabilities*
|
|
# whether each region of the image is real or generated
|
|
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
|
|
|
|
self.layers = nn.Sequential(
|
|
# Each of these blocks will shrink the height and width by a factor of 2
|
|
DiscriminatorBlock(channels, 64, normalize=False),
|
|
DiscriminatorBlock(64, 128),
|
|
DiscriminatorBlock(128, 256),
|
|
DiscriminatorBlock(256, 512),
|
|
# Zero pad on top and left to keep the output height and width same
|
|
# with the $4 \times 4$ kernel
|
|
nn.ZeroPad2d((1, 0, 1, 0)),
|
|
nn.Conv2d(512, 1, kernel_size=4, padding=1)
|
|
)
|
|
|
|
# Initialize weights to $\mathcal{N}(0, 0.2)$
|
|
self.apply(weights_init_normal)
|
|
|
|
def forward(self, img):
|
|
return self.layers(img)
|
|
|
|
|
|
class DiscriminatorBlock(Module):
|
|
"""
|
|
This is the discriminator block module.
|
|
It does a convolution, an optional normalization, and a leaky relu.
|
|
|
|
It shrinks the height and width of the input feature map by half.
|
|
"""
|
|
|
|
def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
|
|
super().__init__()
|
|
layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
|
|
if normalize:
|
|
layers.append(nn.InstanceNorm2d(out_filters))
|
|
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
def __call__(self, x: torch.Tensor):
|
|
return self.layers(x)
|
|
|
|
|
|
def weights_init_normal(m):
|
|
"""
|
|
Initialize convolution layer weights to $\mathcal{N}(0, 0.2)$
|
|
"""
|
|
classname = m.__class__.__name__
|
|
if classname.find("Conv") != -1:
|
|
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
|
|
|
|
|
|
def load_image(path: str):
|
|
"""
|
|
Loads an image and change to RGB if in grey-scale.
|
|
"""
|
|
image = Image.open(path)
|
|
if image.mode != 'RGB':
|
|
image = Image.new("RGB", image.size).paste(image)
|
|
|
|
return image
|
|
|
|
|
|
class ImageDataset(Dataset):
|
|
"""
|
|
### Dataset to load images
|
|
"""
|
|
|
|
@staticmethod
|
|
def download(dataset_name: str):
|
|
"""
|
|
#### Download dataset and extract data
|
|
"""
|
|
# URL
|
|
url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'
|
|
# Download folder
|
|
root = lab.get_data_path() / 'cycle_gan'
|
|
if not root.exists():
|
|
root.mkdir(parents=True)
|
|
# Download destination
|
|
archive = root / f'{dataset_name}.zip'
|
|
# Download file (generally ~100MB)
|
|
download_file(url, archive)
|
|
# Extract the archive
|
|
with zipfile.ZipFile(archive, 'r') as f:
|
|
f.extractall(root)
|
|
|
|
def __init__(self, dataset_name: str, transforms_, mode: str):
|
|
"""
|
|
#### Initialize the dataset
|
|
|
|
* `dataset_name` is the name of the dataset
|
|
* `transforms_` is the set of image transforms
|
|
* `mode` is either `train` or `test`
|
|
"""
|
|
# Dataset path
|
|
root = lab.get_data_path() / 'cycle_gan' / dataset_name
|
|
# Download if missing
|
|
if not root.exists():
|
|
self.download(dataset_name)
|
|
|
|
# Image transforms
|
|
self.transform = transforms.Compose(transforms_)
|
|
|
|
# Get image paths
|
|
path_a = root / f'{mode}A'
|
|
path_b = root / f'{mode}B'
|
|
self.files_a = sorted(str(f) for f in path_a.iterdir())
|
|
self.files_b = sorted(str(f) for f in path_b.iterdir())
|
|
|
|
def __getitem__(self, index):
|
|
# Return a pair of images.
|
|
# These pairs get batched together, and they do not act like pairs in training
|
|
# So it is kind of ok that we always keep giving the same pair.
|
|
return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
|
|
"y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
|
|
|
|
def __len__(self):
|
|
# Number of images in the dataset
|
|
return max(len(self.files_a), len(self.files_b))
|
|
|
|
|
|
class ReplayBuffer:
|
|
"""
|
|
### Replay Buffer
|
|
|
|
Replay buffer is used to train the discriminator.
|
|
Generated images are added to the replay buffer and sampled from it.
|
|
|
|
The replay buffer returns the newly added image with a probability of $0.5$.
|
|
Otherwise it sends an older generated image and and replaces the older image
|
|
with the new generated image.
|
|
|
|
This is done to reduce model oscillation.
|
|
"""
|
|
|
|
def __init__(self, max_size: int = 50):
|
|
self.max_size = max_size
|
|
self.data = []
|
|
|
|
def push_and_pop(self, data: torch.Tensor):
|
|
"""Add/retrieve an image"""
|
|
data = data.detach()
|
|
res = []
|
|
for element in data:
|
|
if len(self.data) < self.max_size:
|
|
self.data.append(element)
|
|
res.append(element)
|
|
else:
|
|
if random.uniform(0, 1) > 0.5:
|
|
i = random.randint(0, self.max_size - 1)
|
|
res.append(self.data[i].clone())
|
|
self.data[i] = element
|
|
else:
|
|
res.append(element)
|
|
return torch.stack(res)
|
|
|
|
|
|
class Configs(BaseConfigs):
|
|
"""## Configurations"""
|
|
|
|
# `DeviceConfigs` will pick a GPU if available
|
|
device: torch.device = DeviceConfigs()
|
|
|
|
# Hyper-parameters
|
|
epochs: int = 200
|
|
dataset_name: str = 'monet2photo'
|
|
batch_size: int = 1
|
|
|
|
data_loader_workers = 8
|
|
|
|
learning_rate = 0.0002
|
|
adam_betas = (0.5, 0.999)
|
|
decay_start = 100
|
|
|
|
# The paper suggests using a least-squares loss instead of
|
|
# negative log-likelihood, at it is found to be more stable.
|
|
gan_loss = torch.nn.MSELoss()
|
|
|
|
# L1 loss is used for cycle loss and identity loss
|
|
cycle_loss = torch.nn.L1Loss()
|
|
identity_loss = torch.nn.L1Loss()
|
|
|
|
# Image dimensions
|
|
img_height = 256
|
|
img_width = 256
|
|
img_channels = 3
|
|
|
|
# Number of residual blocks in the generator
|
|
n_residual_blocks = 9
|
|
|
|
# Loss coefficients
|
|
cyclic_loss_coefficient = 10.0
|
|
identity_loss_coefficient = 5.
|
|
|
|
sample_interval = 500
|
|
|
|
# Models
|
|
generator_xy: GeneratorResNet
|
|
generator_yx: GeneratorResNet
|
|
discriminator_x: Discriminator
|
|
discriminator_y: Discriminator
|
|
|
|
# Optimizers
|
|
generator_optimizer: torch.optim.Adam
|
|
discriminator_optimizer: torch.optim.Adam
|
|
|
|
# Learning rate schedules
|
|
generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
|
|
discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
|
|
|
|
# Data loaders
|
|
dataloader: DataLoader
|
|
valid_dataloader: DataLoader
|
|
|
|
def sample_images(self, n: int):
|
|
"""Generate samples from test set and save them"""
|
|
batch = next(iter(self.valid_dataloader))
|
|
self.generator_xy.eval()
|
|
self.generator_yx.eval()
|
|
with torch.no_grad():
|
|
data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
|
|
gen_y = self.generator_xy(data_x)
|
|
gen_x = self.generator_yx(data_y)
|
|
|
|
# Arrange images along x-axis
|
|
data_x = make_grid(data_x, nrow=5, normalize=True)
|
|
data_y = make_grid(data_y, nrow=5, normalize=True)
|
|
gen_x = make_grid(gen_x, nrow=5, normalize=True)
|
|
gen_y = make_grid(gen_y, nrow=5, normalize=True)
|
|
|
|
# Arrange images along y-axis
|
|
image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)
|
|
|
|
# Show samples
|
|
plot_image(image_grid)
|
|
|
|
def initialize(self):
|
|
"""
|
|
## Initialize models and data loaders
|
|
"""
|
|
input_shape = (self.img_channels, self.img_height, self.img_width)
|
|
|
|
# Create the models
|
|
self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
|
self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
|
|
self.discriminator_x = Discriminator(input_shape).to(self.device)
|
|
self.discriminator_y = Discriminator(input_shape).to(self.device)
|
|
|
|
# Create the optmizers
|
|
self.generator_optimizer = torch.optim.Adam(
|
|
itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
|
|
lr=self.learning_rate, betas=self.adam_betas)
|
|
self.discriminator_optimizer = torch.optim.Adam(
|
|
itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
|
|
lr=self.learning_rate, betas=self.adam_betas)
|
|
|
|
# Create the learning rate schedules.
|
|
# The learning rate stars flat until `decay_start` epochs,
|
|
# and then linearly reduces to $0$ at end of training.
|
|
decay_epochs = self.epochs - self.decay_start
|
|
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
|
self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
|
|
|
# Image transformations
|
|
transforms_ = [
|
|
transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
|
|
transforms.RandomCrop((self.img_height, self.img_width)),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
]
|
|
|
|
# Training data loader
|
|
self.dataloader = DataLoader(
|
|
ImageDataset(self.dataset_name, transforms_, 'train'),
|
|
batch_size=self.batch_size,
|
|
shuffle=True,
|
|
num_workers=self.data_loader_workers,
|
|
)
|
|
|
|
# Validation data loader
|
|
self.valid_dataloader = DataLoader(
|
|
ImageDataset(self.dataset_name, transforms_, "test"),
|
|
batch_size=5,
|
|
shuffle=True,
|
|
num_workers=self.data_loader_workers,
|
|
)
|
|
|
|
def run(self):
|
|
"""
|
|
## Training
|
|
|
|
We aim to solve:
|
|
$$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$
|
|
|
|
where,
|
|
\begin{align}
|
|
\mathcal{L}(G, F, D_X, D_Y)
|
|
&= \mathcal{L}_{GAN}(G, D_Y, X, Y) \\
|
|
&+ \mathcal{L}_{GAN}(F, D_X, Y, X) \\
|
|
&+ \lambda_1 \mathcal{L}_{cyc}(G, F) \\
|
|
&+ \lambda_2 \mathcal{L}_{identity}(G, F) \\
|
|
\\
|
|
\mathcal{L}_{GAN}(G, F, D_Y, X, Y)
|
|
&= \mathbb{E}_{y \sim p_{data}(y)} \Big[log D_Y(y)\Big] \\
|
|
&+ \mathbb{E}_{x \sim p_{data}(x)} \bigg[log\Big(1 - D_Y(G(x))\Big)\bigg] \\
|
|
&+ \mathbb{E}_{x \sim p_{data}(x)} \Big[log D_X(x)\Big] \\
|
|
&+ \mathbb{E}_{y \sim p_{data}(y)} \bigg[log\Big(1 - D_X(F(y))\Big)\bigg] \\
|
|
\\
|
|
\mathcal{L}_{cyc}(G, F)
|
|
&= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(G(x)) - x \lVert_1\Big] \\
|
|
&+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(F(y)) - y \rVert_1\Big] \\
|
|
\\
|
|
\mathcal{L}_{identity}(G, F)
|
|
&= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(x) - x \lVert_1\Big] \\
|
|
&+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(y) - y \rVert_1\Big] \\
|
|
\end{align}
|
|
|
|
$\mathcal{L}_{GAN}$ is the generative adversarial loss from the original
|
|
GAN paper.
|
|
|
|
$\mathcal{L}_{cyc}$ is the cyclic loss, where we try to get $F(G(x))$ to be similar to $x$,
|
|
and $G(F(y))$ to be similar to $y$.
|
|
Basically if the two generators (transformations) are applied in series it should give back the
|
|
original image.
|
|
This is the main contribution of this paper.
|
|
It train the generators to generate an image of the other distribution that is similar to
|
|
the original image.
|
|
Without this loss $G(x)$ could generate anything that's from the distribution of $Y$.
|
|
Now it needs to generate something from the distribution of $Y$ but still have properties of $x$,
|
|
so that $F(G(x)$ can re-generate something like $x$.
|
|
|
|
$\mathcal{L}_{cyc}$ is the identity loss.
|
|
This was used to encourage the mapping to preserve color composition between
|
|
the input and the output.
|
|
|
|
To solve $G^{\*}, F^{\*}$,
|
|
discriminators $D_X$ and $D_Y$ should **ascend** on the gradient,
|
|
\begin{align}
|
|
\nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m
|
|
&\Bigg[
|
|
\log D_Y\Big(y^{(i)}\Big) \\
|
|
&+ \log \Big(1 - D_Y\Big(G\Big(x^{(i)}\Big)\Big)\Big) \\
|
|
&+ \log D_X\Big(x^{(i)}\Big) \\
|
|
& +\log\Big(1 - D_X\Big(F\Big(y^{(i)}\Big)\Big)\Big)
|
|
\Bigg]
|
|
\end{align}
|
|
That is descend on *negative* log-likelihood loss.
|
|
|
|
In order to stabilize the training the negative log- likelihood objective
|
|
was replaced by a least-squared loss -
|
|
the least-squared error of discriminator labelling real images with 1,
|
|
and generated images with 0.
|
|
So we want to descend on the gradient,
|
|
\begin{align}
|
|
\nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m
|
|
&\Bigg[
|
|
\bigg(D_Y\Big(y^{(i)}\Big) - 1\bigg)^2 \\
|
|
&+ D_Y\Big(G\Big(x^{(i)}\Big)\Big)^2 \\
|
|
&+ \bigg(D_X\Big(x^{(i)}\Big) - 1\bigg)^2 \\
|
|
&+ D_X\Big(F\Big(y^{(i)}\Big)\Big)^2
|
|
\Bigg]
|
|
\end{align}
|
|
|
|
We use least-squares for generators also.
|
|
The generators should *descend* on the gradient,
|
|
\begin{align}
|
|
\nabla_{\theta_{F, G}} \frac{1}{m} \sum_{i=1}^m
|
|
&\Bigg[
|
|
\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2 \\
|
|
&+ \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2 \\
|
|
&+ \mathcal{L}_{cyc}(G, F)
|
|
+ \mathcal{L}_{identity}(G, F)
|
|
\Bigg]
|
|
\end{align}
|
|
|
|
We use `generator_xy` for $G$ and `generator_yx$ for $F$.
|
|
We use `discriminator_x$ for $D_X$ and `discriminator_y` for $D_Y$.
|
|
"""
|
|
|
|
# Replay buffers to keep generated samples
|
|
gen_x_buffer = ReplayBuffer()
|
|
gen_y_buffer = ReplayBuffer()
|
|
|
|
# Loop through epochs
|
|
for epoch in monit.loop(self.epochs):
|
|
# Loop through the dataset
|
|
for i, batch in monit.enum('Train', self.dataloader):
|
|
# Move images to the device
|
|
data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)
|
|
|
|
# true labels equal to $1$
|
|
true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
|
|
device=self.device, requires_grad=False)
|
|
# false labels equal to $0$
|
|
false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
|
|
device=self.device, requires_grad=False)
|
|
|
|
# Train the generators.
|
|
# This returns the generated images.
|
|
gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)
|
|
|
|
# Train discriminators
|
|
self.optimize_discriminator(data_x, data_y,
|
|
gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
|
|
true_labels, false_labels)
|
|
|
|
# Save training statistics and increment the global step counter
|
|
tracker.save()
|
|
tracker.add_global_step(max(len(data_x), len(data_y)))
|
|
|
|
# Save images at intervals
|
|
batches_done = epoch * len(self.dataloader) + i
|
|
if batches_done % self.sample_interval == 0:
|
|
# Save models when sampling images
|
|
experiment.save_checkpoint()
|
|
# Sample images
|
|
self.sample_images(batches_done)
|
|
|
|
# Update learning rates
|
|
self.generator_lr_scheduler.step()
|
|
self.discriminator_lr_scheduler.step()
|
|
# New line
|
|
tracker.new_line()
|
|
|
|
def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
|
|
"""
|
|
### Optimize the generators with identity, gan and cycle losses.
|
|
"""
|
|
|
|
# Change to training mode
|
|
self.generator_xy.train()
|
|
self.generator_yx.train()
|
|
|
|
# Identity loss
|
|
# $$\lVert F(G(x^{(i)})) - x^{(i)} \lVert_1\
|
|
# \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1$$
|
|
loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
|
|
self.identity_loss(self.generator_xy(data_y), data_y))
|
|
|
|
# Generate images $G(x)$ and $F(y)$
|
|
gen_y = self.generator_xy(data_x)
|
|
gen_x = self.generator_yx(data_y)
|
|
|
|
# GAN loss
|
|
# $$\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2
|
|
# + \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2$$
|
|
loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
|
|
self.gan_loss(self.discriminator_x(gen_x), true_labels))
|
|
|
|
# Cycle loss
|
|
# $$
|
|
# \lVert F(G(x^{(i)})) - x^{(i)} \lVert_1 +
|
|
# \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1
|
|
# $$
|
|
loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
|
|
self.cycle_loss(self.generator_xy(gen_x), data_y))
|
|
|
|
# Total loss
|
|
loss_generator = (loss_gan +
|
|
self.cyclic_loss_coefficient * loss_cycle +
|
|
self.identity_loss_coefficient * loss_identity)
|
|
|
|
# Take a step in the optimizer
|
|
self.generator_optimizer.zero_grad()
|
|
loss_generator.backward()
|
|
self.generator_optimizer.step()
|
|
|
|
# Log losses
|
|
tracker.add({'loss.generator': loss_generator,
|
|
'loss.generator.cycle': loss_cycle,
|
|
'loss.generator.gan': loss_gan,
|
|
'loss.generator.identity': loss_identity})
|
|
|
|
# Return generated images
|
|
return gen_x, gen_y
|
|
|
|
def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
|
|
gen_x: torch.Tensor, gen_y: torch.Tensor,
|
|
true_labels: torch.Tensor, false_labels: torch.Tensor):
|
|
"""
|
|
### Optimize the discriminators with gan loss.
|
|
"""
|
|
# GAN Loss
|
|
# \begin{align}
|
|
# \bigg(D_Y\Big(y ^ {(i)}\Big) - 1\bigg) ^ 2
|
|
# + D_Y\Big(G\Big(x ^ {(i)}\Big)\Big) ^ 2 + \\
|
|
# \bigg(D_X\Big(x ^ {(i)}\Big) - 1\bigg) ^ 2
|
|
# + D_X\Big(F\Big(y ^ {(i)}\Big)\Big) ^ 2
|
|
# \end{align}
|
|
loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
|
|
self.gan_loss(self.discriminator_x(gen_x), false_labels) +
|
|
self.gan_loss(self.discriminator_y(data_y), true_labels) +
|
|
self.gan_loss(self.discriminator_y(gen_y), false_labels))
|
|
|
|
# Take a step in the optimizer
|
|
self.discriminator_optimizer.zero_grad()
|
|
loss_discriminator.backward()
|
|
self.discriminator_optimizer.step()
|
|
|
|
# Log losses
|
|
tracker.add({'loss.discriminator': loss_discriminator})
|
|
|
|
|
|
def train():
|
|
"""
|
|
## Train Cycle GAN
|
|
"""
|
|
# Create configurations
|
|
conf = Configs()
|
|
# Create an experiment
|
|
experiment.create(name='cycle_gan')
|
|
# Calculate configurations.
|
|
# It will calculate `conf.run` and all other configs required by it.
|
|
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
|
|
conf.initialize()
|
|
|
|
# Register models for saving and loading.
|
|
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
|
# You can also specify a custom dictionary of models.
|
|
experiment.add_pytorch_models(get_modules(conf))
|
|
# Start and watch the experiment
|
|
with experiment.start():
|
|
# Run the training
|
|
conf.run()
|
|
|
|
|
|
def plot_image(img: torch.Tensor):
|
|
"""
|
|
### Plots an image with matplotlib
|
|
"""
|
|
from matplotlib import pyplot as plt
|
|
|
|
# Move tensor to CPU
|
|
img = img.cpu()
|
|
# Get min and max values of the image for normalization
|
|
img_min, img_max = img.min(), img.max()
|
|
# Scale image values to be [0...1]
|
|
img = (img - img_min) / (img_max - img_min + 1e-5)
|
|
# We have to change the order of dimensions to HWC.
|
|
img = img.permute(1, 2, 0)
|
|
# Show Image
|
|
plt.imshow(img)
|
|
# We don't need axes
|
|
plt.axis('off')
|
|
# Display
|
|
plt.show()
|
|
|
|
|
|
def evaluate():
|
|
"""
|
|
## Evaluate trained Cycle GAN
|
|
"""
|
|
# Set the run uuid from the training run
|
|
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
|
# Create configs object
|
|
conf = Configs()
|
|
# Create experiment
|
|
experiment.create(name='cycle_gan_inference')
|
|
# Load hyper parameters set for training
|
|
conf_dict = experiment.load_configs(trained_run_uuid)
|
|
# Calculate configurations. We specify the generators `'generator_xy', 'generator_yx'`
|
|
# so that it only loads those and their dependencies.
|
|
# Configs like `device` and `img_channels` will be calculated since these are required by
|
|
# `generator_xy` and `generator_yx`.
|
|
#
|
|
# If you want other parameters like `dataset_name` you should specify them here.
|
|
# If you specify nothing all the configurations will be calculated including data loaders.
|
|
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
|
|
experiment.configs(conf, conf_dict)
|
|
conf.initialize()
|
|
|
|
# Register models for saving and loading.
|
|
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
|
# You can also specify a custom dictionary of models.
|
|
experiment.add_pytorch_models(get_modules(conf))
|
|
# Specify which run to load from.
|
|
# Loading will actually happen when you call `experiment.start`
|
|
experiment.load(trained_run_uuid)
|
|
|
|
# Start the experiment
|
|
with experiment.start():
|
|
# Image transformations
|
|
transforms_ = [
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
]
|
|
|
|
# Load your own data, here we try test set.
|
|
# I was trying with yosemite photos, they look awesome.
|
|
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
|
|
# in the call to `experiment.configs`
|
|
dataset = ImageDataset(conf.dataset_name, transforms_, 'train')
|
|
# Get an images from dataset
|
|
x_image = dataset[10]['x']
|
|
# Display the image
|
|
plot_image(x_image)
|
|
|
|
# Evaluation mode
|
|
conf.generator_xy.eval()
|
|
conf.generator_yx.eval()
|
|
|
|
# We dont need gradients
|
|
with torch.no_grad():
|
|
# Add batch dimension and move to the device we use
|
|
data = x_image.unsqueeze(0).to(conf.device)
|
|
generated_y = conf.generator_xy(data)
|
|
|
|
# Display the generated image.
|
|
plot_image(generated_y[0].cpu())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
train()
|
|
# evaluate()
|