mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-01 03:43:09 +08:00 
			
		
		
		
	cycle gan
This commit is contained in:
		
							
								
								
									
										362
									
								
								labml_nn/gan/cycle_gan.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										362
									
								
								labml_nn/gan/cycle_gan.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,362 @@ | |||||||
|  | """ | ||||||
|  | Download datasets from https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip | ||||||
|  | and extract them into labml_nn/data/cycle_gan/[DATASET NAME] | ||||||
|  |  | ||||||
|  | I've taken pieces of code from https://github.com/eriklindernoren/PyTorch-GAN | ||||||
|  | """ | ||||||
|  |  | ||||||
|  | import itertools | ||||||
|  | import random | ||||||
|  | from pathlib import PurePath, Path | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torchvision.transforms as transforms | ||||||
|  | from PIL import Image | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | from torch.utils.data import Dataset | ||||||
|  | from torchvision.utils import make_grid | ||||||
|  | from torchvision.utils import save_image | ||||||
|  |  | ||||||
|  | from labml import lab, tracker, experiment | ||||||
|  | from labml_helpers.device import DeviceConfigs | ||||||
|  | from labml_helpers.module import Module | ||||||
|  | from labml_helpers.training_loop import TrainingLoopConfigs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ReplayBuffer: | ||||||
|  |     def __init__(self, max_size: int = 50): | ||||||
|  |         self.max_size = max_size | ||||||
|  |         self.data = [] | ||||||
|  |  | ||||||
|  |     def push_and_pop(self, data): | ||||||
|  |         to_return = [] | ||||||
|  |         for element in data: | ||||||
|  |             if len(self.data) < self.max_size: | ||||||
|  |                 self.data.append(element) | ||||||
|  |                 to_return.append(element) | ||||||
|  |             else: | ||||||
|  |                 if random.uniform(0, 1) > 0.5: | ||||||
|  |                     i = random.randint(0, self.max_size - 1) | ||||||
|  |                     to_return.append(self.data[i].clone()) | ||||||
|  |                     self.data[i] = element | ||||||
|  |                 else: | ||||||
|  |                     to_return.append(element) | ||||||
|  |         return torch.stack(to_return) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def load_image(path: str): | ||||||
|  |     image = Image.open(path) | ||||||
|  |     if image.mode != 'RGB': | ||||||
|  |         image = Image.new("RGB", image.size).pase(image) | ||||||
|  |  | ||||||
|  |     return image | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ImageDataset(Dataset): | ||||||
|  |     def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str): | ||||||
|  |         root = Path(root) | ||||||
|  |         self.transform = transforms.Compose(transforms_) | ||||||
|  |         self.unaligned = unaligned | ||||||
|  |  | ||||||
|  |         self.files_A = sorted(str(f) for f in (root / f'{mode}A').iterdir()) | ||||||
|  |         self.files_B = sorted(str(f) for f in (root / f'{mode}B').iterdir()) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         return {"a": self.transform(load_image(self.files_A[index % len(self.files_A)])), | ||||||
|  |                 "b": self.transform(load_image(self.files_B[index % len(self.files_B)]))} | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return max(len(self.files_A), len(self.files_B)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def weights_init_normal(m): | ||||||
|  |     classname = m.__class__.__name__ | ||||||
|  |     if classname.find("Conv") != -1: | ||||||
|  |         torch.nn.init.normal_(m.weight.data, 0.0, 0.02) | ||||||
|  |     elif classname.find("BatchNorm2d") != -1: | ||||||
|  |         torch.nn.init.normal_(m.weight.data, 1.0, 0.02) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ResidualBlock(Module): | ||||||
|  |     def __init__(self, in_features: int): | ||||||
|  |         super().__init__() | ||||||
|  |         self.block = nn.Sequential( | ||||||
|  |             nn.ReflectionPad2d(1), | ||||||
|  |             nn.Conv2d(in_features, in_features, 3), | ||||||
|  |             nn.InstanceNorm2d(in_features), | ||||||
|  |             nn.ReLU(inplace=True), | ||||||
|  |             nn.ReflectionPad2d(1), | ||||||
|  |             nn.Conv2d(in_features, in_features, 3), | ||||||
|  |             nn.InstanceNorm2d(in_features), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __call__(self, x: torch.Tensor): | ||||||
|  |         return x + self.block(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class GeneratorResNet(Module): | ||||||
|  |     def __init__(self, input_shape, num_residual_blocks): | ||||||
|  |         super().__init__() | ||||||
|  |         channels = input_shape[0] | ||||||
|  |  | ||||||
|  |         # Initial convolution block | ||||||
|  |         out_features = 64 | ||||||
|  |         layers = [ | ||||||
|  |             nn.ReflectionPad2d(channels), | ||||||
|  |             nn.Conv2d(channels, out_features, 7), | ||||||
|  |             nn.InstanceNorm2d(out_features), | ||||||
|  |             nn.ReLU(inplace=True), | ||||||
|  |         ] | ||||||
|  |         in_features = out_features | ||||||
|  |  | ||||||
|  |         # Downsampling | ||||||
|  |         for _ in range(2): | ||||||
|  |             out_features *= 2 | ||||||
|  |             layers += [ | ||||||
|  |                 nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), | ||||||
|  |                 nn.InstanceNorm2d(out_features), | ||||||
|  |                 nn.ReLU(inplace=True), | ||||||
|  |             ] | ||||||
|  |             in_features = out_features | ||||||
|  |  | ||||||
|  |         # Residual blocks | ||||||
|  |         for _ in range(num_residual_blocks): | ||||||
|  |             layers += [ResidualBlock(out_features)] | ||||||
|  |  | ||||||
|  |         # Upsampling | ||||||
|  |         for _ in range(2): | ||||||
|  |             out_features //= 2 | ||||||
|  |             layers += [ | ||||||
|  |                 nn.Upsample(scale_factor=2), | ||||||
|  |                 nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), | ||||||
|  |                 nn.InstanceNorm2d(out_features), | ||||||
|  |                 nn.ReLU(inplace=True), | ||||||
|  |             ] | ||||||
|  |             in_features = out_features | ||||||
|  |  | ||||||
|  |         # Output layer | ||||||
|  |         layers += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()] | ||||||
|  |  | ||||||
|  |         self.layers = nn.Sequential(*layers) | ||||||
|  |  | ||||||
|  |         self.apply(weights_init_normal) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         return self.layers(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DiscriminatorBlock(Module): | ||||||
|  |     def __init__(self, in_filters, out_filters, normalize=True): | ||||||
|  |         super().__init__() | ||||||
|  |         layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] | ||||||
|  |         if normalize: | ||||||
|  |             layers.append(nn.InstanceNorm2d(out_filters)) | ||||||
|  |         layers.append(nn.LeakyReLU(0.2, inplace=True)) | ||||||
|  |         self.layers = nn.Sequential(*layers) | ||||||
|  |  | ||||||
|  |     def __call__(self, x: torch.Tensor): | ||||||
|  |         return self.layers(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Discriminator(Module): | ||||||
|  |     def __init__(self, input_shape): | ||||||
|  |         super().__init__() | ||||||
|  |         channels, height, width = input_shape | ||||||
|  |  | ||||||
|  |         # Calculate output shape of image discriminator (PatchGAN) | ||||||
|  |         self.output_shape = (1, height // 2 ** 4, width // 2 ** 4) | ||||||
|  |  | ||||||
|  |         self.model = nn.Sequential( | ||||||
|  |             DiscriminatorBlock(channels, 64, normalize=False), | ||||||
|  |             DiscriminatorBlock(64, 128), | ||||||
|  |             DiscriminatorBlock(128, 256), | ||||||
|  |             DiscriminatorBlock(256, 512), | ||||||
|  |             nn.ZeroPad2d((1, 0, 1, 0)), | ||||||
|  |             nn.Conv2d(512, 1, 4, padding=1) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.apply(weights_init_normal) | ||||||
|  |  | ||||||
|  |     def forward(self, img): | ||||||
|  |         return self.model(img) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def sample_images(n, dataset_name, valid_dataloader, generator_ab, generator_ba): | ||||||
|  |     """Saves a generated sample from the test set""" | ||||||
|  |     batch = next(iter(valid_dataloader)) | ||||||
|  |     generator_ab.eval() | ||||||
|  |     generator_ba.eval() | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         real_a, real_b = batch['a'].to(generator_ab.device), batch['b'].to(generator_ba.device) | ||||||
|  |         fake_b = generator_ab(real_a) | ||||||
|  |         fake_a = generator_ba(real_b) | ||||||
|  |  | ||||||
|  |         # Arange images along x-axis | ||||||
|  |         real_a = make_grid(real_a, nrow=5, normalize=True) | ||||||
|  |         real_b = make_grid(real_b, nrow=5, normalize=True) | ||||||
|  |         fake_a = make_grid(fake_a, nrow=5, normalize=True) | ||||||
|  |         fake_b = make_grid(fake_b, nrow=5, normalize=True) | ||||||
|  |  | ||||||
|  |         # arange images along y-axis | ||||||
|  |         image_grid = torch.cat((real_a, fake_b, real_b, fake_a), 1) | ||||||
|  |     save_image(image_grid, "images/%s/%s.png" % (dataset_name, n), normalize=False) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Configs(TrainingLoopConfigs): | ||||||
|  |     device: torch.device = DeviceConfigs() | ||||||
|  |     loop_count: int = 200 | ||||||
|  |     dataset_name: str = 'monet2photo' | ||||||
|  |     batch_size: int = 1 | ||||||
|  |  | ||||||
|  |     data_loader_workers = 8 | ||||||
|  |     is_save_models = True | ||||||
|  |  | ||||||
|  |     learning_rate = 0.0002 | ||||||
|  |     adam_betas = (0.5, 0.999) | ||||||
|  |     decay_start = 100 | ||||||
|  |  | ||||||
|  |     identity_loss = torch.nn.L1Loss() | ||||||
|  |     cycle_loss = torch.nn.L1Loss() | ||||||
|  |     gan_loss = torch.nn.MSELoss() | ||||||
|  |  | ||||||
|  |     batch_step = 'cycle_gan_batch_step' | ||||||
|  |  | ||||||
|  |     img_height = 256 | ||||||
|  |     img_width = 256 | ||||||
|  |     img_channels = 3 | ||||||
|  |  | ||||||
|  |     n_residual_blocks = 9 | ||||||
|  |  | ||||||
|  |     cyclic_loss_coefficient = 10.0 | ||||||
|  |     identity_loss_coefficient = 5. | ||||||
|  |  | ||||||
|  |     sample_interval = 100 | ||||||
|  |  | ||||||
|  |     def run(self): | ||||||
|  |         images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name | ||||||
|  |  | ||||||
|  |         input_shape = (self.img_channels, self.img_height, self.img_width) | ||||||
|  |         generator_ab = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) | ||||||
|  |         generator_ba = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) | ||||||
|  |         discriminator_a = Discriminator(input_shape).to(self.device) | ||||||
|  |         discriminator_b = Discriminator(input_shape).to(self.device) | ||||||
|  |  | ||||||
|  |         generator_optimizer = torch.optim.Adam( | ||||||
|  |             itertools.chain(generator_ab.parameters(), generator_ba.parameters()), | ||||||
|  |             lr=self.learning_rate, betas=self.adam_betas) | ||||||
|  |         discriminator_optimizer = torch.optim.Adam( | ||||||
|  |             itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), | ||||||
|  |             lr=self.learning_rate, betas=self.adam_betas) | ||||||
|  |  | ||||||
|  |         decay_epochs = self.loop_count - self.decay_start | ||||||
|  |         generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||||
|  |             generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||||
|  |         discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||||
|  |             discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||||
|  |  | ||||||
|  |         # Image transformations | ||||||
|  |         transforms_ = [ | ||||||
|  |             transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), | ||||||
|  |             transforms.RandomCrop((self.img_height, self.img_width)), | ||||||
|  |             transforms.RandomHorizontalFlip(), | ||||||
|  |             transforms.ToTensor(), | ||||||
|  |             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         # Training data loader | ||||||
|  |         dataloader = DataLoader( | ||||||
|  |             ImageDataset(images_path, transforms_, True, 'train'), | ||||||
|  |             batch_size=self.batch_size, | ||||||
|  |             shuffle=True, | ||||||
|  |             num_workers=self.data_loader_workers, | ||||||
|  |         ) | ||||||
|  |         # Test data loader | ||||||
|  |         valid_dataloader = DataLoader( | ||||||
|  |             ImageDataset(images_path, transforms_, True, "test"), | ||||||
|  |             batch_size=5, | ||||||
|  |             shuffle=True, | ||||||
|  |             num_workers=self.data_loader_workers, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # Buffers of previously generated samples | ||||||
|  |         fake_a_buffer = ReplayBuffer() | ||||||
|  |         fake_b_buffer = ReplayBuffer() | ||||||
|  |  | ||||||
|  |         for epoch in self.training_loop: | ||||||
|  |             for i, batch in enumerate(dataloader): | ||||||
|  |                 # Set model input | ||||||
|  |                 real_a, real_b = batch['a'].to(self.device), batch['b'].to(self.device) | ||||||
|  |  | ||||||
|  |                 # adversarial ground truths | ||||||
|  |                 valid = torch.ones(real_a.size(0), *discriminator_a.output_shape, | ||||||
|  |                                    device=self.device, requires_grad=False) | ||||||
|  |                 fake = torch.zeros(real_a.size(0), *discriminator_a.output_shape, | ||||||
|  |                                    device=self.device, requires_grad=False) | ||||||
|  |  | ||||||
|  |                 #  Train generators | ||||||
|  |                 generator_ab.train() | ||||||
|  |                 generator_ba.train() | ||||||
|  |  | ||||||
|  |                 # Identity loss | ||||||
|  |                 loss_identity = self.identity_loss(generator_ba(real_a), real_a) + \ | ||||||
|  |                                 self.identity_loss(generator_ab(real_b), real_b) | ||||||
|  |  | ||||||
|  |                 # GAN loss | ||||||
|  |                 fake_b = generator_ab(real_a) | ||||||
|  |                 fake_a = generator_ba(real_b) | ||||||
|  |  | ||||||
|  |                 loss_gan = self.gan_loss(discriminator_b(fake_b), valid) + \ | ||||||
|  |                            self.gan_loss(discriminator_a(fake_a), valid) | ||||||
|  |  | ||||||
|  |                 loss_cycle = self.cycle_loss(generator_ba(fake_b), real_a) + \ | ||||||
|  |                              self.cycle_loss(generator_ab(fake_a), real_b) | ||||||
|  |  | ||||||
|  |                 # Total loss | ||||||
|  |                 loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle | ||||||
|  |                                   + self.identity_loss_coefficient * loss_identity) | ||||||
|  |  | ||||||
|  |                 generator_optimizer.zero_grad() | ||||||
|  |                 loss_generator.backward() | ||||||
|  |                 generator_optimizer.step() | ||||||
|  |  | ||||||
|  |                 #  Train discriminators | ||||||
|  |                 fake_a_replay = fake_a_buffer.push_and_pop(fake_a) | ||||||
|  |                 fake_b_replay = fake_b_buffer.push_and_pop(fake_b) | ||||||
|  |                 loss_discriminator = self.gan_loss(discriminator_a(real_a), valid) + \ | ||||||
|  |                                      self.gan_loss(discriminator_a(fake_a_replay.detach()), fake) + \ | ||||||
|  |                                      self.gan_loss(discriminator_b(real_b), valid) + \ | ||||||
|  |                                      self.gan_loss(discriminator_b(fake_b_replay.detach()), fake) | ||||||
|  |  | ||||||
|  |                 discriminator_optimizer.zero_grad() | ||||||
|  |                 loss_discriminator.backward() | ||||||
|  |                 discriminator_optimizer.step() | ||||||
|  |  | ||||||
|  |                 tracker.save({'loss.generator': loss_generator, | ||||||
|  |                               'loss.discriminator': loss_discriminator, | ||||||
|  |                               'loss.generator.cycle': loss_cycle, | ||||||
|  |                               'loss.generator.gan': loss_gan, | ||||||
|  |                               'loss.generator.identity': loss_identity}) | ||||||
|  |  | ||||||
|  |                 # If at sample interval save image | ||||||
|  |                 batches_done = epoch * len(dataloader) + i | ||||||
|  |                 if batches_done % self.sample_interval == 0: | ||||||
|  |                     sample_images(batches_done, self.dataset_name, valid_dataloader, generator_ab, generator_ba) | ||||||
|  |  | ||||||
|  |                 tracker.add_global_step(max(len(real_a), len(real_b))) | ||||||
|  |  | ||||||
|  |             # Update learning rates | ||||||
|  |             generator_lr_scheduler.step() | ||||||
|  |             discriminator_lr_scheduler.step() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     conf = Configs() | ||||||
|  |     experiment.create(name='cycle_gan') | ||||||
|  |     experiment.configs(conf, 'run') | ||||||
|  |     with experiment.start(): | ||||||
|  |         conf.run() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri