made some changes

This commit is contained in:
KeshSam
2021-02-14 14:37:50 +00:00
committed by Varuna Jayasiri
parent fc437c936b
commit 33c3e8b8ce

View File

@ -8,14 +8,14 @@ summary: >
# Cycle GAN
This is a [PyTorch](https://pytorch.org) implementation/tutorial of paper
This is a [PyTorch](https://pytorch.org) implementation/tutorial of the paper
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593).
I've taken pieces of code from [eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN).
It is a very good resource if you want to checkout other GAN variations too.
Cycle GAN does image-to-image translation.
It trains a model to translate an image from given distribution to another, say images of class A and B
It trains a model to translate an image from given distribution to another, say, images of class A and B.
Images of a certain distribution could be things like images of a certain style, or nature.
The models do not need paired images between A and B.
Just a set of images of each class is enough.
@ -26,7 +26,7 @@ Cycle GAN trains two generator models and two discriminator models.
One generator translates images from A to B and the other from B to A.
The discriminators test whether the generated images look real.
This file contains the model code as well as training code.
This file contains the model code as well as the training code.
We also have a Google Colab notebook.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/gan/cycle_gan.ipynb)
@ -62,7 +62,7 @@ class GeneratorResNet(Module):
super().__init__()
# This first block runs a $7\times7$ convolution and maps the image to
# a feature map.
# The output feature map has same height and width because we have
# The output feature map has the same height and width because we have
# a padding of $3$.
# Reflection padding is used because it gives better image quality at edges.
#
@ -145,7 +145,7 @@ class Discriminator(Module):
super().__init__()
channels, height, width = input_shape
# Output of the discriminator is also map of probabilities*
# Output of the discriminator is also a map of probabilities*
# whether each region of the image is real or generated
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
@ -171,7 +171,7 @@ class Discriminator(Module):
class DiscriminatorBlock(Module):
"""
This is the discriminator block module.
It does a convolution, an optional normalization, and a leaky relu.
It does a convolution, an optional normalization, and a leaky ReLU.
It shrinks the height and width of the input feature map by half.
"""
@ -199,7 +199,7 @@ def weights_init_normal(m):
def load_image(path: str):
"""
Loads an image and change to RGB if in grey-scale.
Load an image and change to RGB if in grey-scale.
"""
image = Image.open(path)
if image.mode != 'RGB':
@ -257,7 +257,7 @@ class ImageDataset(Dataset):
def __getitem__(self, index):
# Return a pair of images.
# These pairs get batched together, and they do not act like pairs in training
# These pairs get batched together, and they do not act like pairs in training.
# So it is kind of ok that we always keep giving the same pair.
return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
"y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
@ -275,8 +275,8 @@ class ReplayBuffer:
Generated images are added to the replay buffer and sampled from it.
The replay buffer returns the newly added image with a probability of $0.5$.
Otherwise it sends an older generated image and and replaces the older image
with the new generated image.
Otherwise, it sends an older generated image and replaces the older image
with the newly generated image.
This is done to reduce model oscillation.
"""
@ -404,7 +404,7 @@ class Configs(BaseConfigs):
# Create the learning rate schedules.
# The learning rate stars flat until `decay_start` epochs,
# and then linearly reduces to $0$ at end of training.
# and then linearly reduce to $0$ at end of training.
decay_epochs = self.epochs - self.decay_start
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
@ -478,10 +478,10 @@ class Configs(BaseConfigs):
Basically if the two generators (transformations) are applied in series it should give back the
original image.
This is the main contribution of this paper.
It train the generators to generate an image of the other distribution that is similar to
It trains the generators to generate an image of the other distribution that is similar to
the original image.
Without this loss $G(x)$ could generate anything that's from the distribution of $Y$.
Now it needs to generate something from the distribution of $Y$ but still have properties of $x$,
Now it needs to generate something from the distribution of $Y$ but still has properties of $x$,
so that $F(G(x)$ can re-generate something like $x$.
$\mathcal{L}_{cyc}$ is the identity loss.
@ -503,7 +503,7 @@ class Configs(BaseConfigs):
In order to stabilize the training the negative log- likelihood objective
was replaced by a least-squared loss -
the least-squared error of discriminator labelling real images with 1,
the least-squared error of discriminator, labelling real images with 1,
and generated images with 0.
So we want to descend on the gradient,
\begin{align}
@ -681,7 +681,7 @@ def train():
def plot_image(img: torch.Tensor):
"""
### Plots an image with matplotlib
### Plot an image with matplotlib
"""
from matplotlib import pyplot as plt
@ -705,7 +705,7 @@ def evaluate():
"""
## Evaluate trained Cycle GAN
"""
# Set the run uuid from the training run
# Set the run UUID from the training run
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
# Create configs object
conf = Configs()
@ -715,11 +715,11 @@ def evaluate():
conf_dict = experiment.load_configs(trained_run_uuid)
# Calculate configurations. We specify the generators `'generator_xy', 'generator_yx'`
# so that it only loads those and their dependencies.
# Configs like `device` and `img_channels` will be calculated since these are required by
# Configs like `device` and `img_channels` will be calculated, since these are required by
# `generator_xy` and `generator_yx`.
#
# If you want other parameters like `dataset_name` you should specify them here.
# If you specify nothing all the configurations will be calculated including data loaders.
# If you specify nothing, all the configurations will be calculated, including data loaders.
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
experiment.configs(conf, conf_dict)
conf.initialize()
@ -740,12 +740,12 @@ def evaluate():
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Load your own data, here we try test set.
# I was trying with yosemite photos, they look awesome.
# Load your own data. Here we try the test set.
# I was trying with Yosemite photos, they look awesome.
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
# in the call to `experiment.configs`
dataset = ImageDataset(conf.dataset_name, transforms_, 'train')
# Get an images from dataset
# Get an image from dataset
x_image = dataset[10]['x']
# Display the image
plot_image(x_image)
@ -754,7 +754,7 @@ def evaluate():
conf.generator_xy.eval()
conf.generator_yx.eval()
# We dont need gradients
# We don't need gradients
with torch.no_grad():
# Add batch dimension and move to the device we use
data = x_image.unsqueeze(0).to(conf.device)