mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	cycle gan
This commit is contained in:
		
							
								
								
									
										362
									
								
								labml_nn/gan/cycle_gan.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										362
									
								
								labml_nn/gan/cycle_gan.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,362 @@ | ||||
| """ | ||||
| Download datasets from https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/[DATASET NAME].zip | ||||
| and extract them into labml_nn/data/cycle_gan/[DATASET NAME] | ||||
|  | ||||
| I've taken pieces of code from https://github.com/eriklindernoren/PyTorch-GAN | ||||
| """ | ||||
|  | ||||
| import itertools | ||||
| import random | ||||
| from pathlib import PurePath, Path | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torchvision.transforms as transforms | ||||
| from PIL import Image | ||||
| from torch.utils.data import DataLoader | ||||
| from torch.utils.data import Dataset | ||||
| from torchvision.utils import make_grid | ||||
| from torchvision.utils import save_image | ||||
|  | ||||
| from labml import lab, tracker, experiment | ||||
| from labml_helpers.device import DeviceConfigs | ||||
| from labml_helpers.module import Module | ||||
| from labml_helpers.training_loop import TrainingLoopConfigs | ||||
|  | ||||
|  | ||||
| class ReplayBuffer: | ||||
|     def __init__(self, max_size: int = 50): | ||||
|         self.max_size = max_size | ||||
|         self.data = [] | ||||
|  | ||||
|     def push_and_pop(self, data): | ||||
|         to_return = [] | ||||
|         for element in data: | ||||
|             if len(self.data) < self.max_size: | ||||
|                 self.data.append(element) | ||||
|                 to_return.append(element) | ||||
|             else: | ||||
|                 if random.uniform(0, 1) > 0.5: | ||||
|                     i = random.randint(0, self.max_size - 1) | ||||
|                     to_return.append(self.data[i].clone()) | ||||
|                     self.data[i] = element | ||||
|                 else: | ||||
|                     to_return.append(element) | ||||
|         return torch.stack(to_return) | ||||
|  | ||||
|  | ||||
| def load_image(path: str): | ||||
|     image = Image.open(path) | ||||
|     if image.mode != 'RGB': | ||||
|         image = Image.new("RGB", image.size).pase(image) | ||||
|  | ||||
|     return image | ||||
|  | ||||
|  | ||||
| class ImageDataset(Dataset): | ||||
|     def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str): | ||||
|         root = Path(root) | ||||
|         self.transform = transforms.Compose(transforms_) | ||||
|         self.unaligned = unaligned | ||||
|  | ||||
|         self.files_A = sorted(str(f) for f in (root / f'{mode}A').iterdir()) | ||||
|         self.files_B = sorted(str(f) for f in (root / f'{mode}B').iterdir()) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         return {"a": self.transform(load_image(self.files_A[index % len(self.files_A)])), | ||||
|                 "b": self.transform(load_image(self.files_B[index % len(self.files_B)]))} | ||||
|  | ||||
|     def __len__(self): | ||||
|         return max(len(self.files_A), len(self.files_B)) | ||||
|  | ||||
|  | ||||
| def weights_init_normal(m): | ||||
|     classname = m.__class__.__name__ | ||||
|     if classname.find("Conv") != -1: | ||||
|         torch.nn.init.normal_(m.weight.data, 0.0, 0.02) | ||||
|     elif classname.find("BatchNorm2d") != -1: | ||||
|         torch.nn.init.normal_(m.weight.data, 1.0, 0.02) | ||||
|  | ||||
|  | ||||
| class ResidualBlock(Module): | ||||
|     def __init__(self, in_features: int): | ||||
|         super().__init__() | ||||
|         self.block = nn.Sequential( | ||||
|             nn.ReflectionPad2d(1), | ||||
|             nn.Conv2d(in_features, in_features, 3), | ||||
|             nn.InstanceNorm2d(in_features), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.ReflectionPad2d(1), | ||||
|             nn.Conv2d(in_features, in_features, 3), | ||||
|             nn.InstanceNorm2d(in_features), | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|         return x + self.block(x) | ||||
|  | ||||
|  | ||||
| class GeneratorResNet(Module): | ||||
|     def __init__(self, input_shape, num_residual_blocks): | ||||
|         super().__init__() | ||||
|         channels = input_shape[0] | ||||
|  | ||||
|         # Initial convolution block | ||||
|         out_features = 64 | ||||
|         layers = [ | ||||
|             nn.ReflectionPad2d(channels), | ||||
|             nn.Conv2d(channels, out_features, 7), | ||||
|             nn.InstanceNorm2d(out_features), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ] | ||||
|         in_features = out_features | ||||
|  | ||||
|         # Downsampling | ||||
|         for _ in range(2): | ||||
|             out_features *= 2 | ||||
|             layers += [ | ||||
|                 nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), | ||||
|                 nn.InstanceNorm2d(out_features), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ] | ||||
|             in_features = out_features | ||||
|  | ||||
|         # Residual blocks | ||||
|         for _ in range(num_residual_blocks): | ||||
|             layers += [ResidualBlock(out_features)] | ||||
|  | ||||
|         # Upsampling | ||||
|         for _ in range(2): | ||||
|             out_features //= 2 | ||||
|             layers += [ | ||||
|                 nn.Upsample(scale_factor=2), | ||||
|                 nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), | ||||
|                 nn.InstanceNorm2d(out_features), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ] | ||||
|             in_features = out_features | ||||
|  | ||||
|         # Output layer | ||||
|         layers += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()] | ||||
|  | ||||
|         self.layers = nn.Sequential(*layers) | ||||
|  | ||||
|         self.apply(weights_init_normal) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         return self.layers(x) | ||||
|  | ||||
|  | ||||
| class DiscriminatorBlock(Module): | ||||
|     def __init__(self, in_filters, out_filters, normalize=True): | ||||
|         super().__init__() | ||||
|         layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] | ||||
|         if normalize: | ||||
|             layers.append(nn.InstanceNorm2d(out_filters)) | ||||
|         layers.append(nn.LeakyReLU(0.2, inplace=True)) | ||||
|         self.layers = nn.Sequential(*layers) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|         return self.layers(x) | ||||
|  | ||||
|  | ||||
| class Discriminator(Module): | ||||
|     def __init__(self, input_shape): | ||||
|         super().__init__() | ||||
|         channels, height, width = input_shape | ||||
|  | ||||
|         # Calculate output shape of image discriminator (PatchGAN) | ||||
|         self.output_shape = (1, height // 2 ** 4, width // 2 ** 4) | ||||
|  | ||||
|         self.model = nn.Sequential( | ||||
|             DiscriminatorBlock(channels, 64, normalize=False), | ||||
|             DiscriminatorBlock(64, 128), | ||||
|             DiscriminatorBlock(128, 256), | ||||
|             DiscriminatorBlock(256, 512), | ||||
|             nn.ZeroPad2d((1, 0, 1, 0)), | ||||
|             nn.Conv2d(512, 1, 4, padding=1) | ||||
|         ) | ||||
|  | ||||
|         self.apply(weights_init_normal) | ||||
|  | ||||
|     def forward(self, img): | ||||
|         return self.model(img) | ||||
|  | ||||
|  | ||||
| def sample_images(n, dataset_name, valid_dataloader, generator_ab, generator_ba): | ||||
|     """Saves a generated sample from the test set""" | ||||
|     batch = next(iter(valid_dataloader)) | ||||
|     generator_ab.eval() | ||||
|     generator_ba.eval() | ||||
|     with torch.no_grad(): | ||||
|         real_a, real_b = batch['a'].to(generator_ab.device), batch['b'].to(generator_ba.device) | ||||
|         fake_b = generator_ab(real_a) | ||||
|         fake_a = generator_ba(real_b) | ||||
|  | ||||
|         # Arange images along x-axis | ||||
|         real_a = make_grid(real_a, nrow=5, normalize=True) | ||||
|         real_b = make_grid(real_b, nrow=5, normalize=True) | ||||
|         fake_a = make_grid(fake_a, nrow=5, normalize=True) | ||||
|         fake_b = make_grid(fake_b, nrow=5, normalize=True) | ||||
|  | ||||
|         # arange images along y-axis | ||||
|         image_grid = torch.cat((real_a, fake_b, real_b, fake_a), 1) | ||||
|     save_image(image_grid, "images/%s/%s.png" % (dataset_name, n), normalize=False) | ||||
|  | ||||
|  | ||||
| class Configs(TrainingLoopConfigs): | ||||
|     device: torch.device = DeviceConfigs() | ||||
|     loop_count: int = 200 | ||||
|     dataset_name: str = 'monet2photo' | ||||
|     batch_size: int = 1 | ||||
|  | ||||
|     data_loader_workers = 8 | ||||
|     is_save_models = True | ||||
|  | ||||
|     learning_rate = 0.0002 | ||||
|     adam_betas = (0.5, 0.999) | ||||
|     decay_start = 100 | ||||
|  | ||||
|     identity_loss = torch.nn.L1Loss() | ||||
|     cycle_loss = torch.nn.L1Loss() | ||||
|     gan_loss = torch.nn.MSELoss() | ||||
|  | ||||
|     batch_step = 'cycle_gan_batch_step' | ||||
|  | ||||
|     img_height = 256 | ||||
|     img_width = 256 | ||||
|     img_channels = 3 | ||||
|  | ||||
|     n_residual_blocks = 9 | ||||
|  | ||||
|     cyclic_loss_coefficient = 10.0 | ||||
|     identity_loss_coefficient = 5. | ||||
|  | ||||
|     sample_interval = 100 | ||||
|  | ||||
|     def run(self): | ||||
|         images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name | ||||
|  | ||||
|         input_shape = (self.img_channels, self.img_height, self.img_width) | ||||
|         generator_ab = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) | ||||
|         generator_ba = GeneratorResNet(input_shape, self.n_residual_blocks).to(self.device) | ||||
|         discriminator_a = Discriminator(input_shape).to(self.device) | ||||
|         discriminator_b = Discriminator(input_shape).to(self.device) | ||||
|  | ||||
|         generator_optimizer = torch.optim.Adam( | ||||
|             itertools.chain(generator_ab.parameters(), generator_ba.parameters()), | ||||
|             lr=self.learning_rate, betas=self.adam_betas) | ||||
|         discriminator_optimizer = torch.optim.Adam( | ||||
|             itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), | ||||
|             lr=self.learning_rate, betas=self.adam_betas) | ||||
|  | ||||
|         decay_epochs = self.loop_count - self.decay_start | ||||
|         generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||
|             generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||
|         discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||
|             discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||
|  | ||||
|         # Image transformations | ||||
|         transforms_ = [ | ||||
|             transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), | ||||
|             transforms.RandomCrop((self.img_height, self.img_width)), | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||||
|         ] | ||||
|  | ||||
|         # Training data loader | ||||
|         dataloader = DataLoader( | ||||
|             ImageDataset(images_path, transforms_, True, 'train'), | ||||
|             batch_size=self.batch_size, | ||||
|             shuffle=True, | ||||
|             num_workers=self.data_loader_workers, | ||||
|         ) | ||||
|         # Test data loader | ||||
|         valid_dataloader = DataLoader( | ||||
|             ImageDataset(images_path, transforms_, True, "test"), | ||||
|             batch_size=5, | ||||
|             shuffle=True, | ||||
|             num_workers=self.data_loader_workers, | ||||
|         ) | ||||
|  | ||||
|         # Buffers of previously generated samples | ||||
|         fake_a_buffer = ReplayBuffer() | ||||
|         fake_b_buffer = ReplayBuffer() | ||||
|  | ||||
|         for epoch in self.training_loop: | ||||
|             for i, batch in enumerate(dataloader): | ||||
|                 # Set model input | ||||
|                 real_a, real_b = batch['a'].to(self.device), batch['b'].to(self.device) | ||||
|  | ||||
|                 # adversarial ground truths | ||||
|                 valid = torch.ones(real_a.size(0), *discriminator_a.output_shape, | ||||
|                                    device=self.device, requires_grad=False) | ||||
|                 fake = torch.zeros(real_a.size(0), *discriminator_a.output_shape, | ||||
|                                    device=self.device, requires_grad=False) | ||||
|  | ||||
|                 #  Train generators | ||||
|                 generator_ab.train() | ||||
|                 generator_ba.train() | ||||
|  | ||||
|                 # Identity loss | ||||
|                 loss_identity = self.identity_loss(generator_ba(real_a), real_a) + \ | ||||
|                                 self.identity_loss(generator_ab(real_b), real_b) | ||||
|  | ||||
|                 # GAN loss | ||||
|                 fake_b = generator_ab(real_a) | ||||
|                 fake_a = generator_ba(real_b) | ||||
|  | ||||
|                 loss_gan = self.gan_loss(discriminator_b(fake_b), valid) + \ | ||||
|                            self.gan_loss(discriminator_a(fake_a), valid) | ||||
|  | ||||
|                 loss_cycle = self.cycle_loss(generator_ba(fake_b), real_a) + \ | ||||
|                              self.cycle_loss(generator_ab(fake_a), real_b) | ||||
|  | ||||
|                 # Total loss | ||||
|                 loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle | ||||
|                                   + self.identity_loss_coefficient * loss_identity) | ||||
|  | ||||
|                 generator_optimizer.zero_grad() | ||||
|                 loss_generator.backward() | ||||
|                 generator_optimizer.step() | ||||
|  | ||||
|                 #  Train discriminators | ||||
|                 fake_a_replay = fake_a_buffer.push_and_pop(fake_a) | ||||
|                 fake_b_replay = fake_b_buffer.push_and_pop(fake_b) | ||||
|                 loss_discriminator = self.gan_loss(discriminator_a(real_a), valid) + \ | ||||
|                                      self.gan_loss(discriminator_a(fake_a_replay.detach()), fake) + \ | ||||
|                                      self.gan_loss(discriminator_b(real_b), valid) + \ | ||||
|                                      self.gan_loss(discriminator_b(fake_b_replay.detach()), fake) | ||||
|  | ||||
|                 discriminator_optimizer.zero_grad() | ||||
|                 loss_discriminator.backward() | ||||
|                 discriminator_optimizer.step() | ||||
|  | ||||
|                 tracker.save({'loss.generator': loss_generator, | ||||
|                               'loss.discriminator': loss_discriminator, | ||||
|                               'loss.generator.cycle': loss_cycle, | ||||
|                               'loss.generator.gan': loss_gan, | ||||
|                               'loss.generator.identity': loss_identity}) | ||||
|  | ||||
|                 # If at sample interval save image | ||||
|                 batches_done = epoch * len(dataloader) + i | ||||
|                 if batches_done % self.sample_interval == 0: | ||||
|                     sample_images(batches_done, self.dataset_name, valid_dataloader, generator_ab, generator_ba) | ||||
|  | ||||
|                 tracker.add_global_step(max(len(real_a), len(real_b))) | ||||
|  | ||||
|             # Update learning rates | ||||
|             generator_lr_scheduler.step() | ||||
|             discriminator_lr_scheduler.step() | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     conf = Configs() | ||||
|     experiment.create(name='cycle_gan') | ||||
|     experiment.configs(conf, 'run') | ||||
|     with experiment.start(): | ||||
|         conf.run() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri