mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 04:37:46 +08:00
made some changes
This commit is contained in:
@ -8,14 +8,14 @@ summary: >
|
||||
|
||||
# Cycle GAN
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation/tutorial of paper
|
||||
This is a [PyTorch](https://pytorch.org) implementation/tutorial of the paper
|
||||
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593).
|
||||
|
||||
I've taken pieces of code from [eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN).
|
||||
It is a very good resource if you want to checkout other GAN variations too.
|
||||
|
||||
Cycle GAN does image-to-image translation.
|
||||
It trains a model to translate an image from given distribution to another, say images of class A and B
|
||||
It trains a model to translate an image from given distribution to another, say, images of class A and B.
|
||||
Images of a certain distribution could be things like images of a certain style, or nature.
|
||||
The models do not need paired images between A and B.
|
||||
Just a set of images of each class is enough.
|
||||
@ -26,7 +26,7 @@ Cycle GAN trains two generator models and two discriminator models.
|
||||
One generator translates images from A to B and the other from B to A.
|
||||
The discriminators test whether the generated images look real.
|
||||
|
||||
This file contains the model code as well as training code.
|
||||
This file contains the model code as well as the training code.
|
||||
We also have a Google Colab notebook.
|
||||
|
||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/gan/cycle_gan.ipynb)
|
||||
@ -62,7 +62,7 @@ class GeneratorResNet(Module):
|
||||
super().__init__()
|
||||
# This first block runs a $7\times7$ convolution and maps the image to
|
||||
# a feature map.
|
||||
# The output feature map has same height and width because we have
|
||||
# The output feature map has the same height and width because we have
|
||||
# a padding of $3$.
|
||||
# Reflection padding is used because it gives better image quality at edges.
|
||||
#
|
||||
@ -145,7 +145,7 @@ class Discriminator(Module):
|
||||
super().__init__()
|
||||
channels, height, width = input_shape
|
||||
|
||||
# Output of the discriminator is also map of probabilities*
|
||||
# Output of the discriminator is also a map of probabilities*
|
||||
# whether each region of the image is real or generated
|
||||
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
|
||||
|
||||
@ -171,7 +171,7 @@ class Discriminator(Module):
|
||||
class DiscriminatorBlock(Module):
|
||||
"""
|
||||
This is the discriminator block module.
|
||||
It does a convolution, an optional normalization, and a leaky relu.
|
||||
It does a convolution, an optional normalization, and a leaky ReLU.
|
||||
|
||||
It shrinks the height and width of the input feature map by half.
|
||||
"""
|
||||
@ -199,7 +199,7 @@ def weights_init_normal(m):
|
||||
|
||||
def load_image(path: str):
|
||||
"""
|
||||
Loads an image and change to RGB if in grey-scale.
|
||||
Load an image and change to RGB if in grey-scale.
|
||||
"""
|
||||
image = Image.open(path)
|
||||
if image.mode != 'RGB':
|
||||
@ -257,7 +257,7 @@ class ImageDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Return a pair of images.
|
||||
# These pairs get batched together, and they do not act like pairs in training
|
||||
# These pairs get batched together, and they do not act like pairs in training.
|
||||
# So it is kind of ok that we always keep giving the same pair.
|
||||
return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
|
||||
"y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
|
||||
@ -275,8 +275,8 @@ class ReplayBuffer:
|
||||
Generated images are added to the replay buffer and sampled from it.
|
||||
|
||||
The replay buffer returns the newly added image with a probability of $0.5$.
|
||||
Otherwise it sends an older generated image and and replaces the older image
|
||||
with the new generated image.
|
||||
Otherwise, it sends an older generated image and replaces the older image
|
||||
with the newly generated image.
|
||||
|
||||
This is done to reduce model oscillation.
|
||||
"""
|
||||
@ -404,7 +404,7 @@ class Configs(BaseConfigs):
|
||||
|
||||
# Create the learning rate schedules.
|
||||
# The learning rate stars flat until `decay_start` epochs,
|
||||
# and then linearly reduces to $0$ at end of training.
|
||||
# and then linearly reduce to $0$ at end of training.
|
||||
decay_epochs = self.epochs - self.decay_start
|
||||
self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
|
||||
@ -478,10 +478,10 @@ class Configs(BaseConfigs):
|
||||
Basically if the two generators (transformations) are applied in series it should give back the
|
||||
original image.
|
||||
This is the main contribution of this paper.
|
||||
It train the generators to generate an image of the other distribution that is similar to
|
||||
It trains the generators to generate an image of the other distribution that is similar to
|
||||
the original image.
|
||||
Without this loss $G(x)$ could generate anything that's from the distribution of $Y$.
|
||||
Now it needs to generate something from the distribution of $Y$ but still have properties of $x$,
|
||||
Now it needs to generate something from the distribution of $Y$ but still has properties of $x$,
|
||||
so that $F(G(x)$ can re-generate something like $x$.
|
||||
|
||||
$\mathcal{L}_{cyc}$ is the identity loss.
|
||||
@ -503,7 +503,7 @@ class Configs(BaseConfigs):
|
||||
|
||||
In order to stabilize the training the negative log- likelihood objective
|
||||
was replaced by a least-squared loss -
|
||||
the least-squared error of discriminator labelling real images with 1,
|
||||
the least-squared error of discriminator, labelling real images with 1,
|
||||
and generated images with 0.
|
||||
So we want to descend on the gradient,
|
||||
\begin{align}
|
||||
@ -681,7 +681,7 @@ def train():
|
||||
|
||||
def plot_image(img: torch.Tensor):
|
||||
"""
|
||||
### Plots an image with matplotlib
|
||||
### Plot an image with matplotlib
|
||||
"""
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
@ -705,7 +705,7 @@ def evaluate():
|
||||
"""
|
||||
## Evaluate trained Cycle GAN
|
||||
"""
|
||||
# Set the run uuid from the training run
|
||||
# Set the run UUID from the training run
|
||||
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
||||
# Create configs object
|
||||
conf = Configs()
|
||||
@ -715,11 +715,11 @@ def evaluate():
|
||||
conf_dict = experiment.load_configs(trained_run_uuid)
|
||||
# Calculate configurations. We specify the generators `'generator_xy', 'generator_yx'`
|
||||
# so that it only loads those and their dependencies.
|
||||
# Configs like `device` and `img_channels` will be calculated since these are required by
|
||||
# Configs like `device` and `img_channels` will be calculated, since these are required by
|
||||
# `generator_xy` and `generator_yx`.
|
||||
#
|
||||
# If you want other parameters like `dataset_name` you should specify them here.
|
||||
# If you specify nothing all the configurations will be calculated including data loaders.
|
||||
# If you specify nothing, all the configurations will be calculated, including data loaders.
|
||||
# Calculation of configurations and their dependencies will happen when you call `experiment.start`
|
||||
experiment.configs(conf, conf_dict)
|
||||
conf.initialize()
|
||||
@ -740,12 +740,12 @@ def evaluate():
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
|
||||
# Load your own data, here we try test set.
|
||||
# I was trying with yosemite photos, they look awesome.
|
||||
# Load your own data. Here we try the test set.
|
||||
# I was trying with Yosemite photos, they look awesome.
|
||||
# You can use `conf.dataset_name`, if you specified `dataset_name` as something you wanted to be calculated
|
||||
# in the call to `experiment.configs`
|
||||
dataset = ImageDataset(conf.dataset_name, transforms_, 'train')
|
||||
# Get an images from dataset
|
||||
# Get an image from dataset
|
||||
x_image = dataset[10]['x']
|
||||
# Display the image
|
||||
plot_image(x_image)
|
||||
@ -754,7 +754,7 @@ def evaluate():
|
||||
conf.generator_xy.eval()
|
||||
conf.generator_yx.eval()
|
||||
|
||||
# We dont need gradients
|
||||
# We don't need gradients
|
||||
with torch.no_grad():
|
||||
# Add batch dimension and move to the device we use
|
||||
data = x_image.unsqueeze(0).to(conf.device)
|
||||
|
||||
Reference in New Issue
Block a user