mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 14:29:43 +08:00 
			
		
		
		
	cleanup cycle gan code
This commit is contained in:
		@ -53,7 +53,7 @@ class GeneratorResNet(Module):
 | 
			
		||||
        # `inplace=True` in `ReLU` saves a little bit of memory.
 | 
			
		||||
        out_features = 64
 | 
			
		||||
        layers = [
 | 
			
		||||
            nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflection'),
 | 
			
		||||
            nn.Conv2d(channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
 | 
			
		||||
            nn.InstanceNorm2d(out_features),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
        ]
 | 
			
		||||
@ -80,14 +80,15 @@ class GeneratorResNet(Module):
 | 
			
		||||
        for _ in range(2):
 | 
			
		||||
            out_features //= 2
 | 
			
		||||
            layers += [
 | 
			
		||||
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
 | 
			
		||||
                nn.Upsample(scale_factor=2),
 | 
			
		||||
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
 | 
			
		||||
                nn.InstanceNorm2d(out_features),
 | 
			
		||||
                nn.ReLU(inplace=True),
 | 
			
		||||
            ]
 | 
			
		||||
            in_features = out_features
 | 
			
		||||
 | 
			
		||||
        # Finally we map the feature map to an RGB image
 | 
			
		||||
        layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflection'), nn.Tanh()]
 | 
			
		||||
        layers += [nn.Conv2d(out_features, channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]
 | 
			
		||||
 | 
			
		||||
        # Create a sequential module with the layers
 | 
			
		||||
        self.layers = nn.Sequential(*layers)
 | 
			
		||||
@ -107,10 +108,10 @@ class ResidualBlock(Module):
 | 
			
		||||
    def __init__(self, in_features: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.block = nn.Sequential(
 | 
			
		||||
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflection'),
 | 
			
		||||
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
 | 
			
		||||
            nn.InstanceNorm2d(in_features),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflection'),
 | 
			
		||||
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
 | 
			
		||||
            nn.InstanceNorm2d(in_features),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
        )
 | 
			
		||||
@ -158,6 +159,7 @@ class DiscriminatorBlock(Module):
 | 
			
		||||
 | 
			
		||||
    It shrinks the height and width of the input feature map by half.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
 | 
			
		||||
@ -194,6 +196,7 @@ class ImageDataset(Dataset):
 | 
			
		||||
    """
 | 
			
		||||
    Dataset to load images
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, root: PurePath, transforms_, unaligned: bool, mode: str):
 | 
			
		||||
        root = Path(root)
 | 
			
		||||
        self.transform = transforms.Compose(transforms_)
 | 
			
		||||
@ -221,6 +224,7 @@ class ReplayBuffer:
 | 
			
		||||
 | 
			
		||||
    This is done to reduce model oscillation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, max_size: int = 50):
 | 
			
		||||
        self.max_size = max_size
 | 
			
		||||
        self.data = []
 | 
			
		||||
@ -249,7 +253,7 @@ class Configs(BaseConfigs):
 | 
			
		||||
    dataset_name: str = 'monet2photo'
 | 
			
		||||
    batch_size: int = 1
 | 
			
		||||
 | 
			
		||||
    data_loader_workers = 8
 | 
			
		||||
    data_loader_workers = 0
 | 
			
		||||
    is_save_models = True
 | 
			
		||||
 | 
			
		||||
    learning_rate = 0.0002
 | 
			
		||||
@ -311,6 +315,57 @@ class Configs(BaseConfigs):
 | 
			
		||||
            image_grid = torch.cat((real_a, fake_b, real_b, fake_a), 1)
 | 
			
		||||
        save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False)
 | 
			
		||||
 | 
			
		||||
    def optimize_generators(self, real_a: torch.Tensor, real_b: torch.Tensor, true_labels: torch.Tensor):
 | 
			
		||||
        #  Change to training mode
 | 
			
		||||
        self.generator_ab.train()
 | 
			
		||||
        self.generator_ba.train()
 | 
			
		||||
 | 
			
		||||
        # Identity loss
 | 
			
		||||
        loss_identity = (self.identity_loss(self.generator_ba(real_a), real_a) +
 | 
			
		||||
                         self.identity_loss(self.generator_ab(real_b), real_b))
 | 
			
		||||
 | 
			
		||||
        # Generate images
 | 
			
		||||
        fake_b = self.generator_ab(real_a)
 | 
			
		||||
        fake_a = self.generator_ba(real_b)
 | 
			
		||||
 | 
			
		||||
        # GAN loss
 | 
			
		||||
        loss_gan = (self.gan_loss(self.discriminator_b(fake_b), true_labels) +
 | 
			
		||||
                    self.gan_loss(self.discriminator_a(fake_a), true_labels))
 | 
			
		||||
 | 
			
		||||
        # Cycle loss
 | 
			
		||||
        loss_cycle = (self.cycle_loss(self.generator_ba(fake_b), real_a) +
 | 
			
		||||
                      self.cycle_loss(self.generator_ab(fake_a), real_b))
 | 
			
		||||
 | 
			
		||||
        # Total loss
 | 
			
		||||
        loss_generator = (loss_gan +
 | 
			
		||||
                          self.cyclic_loss_coefficient * loss_cycle +
 | 
			
		||||
                          self.identity_loss_coefficient * loss_identity)
 | 
			
		||||
 | 
			
		||||
        self.generator_optimizer.zero_grad()
 | 
			
		||||
        loss_generator.backward()
 | 
			
		||||
        self.generator_optimizer.step()
 | 
			
		||||
 | 
			
		||||
        tracker.add({'loss.generator': loss_generator,
 | 
			
		||||
                     'loss.generator.cycle': loss_cycle,
 | 
			
		||||
                     'loss.generator.gan': loss_gan,
 | 
			
		||||
                     'loss.generator.identity': loss_identity})
 | 
			
		||||
 | 
			
		||||
        return fake_a, fake_b
 | 
			
		||||
 | 
			
		||||
    def optimize_discriminator(self, real_a: torch.Tensor, real_b: torch.Tensor,
 | 
			
		||||
                               fake_a: torch.Tensor, fake_b: torch.Tensor,
 | 
			
		||||
                               true_labels: torch.Tensor, false_labels: torch.Tensor):
 | 
			
		||||
        loss_discriminator = (self.gan_loss(self.discriminator_a(real_a), true_labels) +
 | 
			
		||||
                              self.gan_loss(self.discriminator_a(fake_a), false_labels) +
 | 
			
		||||
                              self.gan_loss(self.discriminator_b(real_b), true_labels) +
 | 
			
		||||
                              self.gan_loss(self.discriminator_b(fake_b), false_labels))
 | 
			
		||||
 | 
			
		||||
        self.discriminator_optimizer.zero_grad()
 | 
			
		||||
        loss_discriminator.backward()
 | 
			
		||||
        self.discriminator_optimizer.step()
 | 
			
		||||
 | 
			
		||||
        tracker.add({'loss.discriminator': loss_discriminator})
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        # Replay buffers to keep generated samples
 | 
			
		||||
        fake_a_buffer = ReplayBuffer()
 | 
			
		||||
@ -321,56 +376,23 @@ class Configs(BaseConfigs):
 | 
			
		||||
                # Move images to the device
 | 
			
		||||
                real_a, real_b = batch['a'].to(self.device), batch['b'].to(self.device)
 | 
			
		||||
 | 
			
		||||
                # valid labels equal to $1$
 | 
			
		||||
                valid = torch.ones(real_a.size(0), *self.discriminator_a.output_shape,
 | 
			
		||||
                                   device=self.device, requires_grad=False)
 | 
			
		||||
                # fake labels equal to $0$
 | 
			
		||||
                fake = torch.zeros(real_a.size(0), *self.discriminator_a.output_shape,
 | 
			
		||||
                                   device=self.device, requires_grad=False)
 | 
			
		||||
                # true labels equal to $1$
 | 
			
		||||
                true_labels = torch.ones(real_a.size(0), *self.discriminator_a.output_shape,
 | 
			
		||||
                                         device=self.device, requires_grad=False)
 | 
			
		||||
                # false labels equal to $0$
 | 
			
		||||
                false_labels = torch.zeros(real_a.size(0), *self.discriminator_a.output_shape,
 | 
			
		||||
                                           device=self.device, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
                #  Train generators
 | 
			
		||||
                self.generator_ab.train()
 | 
			
		||||
                self.generator_ba.train()
 | 
			
		||||
 | 
			
		||||
                # Identity loss
 | 
			
		||||
                loss_identity = self.identity_loss(self.generator_ba(real_a), real_a) + \
 | 
			
		||||
                                self.identity_loss(self.generator_ab(real_b), real_b)
 | 
			
		||||
 | 
			
		||||
                # GAN loss
 | 
			
		||||
                fake_b = self.generator_ab(real_a)
 | 
			
		||||
                fake_a = self.generator_ba(real_b)
 | 
			
		||||
 | 
			
		||||
                loss_gan = self.gan_loss(self.discriminator_b(fake_b), valid) + \
 | 
			
		||||
                           self.gan_loss(self.discriminator_a(fake_a), valid)
 | 
			
		||||
 | 
			
		||||
                loss_cycle = self.cycle_loss(self.generator_ba(fake_b), real_a) + \
 | 
			
		||||
                             self.cycle_loss(self.generator_ab(fake_a), real_b)
 | 
			
		||||
 | 
			
		||||
                # Total loss
 | 
			
		||||
                loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle
 | 
			
		||||
                                  + self.identity_loss_coefficient * loss_identity)
 | 
			
		||||
 | 
			
		||||
                self.generator_optimizer.zero_grad()
 | 
			
		||||
                loss_generator.backward()
 | 
			
		||||
                self.generator_optimizer.step()
 | 
			
		||||
                # Train the generators
 | 
			
		||||
                fake_a, fake_b = self.optimize_generators(real_a, real_b, true_labels)
 | 
			
		||||
 | 
			
		||||
                #  Train discriminators
 | 
			
		||||
                fake_a_replay = fake_a_buffer.push_and_pop(fake_a)
 | 
			
		||||
                fake_b_replay = fake_b_buffer.push_and_pop(fake_b)
 | 
			
		||||
                loss_discriminator = self.gan_loss(self.discriminator_a(real_a), valid) + \
 | 
			
		||||
                                     self.gan_loss(self.discriminator_a(fake_a_replay), fake) + \
 | 
			
		||||
                                     self.gan_loss(self.discriminator_b(real_b), valid) + \
 | 
			
		||||
                                     self.gan_loss(self.discriminator_b(fake_b_replay), fake)
 | 
			
		||||
                self.optimize_discriminator(real_a, real_b,
 | 
			
		||||
                                            fake_a_buffer.push_and_pop(fake_a), fake_b_buffer.push_and_pop(fake_b),
 | 
			
		||||
                                            true_labels, false_labels)
 | 
			
		||||
 | 
			
		||||
                self.discriminator_optimizer.zero_grad()
 | 
			
		||||
                loss_discriminator.backward()
 | 
			
		||||
                self.discriminator_optimizer.step()
 | 
			
		||||
 | 
			
		||||
                tracker.save({'loss.generator': loss_generator,
 | 
			
		||||
                              'loss.discriminator': loss_discriminator,
 | 
			
		||||
                              'loss.generator.cycle': loss_cycle,
 | 
			
		||||
                              'loss.generator.gan': loss_gan,
 | 
			
		||||
                              'loss.generator.identity': loss_identity})
 | 
			
		||||
                # Save training statistics
 | 
			
		||||
                tracker.save()
 | 
			
		||||
 | 
			
		||||
                # If at sample interval save image
 | 
			
		||||
                batches_done = epoch * len(self.dataloader) + i
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user