diff --git a/labml_nn/gan/cycle_gan.py b/labml_nn/gan/cycle_gan.py index 53327793..7b98b01c 100644 --- a/labml_nn/gan/cycle_gan.py +++ b/labml_nn/gan/cycle_gan.py @@ -53,7 +53,7 @@ class GeneratorResNet(Module): # `inplace=True` in `ReLU` saves a little bit of memory. out_features = 64 layers = [ - nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflection'), + nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] @@ -80,14 +80,15 @@ class GeneratorResNet(Module): for _ in range(2): out_features //= 2 layers += [ - nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1), + nn.Upsample(scale_factor=2), + nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Finally we map the feature map to an RGB image - layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflection'), nn.Tanh()] + layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()] # Create a sequential module with the layers self.layers = nn.Sequential(*layers) @@ -107,10 +108,10 @@ class ResidualBlock(Module): def __init__(self, in_features: int): super().__init__() self.block = nn.Sequential( - nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflection'), + nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), - nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflection'), + nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), ) @@ -158,6 +159,7 @@ class DiscriminatorBlock(Module): It shrinks the height and width of the input feature map by half. """ + def __init__(self, in_filters: int, out_filters: int, normalize: bool = True): super().__init__() layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)] @@ -194,6 +196,7 @@ class ImageDataset(Dataset): """ Dataset to load images """ + def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str): root = Path(root) self.transform = transforms.Compose(transforms_) @@ -221,6 +224,7 @@ class ReplayBuffer: This is done to reduce model oscillation. """ + def __init__(self, max_size: int = 50): self.max_size = max_size self.data = [] @@ -249,7 +253,7 @@ class Configs(BaseConfigs): dataset_name: str = 'monet2photo' batch_size: int = 1 - data_loader_workers = 8 + data_loader_workers = 0 is_save_models = True learning_rate = 0.0002 @@ -311,6 +315,57 @@ class Configs(BaseConfigs): image_grid = torch.cat((real_a, fake_b, real_b, fake_a), 1) save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) + def optimize_generators(self, real_a: torch.Tensor, real_b: torch.Tensor, true_labels: torch.Tensor): + # Change to training mode + self.generator_ab.train() + self.generator_ba.train() + + # Identity loss + loss_identity = (self.identity_loss(self.generator_ba(real_a), real_a) + + self.identity_loss(self.generator_ab(real_b), real_b)) + + # Generate images + fake_b = self.generator_ab(real_a) + fake_a = self.generator_ba(real_b) + + # GAN loss + loss_gan = (self.gan_loss(self.discriminator_b(fake_b), true_labels) + + self.gan_loss(self.discriminator_a(fake_a), true_labels)) + + # Cycle loss + loss_cycle = (self.cycle_loss(self.generator_ba(fake_b), real_a) + + self.cycle_loss(self.generator_ab(fake_a), real_b)) + + # Total loss + loss_generator = (loss_gan + + self.cyclic_loss_coefficient * loss_cycle + + self.identity_loss_coefficient * loss_identity) + + self.generator_optimizer.zero_grad() + loss_generator.backward() + self.generator_optimizer.step() + + tracker.add({'loss.generator': loss_generator, + 'loss.generator.cycle': loss_cycle, + 'loss.generator.gan': loss_gan, + 'loss.generator.identity': loss_identity}) + + return fake_a, fake_b + + def optimize_discriminator(self, real_a: torch.Tensor, real_b: torch.Tensor, + fake_a: torch.Tensor, fake_b: torch.Tensor, + true_labels: torch.Tensor, false_labels: torch.Tensor): + loss_discriminator = (self.gan_loss(self.discriminator_a(real_a), true_labels) + + self.gan_loss(self.discriminator_a(fake_a), false_labels) + + self.gan_loss(self.discriminator_b(real_b), true_labels) + + self.gan_loss(self.discriminator_b(fake_b), false_labels)) + + self.discriminator_optimizer.zero_grad() + loss_discriminator.backward() + self.discriminator_optimizer.step() + + tracker.add({'loss.discriminator': loss_discriminator}) + def run(self): # Replay buffers to keep generated samples fake_a_buffer = ReplayBuffer() @@ -321,56 +376,23 @@ class Configs(BaseConfigs): # Move images to the device real_a, real_b = batch['a'].to(self.device), batch['b'].to(self.device) - # valid labels equal to $1$ - valid = torch.ones(real_a.size(0), *self.discriminator_a.output_shape, - device=self.device, requires_grad=False) - # fake labels equal to $0$ - fake = torch.zeros(real_a.size(0), *self.discriminator_a.output_shape, - device=self.device, requires_grad=False) + # true labels equal to $1$ + true_labels = torch.ones(real_a.size(0), *self.discriminator_a.output_shape, + device=self.device, requires_grad=False) + # false labels equal to $0$ + false_labels = torch.zeros(real_a.size(0), *self.discriminator_a.output_shape, + device=self.device, requires_grad=False) - # Train generators - self.generator_ab.train() - self.generator_ba.train() - - # Identity loss - loss_identity = self.identity_loss(self.generator_ba(real_a), real_a) + \ - self.identity_loss(self.generator_ab(real_b), real_b) - - # GAN loss - fake_b = self.generator_ab(real_a) - fake_a = self.generator_ba(real_b) - - loss_gan = self.gan_loss(self.discriminator_b(fake_b), valid) + \ - self.gan_loss(self.discriminator_a(fake_a), valid) - - loss_cycle = self.cycle_loss(self.generator_ba(fake_b), real_a) + \ - self.cycle_loss(self.generator_ab(fake_a), real_b) - - # Total loss - loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle - + self.identity_loss_coefficient * loss_identity) - - self.generator_optimizer.zero_grad() - loss_generator.backward() - self.generator_optimizer.step() + # Train the generators + fake_a, fake_b = self.optimize_generators(real_a, real_b, true_labels) # Train discriminators - fake_a_replay = fake_a_buffer.push_and_pop(fake_a) - fake_b_replay = fake_b_buffer.push_and_pop(fake_b) - loss_discriminator = self.gan_loss(self.discriminator_a(real_a), valid) + \ - self.gan_loss(self.discriminator_a(fake_a_replay), fake) + \ - self.gan_loss(self.discriminator_b(real_b), valid) + \ - self.gan_loss(self.discriminator_b(fake_b_replay), fake) + self.optimize_discriminator(real_a, real_b, + fake_a_buffer.push_and_pop(fake_a), fake_b_buffer.push_and_pop(fake_b), + true_labels, false_labels) - self.discriminator_optimizer.zero_grad() - loss_discriminator.backward() - self.discriminator_optimizer.step() - - tracker.save({'loss.generator': loss_generator, - 'loss.discriminator': loss_discriminator, - 'loss.generator.cycle': loss_cycle, - 'loss.generator.gan': loss_gan, - 'loss.generator.identity': loss_identity}) + # Save training statistics + tracker.save() # If at sample interval save image batches_done = epoch * len(self.dataloader) + i