mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 14:29:43 +08:00 
			
		
		
		
	✨ dcgan
This commit is contained in:
		@ -20,7 +20,7 @@ $p_{data}(\pmb{x})$ is the probability distribution over data,
 | 
				
			|||||||
whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to
 | 
					whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to
 | 
				
			||||||
gaussian noise.
 | 
					gaussian noise.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
This file defines the loss functions. [Here](gan_mnist.html) is an MNIST example
 | 
					This file defines the loss functions. [Here](simple_mnist_experiment.html) is an MNIST example
 | 
				
			||||||
with two multilayer perceptron for the generator and discriminator.
 | 
					with two multilayer perceptron for the generator and discriminator.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										99
									
								
								labml_nn/gan/dcgan.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								labml_nn/gan/dcgan.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,99 @@
 | 
				
			|||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from labml import experiment
 | 
				
			||||||
 | 
					from labml.configs import calculate
 | 
				
			||||||
 | 
					from labml_helpers.module import Module
 | 
				
			||||||
 | 
					from labml_nn.gan.simple_mnist_experiment import Configs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Generator(Module):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ### Convolutional Generator Network
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.layers = nn.Sequential(
 | 
				
			||||||
 | 
					            # Gives a 3x3 output
 | 
				
			||||||
 | 
					            nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
 | 
				
			||||||
 | 
					            nn.BatchNorm2d(1024),
 | 
				
			||||||
 | 
					            nn.ReLU(True),
 | 
				
			||||||
 | 
					            # This gives a 7x7
 | 
				
			||||||
 | 
					            nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
 | 
				
			||||||
 | 
					            nn.BatchNorm2d(512),
 | 
				
			||||||
 | 
					            nn.ReLU(True),
 | 
				
			||||||
 | 
					            # This give 14x14
 | 
				
			||||||
 | 
					            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
 | 
				
			||||||
 | 
					            nn.BatchNorm2d(256),
 | 
				
			||||||
 | 
					            nn.ReLU(True),
 | 
				
			||||||
 | 
					            # This gives 28*28
 | 
				
			||||||
 | 
					            nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
 | 
				
			||||||
 | 
					            nn.Tanh()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.apply(_weights_init)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x):
 | 
				
			||||||
 | 
					        x = x.unsqueeze(-1).unsqueeze(-1)
 | 
				
			||||||
 | 
					        x = self.layers(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Discriminator(Module):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    ### Convolutional Discriminator Network
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.layers = nn.Sequential(
 | 
				
			||||||
 | 
					            # This gives 14x14
 | 
				
			||||||
 | 
					            nn.Conv2d(1, 256, 4, 2, 1, bias=False),
 | 
				
			||||||
 | 
					            nn.LeakyReLU(0.2, inplace=True),
 | 
				
			||||||
 | 
					            # This gives 7x7
 | 
				
			||||||
 | 
					            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
 | 
				
			||||||
 | 
					            nn.BatchNorm2d(512),
 | 
				
			||||||
 | 
					            nn.LeakyReLU(0.2, inplace=True),
 | 
				
			||||||
 | 
					            # This gives 3x3
 | 
				
			||||||
 | 
					            nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
 | 
				
			||||||
 | 
					            nn.BatchNorm2d(1024),
 | 
				
			||||||
 | 
					            nn.LeakyReLU(0.2, inplace=True),
 | 
				
			||||||
 | 
					            # state size. (ndf*4) x 8 x 8
 | 
				
			||||||
 | 
					            nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.apply(_weights_init)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        x = self.layers(x)
 | 
				
			||||||
 | 
					        return x.view(x.shape[0], -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _weights_init(m):
 | 
				
			||||||
 | 
					    classname = m.__class__.__name__
 | 
				
			||||||
 | 
					    if classname.find('Conv') != -1:
 | 
				
			||||||
 | 
					        nn.init.normal_(m.weight.data, 0.0, 0.02)
 | 
				
			||||||
 | 
					    elif classname.find('BatchNorm') != -1:
 | 
				
			||||||
 | 
					        nn.init.normal_(m.weight.data, 1.0, 0.02)
 | 
				
			||||||
 | 
					        nn.init.constant_(m.bias.data, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the
 | 
				
			||||||
 | 
					# generator and discriminator networks
 | 
				
			||||||
 | 
					calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
 | 
				
			||||||
 | 
					calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main():
 | 
				
			||||||
 | 
					    conf = Configs()
 | 
				
			||||||
 | 
					    experiment.create(name='mnist_dcgan', comment='test')
 | 
				
			||||||
 | 
					    experiment.configs(conf,
 | 
				
			||||||
 | 
					                       {'discriminator': 'cnn',
 | 
				
			||||||
 | 
					                        'generator': 'cnn',
 | 
				
			||||||
 | 
					                        'label_smoothing': 0.01},
 | 
				
			||||||
 | 
					                       'run')
 | 
				
			||||||
 | 
					    with experiment.start():
 | 
				
			||||||
 | 
					        conf.run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    main()
 | 
				
			||||||
@ -16,6 +16,15 @@ from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidC
 | 
				
			|||||||
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
 | 
					from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def weights_init(m):
 | 
				
			||||||
 | 
					    classname = m.__class__.__name__
 | 
				
			||||||
 | 
					    if classname.find('Linear') != -1:
 | 
				
			||||||
 | 
					        nn.init.normal_(m.weight.data, 0.0, 0.02)
 | 
				
			||||||
 | 
					    elif classname.find('BatchNorm') != -1:
 | 
				
			||||||
 | 
					        nn.init.normal_(m.weight.data, 1.0, 0.02)
 | 
				
			||||||
 | 
					        nn.init.constant_(m.bias.data, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Generator(Module):
 | 
					class Generator(Module):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    ### Simple MLP Generator
 | 
					    ### Simple MLP Generator
 | 
				
			||||||
@ -35,6 +44,8 @@ class Generator(Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
 | 
					        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.apply(weights_init)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        return self.layers(x).view(x.shape[0], 1, 28, 28)
 | 
					        return self.layers(x).view(x.shape[0], 1, 28, 28)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -58,6 +69,7 @@ class Discriminator(Module):
 | 
				
			|||||||
            d_prev = size
 | 
					            d_prev = size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
 | 
					        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
 | 
				
			||||||
 | 
					        self.apply(weights_init)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        return self.layers(x.view(x.shape[0], -1))
 | 
					        return self.layers(x.view(x.shape[0], -1))
 | 
				
			||||||
@ -70,8 +82,10 @@ class GANBatchStep(BatchStepProtocol):
 | 
				
			|||||||
                 discriminator_optimizer: Optional[torch.optim.Adam],
 | 
					                 discriminator_optimizer: Optional[torch.optim.Adam],
 | 
				
			||||||
                 generator_optimizer: Optional[torch.optim.Adam],
 | 
					                 generator_optimizer: Optional[torch.optim.Adam],
 | 
				
			||||||
                 discriminator_loss: DiscriminatorLogitsLoss,
 | 
					                 discriminator_loss: DiscriminatorLogitsLoss,
 | 
				
			||||||
                 generator_loss: GeneratorLogitsLoss):
 | 
					                 generator_loss: GeneratorLogitsLoss,
 | 
				
			||||||
 | 
					                 discriminator_k: int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.discriminator_k = discriminator_k
 | 
				
			||||||
        self.generator = generator
 | 
					        self.generator = generator
 | 
				
			||||||
        self.discriminator = discriminator
 | 
					        self.discriminator = discriminator
 | 
				
			||||||
        self.generator_loss = generator_loss
 | 
					        self.generator_loss = generator_loss
 | 
				
			||||||
@ -100,25 +114,26 @@ class GANBatchStep(BatchStepProtocol):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Train the discriminator
 | 
					        # Train the discriminator
 | 
				
			||||||
        with monit.section("discriminator"):
 | 
					        with monit.section("discriminator"):
 | 
				
			||||||
            latent = torch.randn(data.shape[0], 100, device=device)
 | 
					            for _ in range(self.discriminator_k):
 | 
				
			||||||
            if MODE_STATE.is_train:
 | 
					                latent = torch.randn(data.shape[0], 100, device=device)
 | 
				
			||||||
                self.discriminator_optimizer.zero_grad()
 | 
					                if MODE_STATE.is_train:
 | 
				
			||||||
            logits_true = self.discriminator(data)
 | 
					                    self.discriminator_optimizer.zero_grad()
 | 
				
			||||||
            logits_false = self.discriminator(self.generator(latent).detach())
 | 
					                logits_true = self.discriminator(data)
 | 
				
			||||||
            loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
 | 
					                logits_false = self.discriminator(self.generator(latent).detach())
 | 
				
			||||||
            loss = loss_true + loss_false
 | 
					                loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
 | 
				
			||||||
 | 
					                loss = loss_true + loss_false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Log stuff
 | 
					                # Log stuff
 | 
				
			||||||
            tracker.add("loss.discriminator.true.", loss_true)
 | 
					                tracker.add("loss.discriminator.true.", loss_true)
 | 
				
			||||||
            tracker.add("loss.discriminator.false.", loss_false)
 | 
					                tracker.add("loss.discriminator.false.", loss_false)
 | 
				
			||||||
            tracker.add("loss.discriminator.", loss)
 | 
					                tracker.add("loss.discriminator.", loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Train
 | 
					                # Train
 | 
				
			||||||
            if MODE_STATE.is_train:
 | 
					                if MODE_STATE.is_train:
 | 
				
			||||||
                loss.backward()
 | 
					                    loss.backward()
 | 
				
			||||||
                if MODE_STATE.is_log_parameters:
 | 
					                    if MODE_STATE.is_log_parameters:
 | 
				
			||||||
                    pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
 | 
					                        pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
 | 
				
			||||||
                self.discriminator_optimizer.step()
 | 
					                    self.discriminator_optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Train the generator
 | 
					        # Train the generator
 | 
				
			||||||
        with monit.section("generator"):
 | 
					        with monit.section("generator"):
 | 
				
			||||||
@ -156,6 +171,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
 | 
				
			|||||||
    discriminator_loss: DiscriminatorLogitsLoss
 | 
					    discriminator_loss: DiscriminatorLogitsLoss
 | 
				
			||||||
    batch_step = 'gan_batch_step'
 | 
					    batch_step = 'gan_batch_step'
 | 
				
			||||||
    label_smoothing: float = 0.2
 | 
					    label_smoothing: float = 0.2
 | 
				
			||||||
 | 
					    discriminator_k: int = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@option(Configs.dataset_transforms)
 | 
					@option(Configs.dataset_transforms)
 | 
				
			||||||
@ -173,11 +189,12 @@ def gan_batch_step(c: Configs):
 | 
				
			|||||||
                        discriminator_optimizer=c.discriminator_optimizer,
 | 
					                        discriminator_optimizer=c.discriminator_optimizer,
 | 
				
			||||||
                        generator_optimizer=c.generator_optimizer,
 | 
					                        generator_optimizer=c.generator_optimizer,
 | 
				
			||||||
                        discriminator_loss=c.discriminator_loss,
 | 
					                        discriminator_loss=c.discriminator_loss,
 | 
				
			||||||
                        generator_loss=c.generator_loss)
 | 
					                        generator_loss=c.generator_loss,
 | 
				
			||||||
 | 
					                        discriminator_k=c.discriminator_k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
calculate(Configs.generator, lambda c: Generator().to(c.device))
 | 
					calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
 | 
				
			||||||
calculate(Configs.discriminator, lambda c: Discriminator().to(c.device))
 | 
					calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
 | 
				
			||||||
calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
 | 
					calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
 | 
				
			||||||
calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
 | 
					calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Reference in New Issue
	
	Block a user