mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-03 13:57:48 +08:00
✨ dcgan
This commit is contained in:
@ -20,7 +20,7 @@ $p_{data}(\pmb{x})$ is the probability distribution over data,
|
|||||||
whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to
|
whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to
|
||||||
gaussian noise.
|
gaussian noise.
|
||||||
|
|
||||||
This file defines the loss functions. [Here](gan_mnist.html) is an MNIST example
|
This file defines the loss functions. [Here](simple_mnist_experiment.html) is an MNIST example
|
||||||
with two multilayer perceptron for the generator and discriminator.
|
with two multilayer perceptron for the generator and discriminator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
99
labml_nn/gan/dcgan.py
Normal file
99
labml_nn/gan/dcgan.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from labml import experiment
|
||||||
|
from labml.configs import calculate
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from labml_nn.gan.simple_mnist_experiment import Configs
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(Module):
|
||||||
|
"""
|
||||||
|
### Convolutional Generator Network
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
# Gives a 3x3 output
|
||||||
|
nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
|
||||||
|
nn.BatchNorm2d(1024),
|
||||||
|
nn.ReLU(True),
|
||||||
|
# This gives a 7x7
|
||||||
|
nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
|
||||||
|
nn.BatchNorm2d(512),
|
||||||
|
nn.ReLU(True),
|
||||||
|
# This give 14x14
|
||||||
|
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
nn.ReLU(True),
|
||||||
|
# This gives 28*28
|
||||||
|
nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
|
||||||
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.apply(_weights_init)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = x.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
x = self.layers(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator(Module):
|
||||||
|
"""
|
||||||
|
### Convolutional Discriminator Network
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
# This gives 14x14
|
||||||
|
nn.Conv2d(1, 256, 4, 2, 1, bias=False),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
# This gives 7x7
|
||||||
|
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(512),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
# This gives 3x3
|
||||||
|
nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
|
||||||
|
nn.BatchNorm2d(1024),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
# state size. (ndf*4) x 8 x 8
|
||||||
|
nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
|
||||||
|
)
|
||||||
|
self.apply(_weights_init)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.layers(x)
|
||||||
|
return x.view(x.shape[0], -1)
|
||||||
|
|
||||||
|
|
||||||
|
def _weights_init(m):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find('Conv') != -1:
|
||||||
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||||
|
elif classname.find('BatchNorm') != -1:
|
||||||
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||||
|
nn.init.constant_(m.bias.data, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the
|
||||||
|
# generator and discriminator networks
|
||||||
|
calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
|
||||||
|
calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
conf = Configs()
|
||||||
|
experiment.create(name='mnist_dcgan', comment='test')
|
||||||
|
experiment.configs(conf,
|
||||||
|
{'discriminator': 'cnn',
|
||||||
|
'generator': 'cnn',
|
||||||
|
'label_smoothing': 0.01},
|
||||||
|
'run')
|
||||||
|
with experiment.start():
|
||||||
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -16,6 +16,15 @@ from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidC
|
|||||||
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
||||||
|
|
||||||
|
|
||||||
|
def weights_init(m):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find('Linear') != -1:
|
||||||
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||||
|
elif classname.find('BatchNorm') != -1:
|
||||||
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||||
|
nn.init.constant_(m.bias.data, 0)
|
||||||
|
|
||||||
|
|
||||||
class Generator(Module):
|
class Generator(Module):
|
||||||
"""
|
"""
|
||||||
### Simple MLP Generator
|
### Simple MLP Generator
|
||||||
@ -35,6 +44,8 @@ class Generator(Module):
|
|||||||
|
|
||||||
self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
|
self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
|
||||||
|
|
||||||
|
self.apply(weights_init)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.layers(x).view(x.shape[0], 1, 28, 28)
|
return self.layers(x).view(x.shape[0], 1, 28, 28)
|
||||||
|
|
||||||
@ -58,6 +69,7 @@ class Discriminator(Module):
|
|||||||
d_prev = size
|
d_prev = size
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
|
self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
|
||||||
|
self.apply(weights_init)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.layers(x.view(x.shape[0], -1))
|
return self.layers(x.view(x.shape[0], -1))
|
||||||
@ -70,8 +82,10 @@ class GANBatchStep(BatchStepProtocol):
|
|||||||
discriminator_optimizer: Optional[torch.optim.Adam],
|
discriminator_optimizer: Optional[torch.optim.Adam],
|
||||||
generator_optimizer: Optional[torch.optim.Adam],
|
generator_optimizer: Optional[torch.optim.Adam],
|
||||||
discriminator_loss: DiscriminatorLogitsLoss,
|
discriminator_loss: DiscriminatorLogitsLoss,
|
||||||
generator_loss: GeneratorLogitsLoss):
|
generator_loss: GeneratorLogitsLoss,
|
||||||
|
discriminator_k: int):
|
||||||
|
|
||||||
|
self.discriminator_k = discriminator_k
|
||||||
self.generator = generator
|
self.generator = generator
|
||||||
self.discriminator = discriminator
|
self.discriminator = discriminator
|
||||||
self.generator_loss = generator_loss
|
self.generator_loss = generator_loss
|
||||||
@ -100,25 +114,26 @@ class GANBatchStep(BatchStepProtocol):
|
|||||||
|
|
||||||
# Train the discriminator
|
# Train the discriminator
|
||||||
with monit.section("discriminator"):
|
with monit.section("discriminator"):
|
||||||
latent = torch.randn(data.shape[0], 100, device=device)
|
for _ in range(self.discriminator_k):
|
||||||
if MODE_STATE.is_train:
|
latent = torch.randn(data.shape[0], 100, device=device)
|
||||||
self.discriminator_optimizer.zero_grad()
|
if MODE_STATE.is_train:
|
||||||
logits_true = self.discriminator(data)
|
self.discriminator_optimizer.zero_grad()
|
||||||
logits_false = self.discriminator(self.generator(latent).detach())
|
logits_true = self.discriminator(data)
|
||||||
loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
|
logits_false = self.discriminator(self.generator(latent).detach())
|
||||||
loss = loss_true + loss_false
|
loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
|
||||||
|
loss = loss_true + loss_false
|
||||||
|
|
||||||
# Log stuff
|
# Log stuff
|
||||||
tracker.add("loss.discriminator.true.", loss_true)
|
tracker.add("loss.discriminator.true.", loss_true)
|
||||||
tracker.add("loss.discriminator.false.", loss_false)
|
tracker.add("loss.discriminator.false.", loss_false)
|
||||||
tracker.add("loss.discriminator.", loss)
|
tracker.add("loss.discriminator.", loss)
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
if MODE_STATE.is_train:
|
if MODE_STATE.is_train:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if MODE_STATE.is_log_parameters:
|
if MODE_STATE.is_log_parameters:
|
||||||
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
||||||
self.discriminator_optimizer.step()
|
self.discriminator_optimizer.step()
|
||||||
|
|
||||||
# Train the generator
|
# Train the generator
|
||||||
with monit.section("generator"):
|
with monit.section("generator"):
|
||||||
@ -156,6 +171,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
discriminator_loss: DiscriminatorLogitsLoss
|
discriminator_loss: DiscriminatorLogitsLoss
|
||||||
batch_step = 'gan_batch_step'
|
batch_step = 'gan_batch_step'
|
||||||
label_smoothing: float = 0.2
|
label_smoothing: float = 0.2
|
||||||
|
discriminator_k: int = 1
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.dataset_transforms)
|
@option(Configs.dataset_transforms)
|
||||||
@ -173,11 +189,12 @@ def gan_batch_step(c: Configs):
|
|||||||
discriminator_optimizer=c.discriminator_optimizer,
|
discriminator_optimizer=c.discriminator_optimizer,
|
||||||
generator_optimizer=c.generator_optimizer,
|
generator_optimizer=c.generator_optimizer,
|
||||||
discriminator_loss=c.discriminator_loss,
|
discriminator_loss=c.discriminator_loss,
|
||||||
generator_loss=c.generator_loss)
|
generator_loss=c.generator_loss,
|
||||||
|
discriminator_k=c.discriminator_k)
|
||||||
|
|
||||||
|
|
||||||
calculate(Configs.generator, lambda c: Generator().to(c.device))
|
calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
|
||||||
calculate(Configs.discriminator, lambda c: Discriminator().to(c.device))
|
calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
|
||||||
calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
|
calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
|
||||||
calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
|
calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
|
||||||
|
|
||||||
Reference in New Issue
Block a user