dcgan

This commit is contained in:
Varuna Jayasiri
2020-09-28 11:23:04 +05:30
parent 9bd29d165c
commit 7ef213f89c
3 changed files with 138 additions and 22 deletions

99
labml_nn/gan/dcgan.py Normal file
View File

@ -0,0 +1,99 @@
import torch.nn as nn
from labml import experiment
from labml.configs import calculate
from labml_helpers.module import Module
from labml_nn.gan.simple_mnist_experiment import Configs
class Generator(Module):
"""
### Convolutional Generator Network
"""
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
# Gives a 3x3 output
nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
nn.BatchNorm2d(1024),
nn.ReLU(True),
# This gives a 7x7
nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# This give 14x14
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# This gives 28*28
nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
self.apply(_weights_init)
def __call__(self, x):
x = x.unsqueeze(-1).unsqueeze(-1)
x = self.layers(x)
return x
class Discriminator(Module):
"""
### Convolutional Discriminator Network
"""
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
# This gives 14x14
nn.Conv2d(1, 256, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# This gives 7x7
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# This gives 3x3
nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
)
self.apply(_weights_init)
def forward(self, x):
x = self.layers(x)
return x.view(x.shape[0], -1)
def _weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the
# generator and discriminator networks
calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
def main():
conf = Configs()
experiment.create(name='mnist_dcgan', comment='test')
experiment.configs(conf,
{'discriminator': 'cnn',
'generator': 'cnn',
'label_smoothing': 0.01},
'run')
with experiment.start():
conf.run()
if __name__ == '__main__':
main()