diff --git a/labml_nn/gan/__init__.py b/labml_nn/gan/__init__.py index 475b75bf..25372851 100644 --- a/labml_nn/gan/__init__.py +++ b/labml_nn/gan/__init__.py @@ -20,7 +20,7 @@ $p_{data}(\pmb{x})$ is the probability distribution over data, whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to gaussian noise. -This file defines the loss functions. [Here](gan_mnist.html) is an MNIST example +This file defines the loss functions. [Here](simple_mnist_experiment.html) is an MNIST example with two multilayer perceptron for the generator and discriminator. """ diff --git a/labml_nn/gan/dcgan.py b/labml_nn/gan/dcgan.py new file mode 100644 index 00000000..9d390681 --- /dev/null +++ b/labml_nn/gan/dcgan.py @@ -0,0 +1,99 @@ +import torch.nn as nn + +from labml import experiment +from labml.configs import calculate +from labml_helpers.module import Module +from labml_nn.gan.simple_mnist_experiment import Configs + + +class Generator(Module): + """ + ### Convolutional Generator Network + """ + + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + # Gives a 3x3 output + nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(True), + # This gives a 7x7 + nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(True), + # This give 14x14 + nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(True), + # This gives 28*28 + nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False), + nn.Tanh() + ) + + self.apply(_weights_init) + + def __call__(self, x): + x = x.unsqueeze(-1).unsqueeze(-1) + x = self.layers(x) + return x + + +class Discriminator(Module): + """ + ### Convolutional Discriminator Network + """ + + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + # This gives 14x14 + nn.Conv2d(1, 256, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + # This gives 7x7 + nn.Conv2d(256, 512, 4, 2, 1, bias=False), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + # This gives 3x3 + nn.Conv2d(512, 1024, 3, 2, 0, bias=False), + nn.BatchNorm2d(1024), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*4) x 8 x 8 + nn.Conv2d(1024, 1, 3, 1, 0, bias=False), + ) + self.apply(_weights_init) + + def forward(self, x): + x = self.layers(x) + return x.view(x.shape[0], -1) + + +def _weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the +# generator and discriminator networks +calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device)) +calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device)) + + +def main(): + conf = Configs() + experiment.create(name='mnist_dcgan', comment='test') + experiment.configs(conf, + {'discriminator': 'cnn', + 'generator': 'cnn', + 'label_smoothing': 0.01}, + 'run') + with experiment.start(): + conf.run() + + +if __name__ == '__main__': + main() diff --git a/labml_nn/gan/gan_mnist.py b/labml_nn/gan/simple_mnist_experiment.py similarity index 78% rename from labml_nn/gan/gan_mnist.py rename to labml_nn/gan/simple_mnist_experiment.py index 2e808d57..1941dc76 100644 --- a/labml_nn/gan/gan_mnist.py +++ b/labml_nn/gan/simple_mnist_experiment.py @@ -16,6 +16,15 @@ from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidC from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + class Generator(Module): """ ### Simple MLP Generator @@ -35,6 +44,8 @@ class Generator(Module): self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh()) + self.apply(weights_init) + def forward(self, x): return self.layers(x).view(x.shape[0], 1, 28, 28) @@ -58,6 +69,7 @@ class Discriminator(Module): d_prev = size self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1)) + self.apply(weights_init) def forward(self, x): return self.layers(x.view(x.shape[0], -1)) @@ -70,8 +82,10 @@ class GANBatchStep(BatchStepProtocol): discriminator_optimizer: Optional[torch.optim.Adam], generator_optimizer: Optional[torch.optim.Adam], discriminator_loss: DiscriminatorLogitsLoss, - generator_loss: GeneratorLogitsLoss): + generator_loss: GeneratorLogitsLoss, + discriminator_k: int): + self.discriminator_k = discriminator_k self.generator = generator self.discriminator = discriminator self.generator_loss = generator_loss @@ -100,25 +114,26 @@ class GANBatchStep(BatchStepProtocol): # Train the discriminator with monit.section("discriminator"): - latent = torch.randn(data.shape[0], 100, device=device) - if MODE_STATE.is_train: - self.discriminator_optimizer.zero_grad() - logits_true = self.discriminator(data) - logits_false = self.discriminator(self.generator(latent).detach()) - loss_true, loss_false = self.discriminator_loss(logits_true, logits_false) - loss = loss_true + loss_false + for _ in range(self.discriminator_k): + latent = torch.randn(data.shape[0], 100, device=device) + if MODE_STATE.is_train: + self.discriminator_optimizer.zero_grad() + logits_true = self.discriminator(data) + logits_false = self.discriminator(self.generator(latent).detach()) + loss_true, loss_false = self.discriminator_loss(logits_true, logits_false) + loss = loss_true + loss_false - # Log stuff - tracker.add("loss.discriminator.true.", loss_true) - tracker.add("loss.discriminator.false.", loss_false) - tracker.add("loss.discriminator.", loss) + # Log stuff + tracker.add("loss.discriminator.true.", loss_true) + tracker.add("loss.discriminator.false.", loss_false) + tracker.add("loss.discriminator.", loss) - # Train - if MODE_STATE.is_train: - loss.backward() - if MODE_STATE.is_log_parameters: - pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') - self.discriminator_optimizer.step() + # Train + if MODE_STATE.is_train: + loss.backward() + if MODE_STATE.is_log_parameters: + pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') + self.discriminator_optimizer.step() # Train the generator with monit.section("generator"): @@ -156,6 +171,7 @@ class Configs(MNISTConfigs, TrainValidConfigs): discriminator_loss: DiscriminatorLogitsLoss batch_step = 'gan_batch_step' label_smoothing: float = 0.2 + discriminator_k: int = 1 @option(Configs.dataset_transforms) @@ -173,11 +189,12 @@ def gan_batch_step(c: Configs): discriminator_optimizer=c.discriminator_optimizer, generator_optimizer=c.generator_optimizer, discriminator_loss=c.discriminator_loss, - generator_loss=c.generator_loss) + generator_loss=c.generator_loss, + discriminator_k=c.discriminator_k) -calculate(Configs.generator, lambda c: Generator().to(c.device)) -calculate(Configs.discriminator, lambda c: Discriminator().to(c.device)) +calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device)) +calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device)) calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device)) calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))