dcgan

This commit is contained in:
Varuna Jayasiri
2020-09-28 11:23:04 +05:30
parent 9bd29d165c
commit 7ef213f89c
3 changed files with 138 additions and 22 deletions

View File

@ -20,7 +20,7 @@ $p_{data}(\pmb{x})$ is the probability distribution over data,
whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to
gaussian noise.
This file defines the loss functions. [Here](gan_mnist.html) is an MNIST example
This file defines the loss functions. [Here](simple_mnist_experiment.html) is an MNIST example
with two multilayer perceptron for the generator and discriminator.
"""

99
labml_nn/gan/dcgan.py Normal file
View File

@ -0,0 +1,99 @@
import torch.nn as nn
from labml import experiment
from labml.configs import calculate
from labml_helpers.module import Module
from labml_nn.gan.simple_mnist_experiment import Configs
class Generator(Module):
"""
### Convolutional Generator Network
"""
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
# Gives a 3x3 output
nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
nn.BatchNorm2d(1024),
nn.ReLU(True),
# This gives a 7x7
nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# This give 14x14
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# This gives 28*28
nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
self.apply(_weights_init)
def __call__(self, x):
x = x.unsqueeze(-1).unsqueeze(-1)
x = self.layers(x)
return x
class Discriminator(Module):
"""
### Convolutional Discriminator Network
"""
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
# This gives 14x14
nn.Conv2d(1, 256, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# This gives 7x7
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# This gives 3x3
nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
)
self.apply(_weights_init)
def forward(self, x):
x = self.layers(x)
return x.view(x.shape[0], -1)
def _weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the
# generator and discriminator networks
calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
def main():
conf = Configs()
experiment.create(name='mnist_dcgan', comment='test')
experiment.configs(conf,
{'discriminator': 'cnn',
'generator': 'cnn',
'label_smoothing': 0.01},
'run')
with experiment.start():
conf.run()
if __name__ == '__main__':
main()

View File

@ -16,6 +16,15 @@ from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidC
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Generator(Module):
"""
### Simple MLP Generator
@ -35,6 +44,8 @@ class Generator(Module):
self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
self.apply(weights_init)
def forward(self, x):
return self.layers(x).view(x.shape[0], 1, 28, 28)
@ -58,6 +69,7 @@ class Discriminator(Module):
d_prev = size
self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
self.apply(weights_init)
def forward(self, x):
return self.layers(x.view(x.shape[0], -1))
@ -70,8 +82,10 @@ class GANBatchStep(BatchStepProtocol):
discriminator_optimizer: Optional[torch.optim.Adam],
generator_optimizer: Optional[torch.optim.Adam],
discriminator_loss: DiscriminatorLogitsLoss,
generator_loss: GeneratorLogitsLoss):
generator_loss: GeneratorLogitsLoss,
discriminator_k: int):
self.discriminator_k = discriminator_k
self.generator = generator
self.discriminator = discriminator
self.generator_loss = generator_loss
@ -100,25 +114,26 @@ class GANBatchStep(BatchStepProtocol):
# Train the discriminator
with monit.section("discriminator"):
latent = torch.randn(data.shape[0], 100, device=device)
if MODE_STATE.is_train:
self.discriminator_optimizer.zero_grad()
logits_true = self.discriminator(data)
logits_false = self.discriminator(self.generator(latent).detach())
loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
loss = loss_true + loss_false
for _ in range(self.discriminator_k):
latent = torch.randn(data.shape[0], 100, device=device)
if MODE_STATE.is_train:
self.discriminator_optimizer.zero_grad()
logits_true = self.discriminator(data)
logits_false = self.discriminator(self.generator(latent).detach())
loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
loss = loss_true + loss_false
# Log stuff
tracker.add("loss.discriminator.true.", loss_true)
tracker.add("loss.discriminator.false.", loss_false)
tracker.add("loss.discriminator.", loss)
# Log stuff
tracker.add("loss.discriminator.true.", loss_true)
tracker.add("loss.discriminator.false.", loss_false)
tracker.add("loss.discriminator.", loss)
# Train
if MODE_STATE.is_train:
loss.backward()
if MODE_STATE.is_log_parameters:
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
self.discriminator_optimizer.step()
# Train
if MODE_STATE.is_train:
loss.backward()
if MODE_STATE.is_log_parameters:
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
self.discriminator_optimizer.step()
# Train the generator
with monit.section("generator"):
@ -156,6 +171,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
discriminator_loss: DiscriminatorLogitsLoss
batch_step = 'gan_batch_step'
label_smoothing: float = 0.2
discriminator_k: int = 1
@option(Configs.dataset_transforms)
@ -173,11 +189,12 @@ def gan_batch_step(c: Configs):
discriminator_optimizer=c.discriminator_optimizer,
generator_optimizer=c.generator_optimizer,
discriminator_loss=c.discriminator_loss,
generator_loss=c.generator_loss)
generator_loss=c.generator_loss,
discriminator_k=c.discriminator_k)
calculate(Configs.generator, lambda c: Generator().to(c.device))
calculate(Configs.discriminator, lambda c: Discriminator().to(c.device))
calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))