This is a PyTorch implementation/tutorial of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
I've taken pieces of code from eriklindernoren/PyTorch-GAN. It is a very good resource if you want to checkout other GAN variations too.
Cycle GAN does image-to-image translation. It trains a model to translate an image from given distribution to another, say, images of class A and B. Images of a certain distribution could be things like images of a certain style, or nature. The models do not need paired images between A and B. Just a set of images of each class is enough. This works very well on changing between image styles, lighting changes, pattern changes, etc. For example, changing summer to winter, painting style to photos, and horses to zebras.
Cycle GAN trains two generator models and two discriminator models. One generator translates images from A to B and the other from B to A. The discriminators test whether the generated images look real.
This file contains the model code as well as the training code. We also have a Google Colab notebook.
36import itertools
37import random
38import zipfile
39from typing import Tuple
40
41import torch
42import torch.nn as nn
43import torchvision.transforms as transforms
44from PIL import Image
45from torch.utils.data import DataLoader, Dataset
46from torchvision.transforms import InterpolationMode
47from torchvision.utils import make_grid
48
49from labml import lab, tracker, experiment, monit
50from labml.configs import BaseConfigs
51from labml.utils.download import download_file
52from labml.utils.pytorch import get_modules
53from labml_helpers.device import DeviceConfigs
54from labml_helpers.module import ModuleThe generator is a residual network.
57class GeneratorResNet(Module):62 def __init__(self, input_channels: int, n_residual_blocks: int):
63 super().__init__()This first block runs a convolution and maps the image to a feature map. The output feature map has the same height and width because we have a padding of . Reflection padding is used because it gives better image quality at edges.
inplace=True
in ReLU
saves a little bit of memory.
71 out_features = 64
72 layers = [
73 nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
74 nn.InstanceNorm2d(out_features),
75 nn.ReLU(inplace=True),
76 ]
77 in_features = out_featuresWe down-sample with two convolutions with stride of 2
81 for _ in range(2):
82 out_features *= 2
83 layers += [
84 nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
85 nn.InstanceNorm2d(out_features),
86 nn.ReLU(inplace=True),
87 ]
88 in_features = out_featuresWe take this through n_residual_blocks
. This module is defined below.
92 for _ in range(n_residual_blocks):
93 layers += [ResidualBlock(out_features)]Then the resulting feature map is up-sampled to match the original image height and width.
97 for _ in range(2):
98 out_features //= 2
99 layers += [
100 nn.Upsample(scale_factor=2),
101 nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
102 nn.InstanceNorm2d(out_features),
103 nn.ReLU(inplace=True),
104 ]
105 in_features = out_featuresFinally we map the feature map to an RGB image
108 layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]Create a sequential module with the layers
111 self.layers = nn.Sequential(*layers)Initialize weights to
114 self.apply(weights_init_normal)116 def forward(self, x):
117 return self.layers(x)This is the residual block, with two convolution layers.
120class ResidualBlock(Module):125 def __init__(self, in_features: int):
126 super().__init__()
127 self.block = nn.Sequential(
128 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
129 nn.InstanceNorm2d(in_features),
130 nn.ReLU(inplace=True),
131 nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
132 nn.InstanceNorm2d(in_features),
133 nn.ReLU(inplace=True),
134 )136 def forward(self, x: torch.Tensor):
137 return x + self.block(x)This is the discriminator.
140class Discriminator(Module):145 def __init__(self, input_shape: Tuple[int, int, int]):
146 super().__init__()
147 channels, height, width = input_shapeOutput of the discriminator is also a map of probabilities, whether each region of the image is real or generated
151 self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
152
153 self.layers = nn.Sequential(Each of these blocks will shrink the height and width by a factor of 2
155 DiscriminatorBlock(channels, 64, normalize=False),
156 DiscriminatorBlock(64, 128),
157 DiscriminatorBlock(128, 256),
158 DiscriminatorBlock(256, 512),Zero pad on top and left to keep the output height and width same with the kernel
161 nn.ZeroPad2d((1, 0, 1, 0)),
162 nn.Conv2d(512, 1, kernel_size=4, padding=1)
163 )Initialize weights to
166 self.apply(weights_init_normal)168 def forward(self, img):
169 return self.layers(img)This is the discriminator block module. It does a convolution, an optional normalization, and a leaky ReLU.
It shrinks the height and width of the input feature map by half.
172class DiscriminatorBlock(Module):180 def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
181 super().__init__()
182 layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
183 if normalize:
184 layers.append(nn.InstanceNorm2d(out_filters))
185 layers.append(nn.LeakyReLU(0.2, inplace=True))
186 self.layers = nn.Sequential(*layers)188 def forward(self, x: torch.Tensor):
189 return self.layers(x)Initialize convolution layer weights to
192def weights_init_normal(m):196 classname = m.__class__.__name__
197 if classname.find("Conv") != -1:
198 torch.nn.init.normal_(m.weight.data, 0.0, 0.02)Load an image and change to RGB if in grey-scale.
201def load_image(path: str):205 image = Image.open(path)
206 if image.mode != 'RGB':
207 image = Image.new("RGB", image.size).paste(image)
208
209 return image212class ImageDataset(Dataset):217 @staticmethod
218 def download(dataset_name: str):URL
223 url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'Download folder
225 root = lab.get_data_path() / 'cycle_gan'
226 if not root.exists():
227 root.mkdir(parents=True)Download destination
229 archive = root / f'{dataset_name}.zip'Download file (generally ~100MB)
231 download_file(url, archive)Extract the archive
233 with zipfile.ZipFile(archive, 'r') as f:
234 f.extractall(root)dataset_name
is the name of the dataset transforms_
is the set of image transforms mode
is either train
or test
236 def __init__(self, dataset_name: str, transforms_, mode: str):Dataset path
245 root = lab.get_data_path() / 'cycle_gan' / dataset_nameDownload if missing
247 if not root.exists():
248 self.download(dataset_name)Image transforms
251 self.transform = transforms.Compose(transforms_)Get image paths
254 path_a = root / f'{mode}A'
255 path_b = root / f'{mode}B'
256 self.files_a = sorted(str(f) for f in path_a.iterdir())
257 self.files_b = sorted(str(f) for f in path_b.iterdir())259 def __getitem__(self, index):Return a pair of images. These pairs get batched together, and they do not act like pairs in training. So it is kind of ok that we always keep giving the same pair.
263 return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
264 "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}266 def __len__(self):Number of images in the dataset
268 return max(len(self.files_a), len(self.files_b))Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it.
The replay buffer returns the newly added image with a probability of . Otherwise, it sends an older generated image and replaces the older image with the newly generated image.
This is done to reduce model oscillation.
271class ReplayBuffer:285 def __init__(self, max_size: int = 50):
286 self.max_size = max_size
287 self.data = []Add/retrieve an image
289 def push_and_pop(self, data: torch.Tensor):291 data = data.detach()
292 res = []
293 for element in data:
294 if len(self.data) < self.max_size:
295 self.data.append(element)
296 res.append(element)
297 else:
298 if random.uniform(0, 1) > 0.5:
299 i = random.randint(0, self.max_size - 1)
300 res.append(self.data[i].clone())
301 self.data[i] = element
302 else:
303 res.append(element)
304 return torch.stack(res)307class Configs(BaseConfigs):DeviceConfigs
will pick a GPU if available
311 device: torch.device = DeviceConfigs()Hyper-parameters
314 epochs: int = 200
315 dataset_name: str = 'monet2photo'
316 batch_size: int = 1
317
318 data_loader_workers = 8
319
320 learning_rate = 0.0002
321 adam_betas = (0.5, 0.999)
322 decay_start = 100The paper suggests using a least-squares loss instead of negative log-likelihood, at it is found to be more stable.
326 gan_loss = torch.nn.MSELoss()L1 loss is used for cycle loss and identity loss
329 cycle_loss = torch.nn.L1Loss()
330 identity_loss = torch.nn.L1Loss()Image dimensions
333 img_height = 256
334 img_width = 256
335 img_channels = 3Number of residual blocks in the generator
338 n_residual_blocks = 9Loss coefficients
341 cyclic_loss_coefficient = 10.0
342 identity_loss_coefficient = 5.
343
344 sample_interval = 500Models
347 generator_xy: GeneratorResNet
348 generator_yx: GeneratorResNet
349 discriminator_x: Discriminator
350 discriminator_y: DiscriminatorOptimizers
353 generator_optimizer: torch.optim.Adam
354 discriminator_optimizer: torch.optim.AdamLearning rate schedules
357 generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
358 discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLRData loaders
361 dataloader: DataLoader
362 valid_dataloader: DataLoaderGenerate samples from test set and save them
364 def sample_images(self, n: int):366 batch = next(iter(self.valid_dataloader))
367 self.generator_xy.eval()
368 self.generator_yx.eval()
369 with torch.no_grad():
370 data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
371 gen_y = self.generator_xy(data_x)
372 gen_x = self.generator_yx(data_y)Arrange images along x-axis
375 data_x = make_grid(data_x, nrow=5, normalize=True)
376 data_y = make_grid(data_y, nrow=5, normalize=True)
377 gen_x = make_grid(gen_x, nrow=5, normalize=True)
378 gen_y = make_grid(gen_y, nrow=5, normalize=True)Arrange images along y-axis
381 image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)Show samples
384 plot_image(image_grid)386 def initialize(self):390 input_shape = (self.img_channels, self.img_height, self.img_width)Create the models
393 self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394 self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
395 self.discriminator_x = Discriminator(input_shape).to(self.device)
396 self.discriminator_y = Discriminator(input_shape).to(self.device)Create the optmizers
399 self.generator_optimizer = torch.optim.Adam(
400 itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
401 lr=self.learning_rate, betas=self.adam_betas)
402 self.discriminator_optimizer = torch.optim.Adam(
403 itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
404 lr=self.learning_rate, betas=self.adam_betas)Create the learning rate schedules. The learning rate stars flat until decay_start
epochs, and then linearly reduce to at end of training.
409 decay_epochs = self.epochs - self.decay_start
410 self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
411 self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
412 self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
413 self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)Image transformations
416 transforms_ = [
417 transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
418 transforms.RandomCrop((self.img_height, self.img_width)),
419 transforms.RandomHorizontalFlip(),
420 transforms.ToTensor(),
421 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
422 ]Training data loader
425 self.dataloader = DataLoader(
426 ImageDataset(self.dataset_name, transforms_, 'train'),
427 batch_size=self.batch_size,
428 shuffle=True,
429 num_workers=self.data_loader_workers,
430 )Validation data loader
433 self.valid_dataloader = DataLoader(
434 ImageDataset(self.dataset_name, transforms_, "test"),
435 batch_size=5,
436 shuffle=True,
437 num_workers=self.data_loader_workers,
438 )We aim to solve:
where, translates images from , translates images from , tests if images are from space, tests if images are from space, and
is the generative adversarial loss from the original GAN paper.
is the cyclic loss, where we try to get to be similar to , and to be similar to . Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss could generate anything that's from the distribution of . Now it needs to generate something from the distribution of but still has properties of , so that can re-generate something like .
is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output.
To solve , discriminators and should ascend on the gradient,
That is descend on negative log-likelihood loss.
In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient,
We use least-squares for generators also. The generators should descend on the gradient,
We use generator_xy
for and generator_yx
for . We use discriminator_x
for and discriminator_y
for .
440 def run(self):Replay buffers to keep generated samples
542 gen_x_buffer = ReplayBuffer()
543 gen_y_buffer = ReplayBuffer()Loop through epochs
546 for epoch in monit.loop(self.epochs):Loop through the dataset
548 for i, batch in monit.enum('Train', self.dataloader):Move images to the device
550 data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)true labels equal to
553 true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
554 device=self.device, requires_grad=False)false labels equal to
556 false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
557 device=self.device, requires_grad=False)Train the generators. This returns the generated images.
561 gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)Train discriminators
564 self.optimize_discriminator(data_x, data_y,
565 gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
566 true_labels, false_labels)Save training statistics and increment the global step counter
569 tracker.save()
570 tracker.add_global_step(max(len(data_x), len(data_y)))Save images at intervals
573 batches_done = epoch * len(self.dataloader) + i
574 if batches_done % self.sample_interval == 0:Save models when sampling images
576 experiment.save_checkpoint()Sample images
578 self.sample_images(batches_done)Update learning rates
581 self.generator_lr_scheduler.step()
582 self.discriminator_lr_scheduler.step()New line
584 tracker.new_line()586 def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):Change to training mode
592 self.generator_xy.train()
593 self.generator_yx.train()Identity loss
598 loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
599 self.identity_loss(self.generator_xy(data_y), data_y))Generate images and
602 gen_y = self.generator_xy(data_x)
603 gen_x = self.generator_yx(data_y)GAN loss
608 loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
609 self.gan_loss(self.discriminator_x(gen_x), true_labels))Cycle loss
616 loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
617 self.cycle_loss(self.generator_xy(gen_x), data_y))Total loss
620 loss_generator = (loss_gan +
621 self.cyclic_loss_coefficient * loss_cycle +
622 self.identity_loss_coefficient * loss_identity)Take a step in the optimizer
625 self.generator_optimizer.zero_grad()
626 loss_generator.backward()
627 self.generator_optimizer.step()Log losses
630 tracker.add({'loss.generator': loss_generator,
631 'loss.generator.cycle': loss_cycle,
632 'loss.generator.gan': loss_gan,
633 'loss.generator.identity': loss_identity})Return generated images
636 return gen_x, gen_y638 def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
639 gen_x: torch.Tensor, gen_y: torch.Tensor,
640 true_labels: torch.Tensor, false_labels: torch.Tensor):653 loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
654 self.gan_loss(self.discriminator_x(gen_x), false_labels) +
655 self.gan_loss(self.discriminator_y(data_y), true_labels) +
656 self.gan_loss(self.discriminator_y(gen_y), false_labels))Take a step in the optimizer
659 self.discriminator_optimizer.zero_grad()
660 loss_discriminator.backward()
661 self.discriminator_optimizer.step()Log losses
664 tracker.add({'loss.discriminator': loss_discriminator})667def train():Create configurations
672 conf = Configs()Create an experiment
674 experiment.create(name='cycle_gan')Calculate configurations. It will calculate conf.run
and all other configs required by it.
677 experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
678 conf.initialize()Register models for saving and loading. get_modules
gives a dictionary of nn.Modules
in conf
. You can also specify a custom dictionary of models.
683 experiment.add_pytorch_models(get_modules(conf))Start and watch the experiment
685 with experiment.start():Run the training
687 conf.run()690def plot_image(img: torch.Tensor):694 from matplotlib import pyplot as pltMove tensor to CPU
697 img = img.cpu()Get min and max values of the image for normalization
699 img_min, img_max = img.min(), img.max()We have to change the order of dimensions to HWC.
703 img = img.permute(1, 2, 0)Show Image
705 plt.imshow(img)We don't need axes
707 plt.axis('off')Display
709 plt.show()712def evaluate():Set the run UUID from the training run
717 trained_run_uuid = 'f73c1164184711eb9190b74249275441'Create configs object
719 conf = Configs()Create experiment
721 experiment.create(name='cycle_gan_inference')Load hyper parameters set for training
723 conf_dict = experiment.load_configs(trained_run_uuid)Calculate configurations. We specify the generators 'generator_xy', 'generator_yx'
so that it only loads those and their dependencies. Configs like device
and img_channels
will be calculated, since these are required by generator_xy
and generator_yx
.
If you want other parameters like dataset_name
you should specify them here. If you specify nothing, all the configurations will be calculated, including data loaders. Calculation of configurations and their dependencies will happen when you call experiment.start
732 experiment.configs(conf, conf_dict)
733 conf.initialize()Register models for saving and loading. get_modules
gives a dictionary of nn.Modules
in conf
. You can also specify a custom dictionary of models.
738 experiment.add_pytorch_models(get_modules(conf))Specify which run to load from. Loading will actually happen when you call experiment.start
741 experiment.load(trained_run_uuid)Start the experiment
744 with experiment.start():Image transformations
746 transforms_ = [
747 transforms.ToTensor(),
748 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
749 ]Load your own data. Here we try the test set. I was trying with Yosemite photos, they look awesome. You can use conf.dataset_name
, if you specified dataset_name
as something you wanted to be calculated in the call to experiment.configs
755 dataset = ImageDataset(conf.dataset_name, transforms_, 'train')Get an image from dataset
757 x_image = dataset[10]['x']Display the image
759 plot_image(x_image)Evaluation mode
762 conf.generator_xy.eval()
763 conf.generator_yx.eval()We don't need gradients
766 with torch.no_grad():Add batch dimension and move to the device we use
768 data = x_image.unsqueeze(0).to(conf.device)
769 generated_y = conf.generator_xy(data)Display the generated image.
772 plot_image(generated_y[0].cpu())
773
774
775if __name__ == '__main__':
776 train()evaluate()