""" Download datasets from https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip and extract them into labml_nn/data/cycle_gan/[DATASET NAME] I've taken pieces of code from https://github.com/eriklindernoren/PyTorch-GAN """ import itertools import random from pathlib import PurePath, Path import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image from torch.utils.data import DataLoader from torch.utils.data import Dataset from torchvision.utils import make_grid from torchvision.utils import save_image from labml import lab, tracker, experiment from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module from labml_helpers.training_loop import TrainingLoopConfigs class ReplayBuffer: def __init__(self, max_size: int = 50): self.max_size = max_size self.data = [] def push_and_pop(self, data): to_return = [] for element in data: if len(self.data) < self.max_size: self.data.append(element) to_return.append(element) else: if random.uniform(0, 1) > 0.5: i = random.randint(0, self.max_size - 1) to_return.append(self.data[i].clone()) self.data[i] = element else: to_return.append(element) return torch.stack(to_return) def load_image(path: str): image = Image.open(path) if image.mode != 'RGB': image = Image.new("RGB", image.size).pase(image) return image class ImageDataset(Dataset): def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str): root = Path(root) self.transform = transforms.Compose(transforms_) self.unaligned = unaligned self.files_A = sorted(str(f) for f in (root / f'{mode}A').iterdir()) self.files_B = sorted(str(f) for f in (root / f'{mode}B').iterdir()) def __getitem__(self, index): return {"a": self.transform(load_image(self.files_A[index % len(self.files_A)])), "b": self.transform(load_image(self.files_B[index % len(self.files_B)]))} def __len__(self): return max(len(self.files_A), len(self.files_B)) def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) class ResidualBlock(Module): def __init__(self, in_features: int): super().__init__() self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), ) def __call__(self, x: torch.Tensor): return x + self.block(x) class GeneratorResNet(Module): def __init__(self, input_shape, num_residual_blocks): super().__init__() channels = input_shape[0] # Initial convolution block out_features = 64 layers = [ nn.ReflectionPad2d(channels), nn.Conv2d(channels, out_features, 7), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Downsampling for _ in range(2): out_features *= 2 layers += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Residual blocks for _ in range(num_residual_blocks): layers += [ResidualBlock(out_features)] # Upsampling for _ in range(2): out_features //= 2 layers += [ nn.Upsample(scale_factor=2), nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Output layer layers += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()] self.layers = nn.Sequential(*layers) self.apply(weights_init_normal) def __call__(self, x): return self.layers(x) class DiscriminatorBlock(Module): def __init__(self, in_filters, out_filters, normalize=True): super().__init__() layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] if normalize: layers.append(nn.InstanceNorm2d(out_filters)) layers.append(nn.LeakyReLU(0.2, inplace=True)) self.layers = nn.Sequential(*layers) def __call__(self, x: torch.Tensor): return self.layers(x) class Discriminator(Module): def __init__(self, input_shape): super().__init__() channels, height, width = input_shape # Calculate output shape of image discriminator (PatchGAN) self.output_shape = (1, height // 2 ** 4, width // 2 ** 4) self.model = nn.Sequential( DiscriminatorBlock(channels, 64, normalize=False), DiscriminatorBlock(64, 128), DiscriminatorBlock(128, 256), DiscriminatorBlock(256, 512), nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(512, 1, 4, padding=1) ) self.apply(weights_init_normal) def forward(self, img): return self.model(img) def sample_images(n, dataset_name, valid_dataloader, generator_ab, generator_ba): """Saves a generated sample from the test set""" batch = next(iter(valid_dataloader)) generator_ab.eval() generator_ba.eval() with torch.no_grad(): real_a, real_b = batch['a'].to(generator_ab.device), batch['b'].to(generator_ba.device) fake_b = generator_ab(real_a) fake_a = generator_ba(real_b) # Arange images along x-axis real_a = make_grid(real_a, nrow=5, normalize=True) real_b = make_grid(real_b, nrow=5, normalize=True) fake_a = make_grid(fake_a, nrow=5, normalize=True) fake_b = make_grid(fake_b, nrow=5, normalize=True) # arange images along y-axis image_grid = torch.cat((real_a, fake_b, real_b, fake_a), 1) save_image(image_grid, "images/%s/%s.png" % (dataset_name, n), normalize=False) class Configs(TrainingLoopConfigs): device: torch.device = DeviceConfigs() loop_count: int = 200 dataset_name: str = 'monet2photo' batch_size: int = 1 data_loader_workers = 8 is_save_models = True learning_rate = 0.0002 adam_betas = (0.5, 0.999) decay_start = 100 identity_loss = torch.nn.L1Loss() cycle_loss = torch.nn.L1Loss() gan_loss = torch.nn.MSELoss() batch_step = 'cycle_gan_batch_step' img_height = 256 img_width = 256 img_channels = 3 n_residual_blocks = 9 cyclic_loss_coefficient = 10.0 identity_loss_coefficient = 5. sample_interval = 100 def run(self): images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name input_shape = (self.img_channels, self.img_height, self.img_width) generator_ab = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) generator_ba = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) discriminator_a = Discriminator(input_shape).to(self.device) discriminator_b = Discriminator(input_shape).to(self.device) generator_optimizer = torch.optim.Adam( itertools.chain(generator_ab.parameters(), generator_ba.parameters()), lr=self.learning_rate, betas=self.adam_betas) discriminator_optimizer = torch.optim.Adam( itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), lr=self.learning_rate, betas=self.adam_betas) decay_epochs = self.loop_count - self.decay_start generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) # Image transformations transforms_ = [ transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), transforms.RandomCrop((self.img_height, self.img_width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] # Training data loader dataloader = DataLoader( ImageDataset(images_path, transforms_, True, 'train'), batch_size=self.batch_size, shuffle=True, num_workers=self.data_loader_workers, ) # Test data loader valid_dataloader = DataLoader( ImageDataset(images_path, transforms_, True, "test"), batch_size=5, shuffle=True, num_workers=self.data_loader_workers, ) # Buffers of previously generated samples fake_a_buffer = ReplayBuffer() fake_b_buffer = ReplayBuffer() for epoch in self.training_loop: for i, batch in enumerate(dataloader): # Set model input real_a, real_b = batch['a'].to(self.device), batch['b'].to(self.device) # adversarial ground truths valid = torch.ones(real_a.size(0), *discriminator_a.output_shape, device=self.device, requires_grad=False) fake = torch.zeros(real_a.size(0), *discriminator_a.output_shape, device=self.device, requires_grad=False) # Train generators generator_ab.train() generator_ba.train() # Identity loss loss_identity = self.identity_loss(generator_ba(real_a), real_a) + \ self.identity_loss(generator_ab(real_b), real_b) # GAN loss fake_b = generator_ab(real_a) fake_a = generator_ba(real_b) loss_gan = self.gan_loss(discriminator_b(fake_b), valid) + \ self.gan_loss(discriminator_a(fake_a), valid) loss_cycle = self.cycle_loss(generator_ba(fake_b), real_a) + \ self.cycle_loss(generator_ab(fake_a), real_b) # Total loss loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle + self.identity_loss_coefficient * loss_identity) generator_optimizer.zero_grad() loss_generator.backward() generator_optimizer.step() # Train discriminators fake_a_replay = fake_a_buffer.push_and_pop(fake_a) fake_b_replay = fake_b_buffer.push_and_pop(fake_b) loss_discriminator = self.gan_loss(discriminator_a(real_a), valid) + \ self.gan_loss(discriminator_a(fake_a_replay.detach()), fake) + \ self.gan_loss(discriminator_b(real_b), valid) + \ self.gan_loss(discriminator_b(fake_b_replay.detach()), fake) discriminator_optimizer.zero_grad() loss_discriminator.backward() discriminator_optimizer.step() tracker.save({'loss.generator': loss_generator, 'loss.discriminator': loss_discriminator, 'loss.generator.cycle': loss_cycle, 'loss.generator.gan': loss_gan, 'loss.generator.identity': loss_identity}) # If at sample interval save image batches_done = epoch * len(dataloader) + i if batches_done % self.sample_interval == 0: sample_images(batches_done, self.dataset_name, valid_dataloader, generator_ab, generator_ba) tracker.add_global_step(max(len(real_a), len(real_b))) # Update learning rates generator_lr_scheduler.step() discriminator_lr_scheduler.step() def main(): conf = Configs() experiment.create(name='cycle_gan') experiment.configs(conf, 'run') with experiment.start(): conf.run() if __name__ == '__main__': main()