Cycle GAN

This is a PyTorch implementation/tutorial of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.

I've taken pieces of code from eriklindernoren/PyTorch-GAN. It is a very good resource if you want to checkout other GAN variations too.

Cycle GAN does image-to-image translation. It trains a model to translate an image from given distribution to another, say, images of class A and B. Images of a certain distribution could be things like images of a certain style, or nature. The models do not need paired images between A and B. Just a set of images of each class is enough. This works very well on changing between image styles, lighting changes, pattern changes, etc. For example, changing summer to winter, painting style to photos, and horses to zebras.

Cycle GAN trains two generator models and two discriminator models. One generator translates images from A to B and the other from B to A. The discriminators test whether the generated images look real.

This file contains the model code as well as the training code. We also have a Google Colab notebook.

Open In Colab View Run

36import itertools
37import random
38import zipfile
39from typing import Tuple
40
41import torch
42import torch.nn as nn
43import torchvision.transforms as transforms
44from PIL import Image
45from torch.utils.data import DataLoader, Dataset
46from torchvision.transforms import InterpolationMode
47from torchvision.utils import make_grid
48
49from labml import lab, tracker, experiment, monit
50from labml.configs import BaseConfigs
51from labml.utils.download import download_file
52from labml.utils.pytorch import get_modules
53from labml_helpers.device import DeviceConfigs
54from labml_helpers.module import Module

The generator is a residual network.

57class GeneratorResNet(Module):
62    def __init__(self, input_channels: int, n_residual_blocks: int):
63        super().__init__()

This first block runs a convolution and maps the image to a feature map. The output feature map has the same height and width because we have a padding of . Reflection padding is used because it gives better image quality at edges.

inplace=True in ReLU saves a little bit of memory.

71        out_features = 64
72        layers = [
73            nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
74            nn.InstanceNorm2d(out_features),
75            nn.ReLU(inplace=True),
76        ]
77        in_features = out_features

We down-sample with two convolutions with stride of 2

81        for _ in range(2):
82            out_features *= 2
83            layers += [
84                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
85                nn.InstanceNorm2d(out_features),
86                nn.ReLU(inplace=True),
87            ]
88            in_features = out_features

We take this through n_residual_blocks . This module is defined below.

92        for _ in range(n_residual_blocks):
93            layers += [ResidualBlock(out_features)]

Then the resulting feature map is up-sampled to match the original image height and width.

97        for _ in range(2):
98            out_features //= 2
99            layers += [
100                nn.Upsample(scale_factor=2),
101                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
102                nn.InstanceNorm2d(out_features),
103                nn.ReLU(inplace=True),
104            ]
105            in_features = out_features

Finally we map the feature map to an RGB image

108        layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]

Create a sequential module with the layers

111        self.layers = nn.Sequential(*layers)

Initialize weights to

114        self.apply(weights_init_normal)
116    def forward(self, x):
117        return self.layers(x)

This is the residual block, with two convolution layers.

120class ResidualBlock(Module):
125    def __init__(self, in_features: int):
126        super().__init__()
127        self.block = nn.Sequential(
128            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
129            nn.InstanceNorm2d(in_features),
130            nn.ReLU(inplace=True),
131            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
132            nn.InstanceNorm2d(in_features),
133            nn.ReLU(inplace=True),
134        )
136    def forward(self, x: torch.Tensor):
137        return x + self.block(x)

This is the discriminator.

140class Discriminator(Module):
145    def __init__(self, input_shape: Tuple[int, int, int]):
146        super().__init__()
147        channels, height, width = input_shape

Output of the discriminator is also a map of probabilities, whether each region of the image is real or generated

151        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
152
153        self.layers = nn.Sequential(

Each of these blocks will shrink the height and width by a factor of 2

155            DiscriminatorBlock(channels, 64, normalize=False),
156            DiscriminatorBlock(64, 128),
157            DiscriminatorBlock(128, 256),
158            DiscriminatorBlock(256, 512),

Zero pad on top and left to keep the output height and width same with the kernel

161            nn.ZeroPad2d((1, 0, 1, 0)),
162            nn.Conv2d(512, 1, kernel_size=4, padding=1)
163        )

Initialize weights to

166        self.apply(weights_init_normal)
168    def forward(self, img):
169        return self.layers(img)

This is the discriminator block module. It does a convolution, an optional normalization, and a leaky ReLU.

It shrinks the height and width of the input feature map by half.

172class DiscriminatorBlock(Module):
180    def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
181        super().__init__()
182        layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
183        if normalize:
184            layers.append(nn.InstanceNorm2d(out_filters))
185        layers.append(nn.LeakyReLU(0.2, inplace=True))
186        self.layers = nn.Sequential(*layers)
188    def forward(self, x: torch.Tensor):
189        return self.layers(x)

Initialize convolution layer weights to

192def weights_init_normal(m):
196    classname = m.__class__.__name__
197    if classname.find("Conv") != -1:
198        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

Load an image and change to RGB if in grey-scale.

201def load_image(path: str):
205    image = Image.open(path)
206    if image.mode != 'RGB':
207        image = Image.new("RGB", image.size).paste(image)
208
209    return image

Dataset to load images

212class ImageDataset(Dataset):

Download dataset and extract data

217    @staticmethod
218    def download(dataset_name: str):

URL

223        url = f'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip'

Download folder

225        root = lab.get_data_path() / 'cycle_gan'
226        if not root.exists():
227            root.mkdir(parents=True)

Download destination

229        archive = root / f'{dataset_name}.zip'

Download file (generally ~100MB)

231        download_file(url, archive)

Extract the archive

233        with zipfile.ZipFile(archive, 'r') as f:
234            f.extractall(root)

Initialize the dataset

  • dataset_name is the name of the dataset
  • transforms_ is the set of image transforms
  • mode is either train or test
236    def __init__(self, dataset_name: str, transforms_, mode: str):

Dataset path

245        root = lab.get_data_path() / 'cycle_gan' / dataset_name

Download if missing

247        if not root.exists():
248            self.download(dataset_name)

Image transforms

251        self.transform = transforms.Compose(transforms_)

Get image paths

254        path_a = root / f'{mode}A'
255        path_b = root / f'{mode}B'
256        self.files_a = sorted(str(f) for f in path_a.iterdir())
257        self.files_b = sorted(str(f) for f in path_b.iterdir())
259    def __getitem__(self, index):

Return a pair of images. These pairs get batched together, and they do not act like pairs in training. So it is kind of ok that we always keep giving the same pair.

263        return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])),
264                "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))}
266    def __len__(self):

Number of images in the dataset

268        return max(len(self.files_a), len(self.files_b))

Replay Buffer

Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it.

The replay buffer returns the newly added image with a probability of . Otherwise, it sends an older generated image and replaces the older image with the newly generated image.

This is done to reduce model oscillation.

271class ReplayBuffer:
285    def __init__(self, max_size: int = 50):
286        self.max_size = max_size
287        self.data = []

Add/retrieve an image

289    def push_and_pop(self, data: torch.Tensor):
291        data = data.detach()
292        res = []
293        for element in data:
294            if len(self.data) < self.max_size:
295                self.data.append(element)
296                res.append(element)
297            else:
298                if random.uniform(0, 1) > 0.5:
299                    i = random.randint(0, self.max_size - 1)
300                    res.append(self.data[i].clone())
301                    self.data[i] = element
302                else:
303                    res.append(element)
304        return torch.stack(res)

Configurations

307class Configs(BaseConfigs):

DeviceConfigs will pick a GPU if available

311    device: torch.device = DeviceConfigs()

Hyper-parameters

314    epochs: int = 200
315    dataset_name: str = 'monet2photo'
316    batch_size: int = 1
317
318    data_loader_workers = 8
319
320    learning_rate = 0.0002
321    adam_betas = (0.5, 0.999)
322    decay_start = 100

The paper suggests using a least-squares loss instead of negative log-likelihood, at it is found to be more stable.

326    gan_loss = torch.nn.MSELoss()

L1 loss is used for cycle loss and identity loss

329    cycle_loss = torch.nn.L1Loss()
330    identity_loss = torch.nn.L1Loss()

Image dimensions

333    img_height = 256
334    img_width = 256
335    img_channels = 3

Number of residual blocks in the generator

338    n_residual_blocks = 9

Loss coefficients

341    cyclic_loss_coefficient = 10.0
342    identity_loss_coefficient = 5.
343
344    sample_interval = 500

Models

347    generator_xy: GeneratorResNet
348    generator_yx: GeneratorResNet
349    discriminator_x: Discriminator
350    discriminator_y: Discriminator

Optimizers

353    generator_optimizer: torch.optim.Adam
354    discriminator_optimizer: torch.optim.Adam

Learning rate schedules

357    generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
358    discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR

Data loaders

361    dataloader: DataLoader
362    valid_dataloader: DataLoader

Generate samples from test set and save them

364    def sample_images(self, n: int):
366        batch = next(iter(self.valid_dataloader))
367        self.generator_xy.eval()
368        self.generator_yx.eval()
369        with torch.no_grad():
370            data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device)
371            gen_y = self.generator_xy(data_x)
372            gen_x = self.generator_yx(data_y)

Arrange images along x-axis

375            data_x = make_grid(data_x, nrow=5, normalize=True)
376            data_y = make_grid(data_y, nrow=5, normalize=True)
377            gen_x = make_grid(gen_x, nrow=5, normalize=True)
378            gen_y = make_grid(gen_y, nrow=5, normalize=True)

Arrange images along y-axis

381            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)

Show samples

384        plot_image(image_grid)

Initialize models and data loaders

386    def initialize(self):
390        input_shape = (self.img_channels, self.img_height, self.img_width)

Create the models

393        self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
394        self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
395        self.discriminator_x = Discriminator(input_shape).to(self.device)
396        self.discriminator_y = Discriminator(input_shape).to(self.device)

Create the optmizers

399        self.generator_optimizer = torch.optim.Adam(
400            itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
401            lr=self.learning_rate, betas=self.adam_betas)
402        self.discriminator_optimizer = torch.optim.Adam(
403            itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
404            lr=self.learning_rate, betas=self.adam_betas)

Create the learning rate schedules. The learning rate stars flat until decay_start epochs, and then linearly reduce to at end of training.

409        decay_epochs = self.epochs - self.decay_start
410        self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
411            self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
412        self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
413            self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)

Image transformations

416        transforms_ = [
417            transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
418            transforms.RandomCrop((self.img_height, self.img_width)),
419            transforms.RandomHorizontalFlip(),
420            transforms.ToTensor(),
421            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
422        ]

Training data loader

425        self.dataloader = DataLoader(
426            ImageDataset(self.dataset_name, transforms_, 'train'),
427            batch_size=self.batch_size,
428            shuffle=True,
429            num_workers=self.data_loader_workers,
430        )

Validation data loader

433        self.valid_dataloader = DataLoader(
434            ImageDataset(self.dataset_name, transforms_, "test"),
435            batch_size=5,
436            shuffle=True,
437            num_workers=self.data_loader_workers,
438        )

Training

We aim to solve:

where, translates images from , translates images from , tests if images are from space, tests if images are from space, and

is the generative adversarial loss from the original GAN paper.

is the cyclic loss, where we try to get to be similar to , and to be similar to . Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss could generate anything that's from the distribution of . Now it needs to generate something from the distribution of but still has properties of , so that can re-generate something like .

is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output.

To solve , discriminators and should ascend on the gradient,

That is descend on negative log-likelihood loss.

In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient,

We use least-squares for generators also. The generators should descend on the gradient,

We use generator_xy for and generator_yx for . We use discriminator_x for and discriminator_y for .

440    def run(self):

Replay buffers to keep generated samples

542        gen_x_buffer = ReplayBuffer()
543        gen_y_buffer = ReplayBuffer()

Loop through epochs

546        for epoch in monit.loop(self.epochs):

Loop through the dataset

548            for i, batch in monit.enum('Train', self.dataloader):

Move images to the device

550                data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)

true labels equal to

553                true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
554                                         device=self.device, requires_grad=False)

false labels equal to

556                false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
557                                           device=self.device, requires_grad=False)

Train the generators. This returns the generated images.

561                gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels)

Train discriminators

564                self.optimize_discriminator(data_x, data_y,
565                                            gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
566                                            true_labels, false_labels)

Save training statistics and increment the global step counter

569                tracker.save()
570                tracker.add_global_step(max(len(data_x), len(data_y)))

Save images at intervals

573                batches_done = epoch * len(self.dataloader) + i
574                if batches_done % self.sample_interval == 0:

Save models when sampling images

576                    experiment.save_checkpoint()

Sample images

578                    self.sample_images(batches_done)

Update learning rates

581            self.generator_lr_scheduler.step()
582            self.discriminator_lr_scheduler.step()

New line

584            tracker.new_line()

Optimize the generators with identity, gan and cycle losses.

586    def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):

Change to training mode

592        self.generator_xy.train()
593        self.generator_yx.train()

Identity loss

598        loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
599                         self.identity_loss(self.generator_xy(data_y), data_y))

Generate images and

602        gen_y = self.generator_xy(data_x)
603        gen_x = self.generator_yx(data_y)

GAN loss

608        loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
609                    self.gan_loss(self.discriminator_x(gen_x), true_labels))

Cycle loss

616        loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
617                      self.cycle_loss(self.generator_xy(gen_x), data_y))

Total loss

620        loss_generator = (loss_gan +
621                          self.cyclic_loss_coefficient * loss_cycle +
622                          self.identity_loss_coefficient * loss_identity)

Take a step in the optimizer

625        self.generator_optimizer.zero_grad()
626        loss_generator.backward()
627        self.generator_optimizer.step()

Log losses

630        tracker.add({'loss.generator': loss_generator,
631                     'loss.generator.cycle': loss_cycle,
632                     'loss.generator.gan': loss_gan,
633                     'loss.generator.identity': loss_identity})

Return generated images

636        return gen_x, gen_y

Optimize the discriminators with gan loss.

638    def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
639                               gen_x: torch.Tensor, gen_y: torch.Tensor,
640                               true_labels: torch.Tensor, false_labels: torch.Tensor):

GAN Loss

653        loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
654                              self.gan_loss(self.discriminator_x(gen_x), false_labels) +
655                              self.gan_loss(self.discriminator_y(data_y), true_labels) +
656                              self.gan_loss(self.discriminator_y(gen_y), false_labels))

Take a step in the optimizer

659        self.discriminator_optimizer.zero_grad()
660        loss_discriminator.backward()
661        self.discriminator_optimizer.step()

Log losses

664        tracker.add({'loss.discriminator': loss_discriminator})

Train Cycle GAN

667def train():

Create configurations

672    conf = Configs()

Create an experiment

674    experiment.create(name='cycle_gan')

Calculate configurations. It will calculate conf.run and all other configs required by it.

677    experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'})
678    conf.initialize()

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.

683    experiment.add_pytorch_models(get_modules(conf))

Start and watch the experiment

685    with experiment.start():

Run the training

687        conf.run()

Plot an image with matplotlib

690def plot_image(img: torch.Tensor):
694    from matplotlib import pyplot as plt

Move tensor to CPU

697    img = img.cpu()

Get min and max values of the image for normalization

699    img_min, img_max = img.min(), img.max()

Scale image values to be 0...1

701    img = (img - img_min) / (img_max - img_min + 1e-5)

We have to change the order of dimensions to HWC.

703    img = img.permute(1, 2, 0)

Show Image

705    plt.imshow(img)

We don't need axes

707    plt.axis('off')

Display

709    plt.show()

Evaluate trained Cycle GAN

712def evaluate():

Set the run UUID from the training run

717    trained_run_uuid = 'f73c1164184711eb9190b74249275441'

Create configs object

719    conf = Configs()

Create experiment

721    experiment.create(name='cycle_gan_inference')

Load hyper parameters set for training

723    conf_dict = experiment.load_configs(trained_run_uuid)

Calculate configurations. We specify the generators 'generator_xy', 'generator_yx' so that it only loads those and their dependencies. Configs like device and img_channels will be calculated, since these are required by generator_xy and generator_yx .

If you want other parameters like dataset_name you should specify them here. If you specify nothing, all the configurations will be calculated, including data loaders. Calculation of configurations and their dependencies will happen when you call experiment.start

732    experiment.configs(conf, conf_dict)
733    conf.initialize()

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.

738    experiment.add_pytorch_models(get_modules(conf))

Specify which run to load from. Loading will actually happen when you call experiment.start

741    experiment.load(trained_run_uuid)

Start the experiment

744    with experiment.start():

Image transformations

746        transforms_ = [
747            transforms.ToTensor(),
748            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
749        ]

Load your own data. Here we try the test set. I was trying with Yosemite photos, they look awesome. You can use conf.dataset_name , if you specified dataset_name as something you wanted to be calculated in the call to experiment.configs

755        dataset = ImageDataset(conf.dataset_name, transforms_, 'train')

Get an image from dataset

757        x_image = dataset[10]['x']

Display the image

759        plot_image(x_image)

Evaluation mode

762        conf.generator_xy.eval()
763        conf.generator_yx.eval()

We don't need gradients

766        with torch.no_grad():

Add batch dimension and move to the device we use

768            data = x_image.unsqueeze(0).to(conf.device)
769            generated_y = conf.generator_xy(data)

Display the generated image.

772        plot_image(generated_y[0].cpu())
773
774
775if __name__ == '__main__':
776    train()

evaluate()