""" # Cycle GAN This is an implementation of paper [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593). ### Running the experiment To train the model you need to download datasets from `https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip` and extract them into folder `labml_nn/data/cycle_gan/[DATASET NAME]`. You will also have to `dataset_name` configuration to `[DATASET NAME]`. This defaults to `monet2photo`. I've taken pieces of code from [https://github.com/eriklindernoren/PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN). It is a very good resource if you want to checkout other GAN variations too. """ import itertools import random from pathlib import PurePath, Path from typing import Tuple import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image from torch.utils.data import DataLoader from torch.utils.data import Dataset from torchvision.utils import make_grid from torchvision.utils import save_image from labml import lab, tracker, experiment, monit, configs from labml.configs import BaseConfigs from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module class GeneratorResNet(Module): """ The generator is a residual network. """ def __init__(self, input_shape: Tuple[int, int, int], n_residual_blocks: int): super().__init__() # The number of channels in the input image, which is 3 for RGB images. channels = input_shape[0] # This first block runs a $7\times7$ convolution and maps the image to # a feature map. # The output feature map has same height and width because we have # a padding of $3$. # Reflection padding is used because it gives better image quality at edges. # # `inplace=True` in `ReLU` saves a little bit of memory. out_features = 64 layers = [ nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # We down-sample with two $3 \times 3$ convolutions # with stride of 2 for _ in range(2): out_features *= 2 layers += [ nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # We take this through `n_residual_blocks`. # This module is defined below. for _ in range(n_residual_blocks): layers += [ResidualBlock(out_features)] # Then the resulting feature map is up-sampled # to match the original image height and width. for _ in range(2): out_features //= 2 layers += [ nn.Upsample(scale_factor=2), nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Finally we map the feature map to an RGB image layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()] # Create a sequential module with the layers self.layers = nn.Sequential(*layers) # Initialize weights to $\mathcal{N}(0, 0.2)$ self.apply(weights_init_normal) def __call__(self, x): return self.layers(x) class ResidualBlock(Module): """ This is the residual block, with two convolution layers. """ def __init__(self, in_features: int): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), ) def __call__(self, x: torch.Tensor): return x + self.block(x) class Discriminator(Module): """ This is the discriminator. """ def __init__(self, input_shape: Tuple[int, int, int]): super().__init__() channels, height, width = input_shape # Output of the discriminator is also map of probabilities* # whether each region of the image is real or generated self.output_shape = (1, height // 2 ** 4, width // 2 ** 4) self.layers = nn.Sequential( # Each of these blocks will shrink the height and width by a factor of 2 DiscriminatorBlock(channels, 64, normalize=False), DiscriminatorBlock(64, 128), DiscriminatorBlock(128, 256), DiscriminatorBlock(256, 512), # Zero pad on top and left to keep the output height and width same # with the $4 \times 4$ kernel nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(512, 1, kernel_size=4, padding=1) ) # Initialize weights to $\mathcal{N}(0, 0.2)$ self.apply(weights_init_normal) def forward(self, img): return self.layers(img) class DiscriminatorBlock(Module): """ This is the discriminator block module. It does a convolution, an optional normalization, and a leaky relu. It shrinks the height and width of the input feature map by half. """ def __init__(self, in_filters: int, out_filters: int, normalize: bool = True): super().__init__() layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)] if normalize: layers.append(nn.InstanceNorm2d(out_filters)) layers.append(nn.LeakyReLU(0.2, inplace=True)) self.layers = nn.Sequential(*layers) def __call__(self, x: torch.Tensor): return self.layers(x) def weights_init_normal(m): """ Initialize convolution layer weights to $\mathcal{N}(0, 0.2)$ """ classname = m.__class__.__name__ if classname.find("Conv") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) def load_image(path: str): """ Loads an image and change to RGB if in grey-scale. """ image = Image.open(path) if image.mode != 'RGB': image = Image.new("RGB", image.size).paste(image) return image class ImageDataset(Dataset): """ Dataset to load images """ def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str): root = Path(root) self.transform = transforms.Compose(transforms_) self.unaligned = unaligned self.files_a = sorted(str(f) for f in (root / f'{mode}A').iterdir()) self.files_b = sorted(str(f) for f in (root / f'{mode}B').iterdir()) def __getitem__(self, index): return {"x": self.transform(load_image(self.files_a[index % len(self.files_a)])), "y": self.transform(load_image(self.files_b[index % len(self.files_b)]))} def __len__(self): return max(len(self.files_a), len(self.files_b)) class ReplayBuffer: """ Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it. The replay buffer returns the newly added image with a probability of $0.5$. Otherwise it sends an older generated image and and replaces the older image with the new generated image. This is done to reduce model oscillation. """ def __init__(self, max_size: int = 50): self.max_size = max_size self.data = [] def push_and_pop(self, data: torch.Tensor): """Add/retrieve an image""" data = data.detach() res = [] for element in data: if len(self.data) < self.max_size: self.data.append(element) res.append(element) else: if random.uniform(0, 1) > 0.5: i = random.randint(0, self.max_size - 1) res.append(self.data[i].clone()) self.data[i] = element else: res.append(element) return torch.stack(res) class Configs(BaseConfigs): """## Configurations""" # `DeviceConfigs` will pick a GPU if available device: torch.device = DeviceConfigs() # Hyper-parameters epochs: int = 200 dataset_name: str = 'monet2photo' batch_size: int = 1 data_loader_workers = 8 learning_rate = 0.0002 adam_betas = (0.5, 0.999) decay_start = 100 # The paper suggests using a least-squares loss instead of # negative log-likelihood, at it is found to be more stable. gan_loss = torch.nn.MSELoss() # L1 loss is used for cycle loss and identity loss cycle_loss = torch.nn.L1Loss() identity_loss = torch.nn.L1Loss() # Image dimensions img_height = 256 img_width = 256 img_channels = 3 # Number of residual blocks in the generator n_residual_blocks = 9 # Loss coefficients cyclic_loss_coefficient = 10.0 identity_loss_coefficient = 5. sample_interval = 500 # Models generator_xy: GeneratorResNet generator_yx: GeneratorResNet discriminator_x: Discriminator discriminator_y: Discriminator # Optimizers generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam # Learning rate schedules generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR # Data loaders dataloader: DataLoader valid_dataloader: DataLoader def sample_images(self, n: int): """Generate samples from test set and save them""" batch = next(iter(self.valid_dataloader)) self.generator_xy.eval() self.generator_yx.eval() with torch.no_grad(): data_x, data_y = batch['x'].to(self.generator_xy.device), batch['y'].to(self.generator_yx.device) gen_y = self.generator_xy(data_x) gen_x = self.generator_yx(data_y) # Arrange images along x-axis data_x = make_grid(data_x, nrow=5, normalize=True) data_y = make_grid(data_y, nrow=5, normalize=True) gen_x = make_grid(gen_x, nrow=5, normalize=True) gen_y = make_grid(gen_y, nrow=5, normalize=True) # Arrange images along y-axis image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1) # Save grid save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) def run(self): """ ## Training We aim to solve: $$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$ where, \begin{align} \mathcal{L}(G, F, D_X, D_Y) &= \mathcal{L}_{GAN}(G, D_Y, X, Y) \\ &+ \mathcal{L}_{GAN}(F, D_X, Y, X) \\ &+ \lambda_1 \mathcal{L}_{cyc}(G, F) \\ &+ \lambda_2 \mathcal{L}_{identity}(G, F) \\ \\ \mathcal{L}_{GAN}(G, F, D_Y, X, Y) &= \mathbb{E}_{y \sim p_{data}(y)} \Big[log D_Y(y)\Big] \\ &+ \mathbb{E}_{x \sim p_{data}(x)} \bigg[log\Big(1 - D_Y(G(x))\Big)\bigg] \\ &+ \mathbb{E}_{x \sim p_{data}(x)} \Big[log D_X(x)\Big] \\ &+ \mathbb{E}_{y \sim p_{data}(y)} \bigg[log\Big(1 - D_X(F(y))\Big)\bigg] \\ \\ \mathcal{L}_{cyc}(G, F) &= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(G(x)) - x \lVert_1\Big] \\ &+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(F(y)) - y \rVert_1\Big] \\ \\ \mathcal{L}_{identity}(G, F) &= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(x) - x \lVert_1\Big] \\ &+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(y) - y \rVert_1\Big] \\ \end{align} $\mathcal{L}_{GAN}$ is the generative adversarial loss from the original GAN paper. $\mathcal{L}_{cyc}$ is the cyclic loss, where we try to get $F(G(x))$ to be similar to $x$, and $G(F(y))$ to be similar to $y$. Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It train the generators to generate an image of the other distribution that is similar to the original image. Without this loss $G(x)$ could generate anything that's from the distribution of $Y$. Now it needs to generate something from the distribution of $Y$ but still have properties of $x$, so that $F(G(x)$ can re-generate something like $x$. $\mathcal{L}_{cyc}$ is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output. To solve $G^{\*}, F^{\*}$, discriminators $D_X$ and $D_Y$ should **ascend** on the gradient, \begin{align} \nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m &\Bigg[ \log D_Y\Big(y^{(i)}\Big) \\ &+ \log \Big(1 - D_Y\Big(G\Big(x^{(i)}\Big)\Big)\Big) \\ &+ \log D_X\Big(x^{(i)}\Big) \\ & +\log\Big(1 - D_X\Big(F\Big(y^{(i)}\Big)\Big)\Big) \Bigg] \end{align} That is descend on *negative* log-likelihood loss. In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator labelling real images with 1, and generated images with 0. So we want to descend on the gradient, \begin{align} \nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m &\Bigg[ \bigg(D_Y\Big(y^{(i)}\Big) - 1\bigg)^2 \\ &+ D_Y\Big(G\Big(x^{(i)}\Big)\Big)^2 \\ &+ \bigg(D_X\Big(x^{(i)}\Big) - 1\bigg)^2 \\ &+ D_X\Big(F\Big(y^{(i)}\Big)\Big)^2 \Bigg] \end{align} We use least-squares for generators also. The generators should *descend* on the gradient, \begin{align} \nabla_{\theta_{F, G}} \frac{1}{m} \sum_{i=1}^m &\Bigg[ \bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2 \\ &+ \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2 \\ &+ \mathcal{L}_{cyc}(G, F) + \mathcal{L}_{identity}(G, F) \Bigg] \end{align} We use `generator_xy` for $G$ and `generator_yx$ for $F$. We use `discriminator_x$ for $D_X$ and `discriminator_y` for $D_Y$. """ # Replay buffers to keep generated samples gen_x_buffer = ReplayBuffer() gen_y_buffer = ReplayBuffer() # Loop through epochs for epoch in monit.loop(self.epochs): # Loop through the dataset for i, batch in enumerate(self.dataloader): # Move images to the device data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device) # true labels equal to $1$ true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape, device=self.device, requires_grad=False) # false labels equal to $0$ false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape, device=self.device, requires_grad=False) # Train the generators. # This returns the generated images. gen_x, gen_y = self.optimize_generators(data_x, data_y, true_labels) # Train discriminators self.optimize_discriminator(data_x, data_y, gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y), true_labels, false_labels) # Save training statistics and increment the global step counter tracker.save() tracker.add_global_step(max(len(data_x), len(data_y))) # Save images at intervals batches_done = epoch * len(self.dataloader) + i if batches_done % self.sample_interval == 0: self.sample_images(batches_done) # Update learning rates self.generator_lr_scheduler.step() self.discriminator_lr_scheduler.step() def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor): """ ### Optimize the generators with identity, gan and cycle losses. """ # Change to training mode self.generator_xy.train() self.generator_yx.train() # Identity loss # $$\lVert F(G(x^{(i)})) - x^{(i)} \lVert_1\ # \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1$$ loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) + self.identity_loss(self.generator_xy(data_y), data_y)) # Generate images $G(x)$ and $F(y)$ gen_y = self.generator_xy(data_x) gen_x = self.generator_yx(data_y) # GAN loss # $$\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2 # + \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2$$ loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) + self.gan_loss(self.discriminator_x(gen_x), true_labels)) # Cycle loss # $$ # \lVert F(G(x^{(i)})) - x^{(i)} \lVert_1 + # \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1 # $$ loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) + self.cycle_loss(self.generator_xy(gen_x), data_y)) # Total loss loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle + self.identity_loss_coefficient * loss_identity) # Take a step in the optimizer self.generator_optimizer.zero_grad() loss_generator.backward() self.generator_optimizer.step() # Log losses tracker.add({'loss.generator': loss_generator, 'loss.generator.cycle': loss_cycle, 'loss.generator.gan': loss_gan, 'loss.generator.identity': loss_identity}) # Return generated images return gen_x, gen_y def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor, gen_x: torch.Tensor, gen_y: torch.Tensor, true_labels: torch.Tensor, false_labels: torch.Tensor): """ ### Optimize the discriminators with gan loss. """ # GAN Loss # \begin{align} # \bigg(D_Y\Big(y ^ {(i)}\Big) - 1\bigg) ^ 2 # + D_Y\Big(G\Big(x ^ {(i)}\Big)\Big) ^ 2 + \\ # \bigg(D_X\Big(x ^ {(i)}\Big) - 1\bigg) ^ 2 # + D_X\Big(F\Big(y ^ {(i)}\Big)\Big) ^ 2 # \end{align} loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) + self.gan_loss(self.discriminator_x(gen_x), false_labels) + self.gan_loss(self.discriminator_y(data_y), true_labels) + self.gan_loss(self.discriminator_y(gen_y), false_labels)) # Take a step in the optimizer self.discriminator_optimizer.zero_grad() loss_discriminator.backward() self.discriminator_optimizer.step() # Log losses tracker.add({'loss.discriminator': loss_discriminator}) @configs.setup([Configs.generator_xy, Configs.generator_yx, Configs.discriminator_x, Configs.discriminator_y, Configs.generator_optimizer, Configs.discriminator_optimizer, Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler]) def setup_models(self: Configs): """ ## setup the models """ input_shape = (self.img_channels, self.img_height, self.img_width) # Create the models self.generator_xy = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) self.generator_yx = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) self.discriminator_x = Discriminator(input_shape).to(self.device) self.discriminator_y = Discriminator(input_shape).to(self.device) # Create the optmizers self.generator_optimizer = torch.optim.Adam( itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()), lr=self.learning_rate, betas=self.adam_betas) self.discriminator_optimizer = torch.optim.Adam( itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()), lr=self.learning_rate, betas=self.adam_betas) # Create the learning rate schedules. # The learning rate stars flat until `decay_start` epochs, # and then linearly reduces to $0$ at end of training. decay_epochs = self.epochs - self.decay_start self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) @configs.setup([Configs.dataloader, Configs.valid_dataloader]) def setup_dataloader(self: Configs): """ ## setup the data loaders """ # Location of the dataset images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name # Image transformations transforms_ = [ transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), transforms.RandomCrop((self.img_height, self.img_width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] # Training data loader self.dataloader = DataLoader( ImageDataset(images_path, transforms_, True, 'train'), batch_size=self.batch_size, shuffle=True, num_workers=self.data_loader_workers, ) # Validation data loader self.valid_dataloader = DataLoader( ImageDataset(images_path, transforms_, True, "test"), batch_size=5, shuffle=True, num_workers=self.data_loader_workers, ) def main(): conf = Configs() experiment.create(name='cycle_gan') experiment.configs(conf, 'run') with experiment.start(): conf.run() if __name__ == '__main__': main()