diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py index dfc9abd6..da4c3ac4 100644 --- a/labml_nn/__init__.py +++ b/labml_nn/__init__.py @@ -28,6 +28,8 @@ and #### ✨ [Capsule Networks](http://lab-ml.com/labml_nn/capsule_networks/) #### ✨ [Generative Adversarial Networks](http://lab-ml.com/labml_nn/gan/) +* [GAN with a multi-layer perceptron](http://lab-ml.com/labml_nn/gan/simple_mnist_experiment.html) +* [GAN with deep convolutional network](http://lab-ml.com/labml_nn/gan/dcgan.html) ### Installation diff --git a/labml_nn/gan/dcgan.py b/labml_nn/gan/dcgan.py index 9d390681..7092f18f 100644 --- a/labml_nn/gan/dcgan.py +++ b/labml_nn/gan/dcgan.py @@ -1,3 +1,10 @@ +""" +This is an implementation of paper +[Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434). + +This implementation is based on the [PyTorch DCGAN Tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html). +""" + import torch.nn as nn from labml import experiment @@ -9,24 +16,30 @@ from labml_nn.gan.simple_mnist_experiment import Configs class Generator(Module): """ ### Convolutional Generator Network + + This is similar to the de-convolutional network used for CelebA faces, + but modified for MNIST images. + + """ def __init__(self): super().__init__() + # The input is $1 \times 1$ with 100 channels self.layers = nn.Sequential( - # Gives a 3x3 output + # This gives $3 \times 3$ output nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False), nn.BatchNorm2d(1024), nn.ReLU(True), - # This gives a 7x7 + # This gives $7 \times 7$ nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), - # This give 14x14 + # This give $14 \times 14$ nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), - # This gives 28*28 + # This gives $28 \times 28$ nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False), nn.Tanh() ) @@ -34,6 +47,7 @@ class Generator(Module): self.apply(_weights_init) def __call__(self, x): + # Change from shape `[batch_size, 100]` to `[batch_size, 100, 1, 1]` x = x.unsqueeze(-1).unsqueeze(-1) x = self.layers(x) return x @@ -46,19 +60,20 @@ class Discriminator(Module): def __init__(self): super().__init__() + # The input is $28 \times 28$ with one channel self.layers = nn.Sequential( - # This gives 14x14 + # This gives $14 \times 14$ nn.Conv2d(1, 256, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), - # This gives 7x7 + # This gives $7 \times 7$ nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), - # This gives 3x3 + # This gives $3 \times 3$ nn.Conv2d(512, 1024, 3, 2, 0, bias=False), nn.BatchNorm2d(1024), nn.LeakyReLU(0.2, inplace=True), - # state size. (ndf*4) x 8 x 8 + # This gives $1 \times 1$ nn.Conv2d(1024, 1, 3, 1, 0, bias=False), ) self.apply(_weights_init) diff --git a/readme.md b/readme.md index 13c6521e..c97443ba 100644 --- a/readme.md +++ b/readme.md @@ -25,6 +25,8 @@ and #### ✨ [Capsule Networks](http://lab-ml.com/labml_nn/capsule_networks/) #### ✨ [Generative Adversarial Networks](http://lab-ml.com/labml_nn/gan/) +* [GAN with a multi-layer perceptron](http://lab-ml.com/labml_nn/gan/simple_mnist_experiment.html) +* [GAN with deep convolutional network](http://lab-ml.com/labml_nn/gan/dcgan.html) ### Installation