mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-30 10:18:50 +08:00 
			
		
		
		
	Merge branch 'master' of github.com:lab-ml/transformers
merge
This commit is contained in:
		| @ -28,6 +28,8 @@ and | ||||
| #### ✨ [Capsule Networks](http://lab-ml.com/labml_nn/capsule_networks/) | ||||
|  | ||||
| #### ✨ [Generative Adversarial Networks](http://lab-ml.com/labml_nn/gan/) | ||||
| * [GAN with a multi-layer perceptron](http://lab-ml.com/labml_nn/gan/simple_mnist_experiment.html) | ||||
| * [GAN with deep convolutional network](http://lab-ml.com/labml_nn/gan/dcgan.html) | ||||
|  | ||||
| ### Installation | ||||
|  | ||||
|  | ||||
| @ -1,3 +1,10 @@ | ||||
| """ | ||||
| This is an implementation of paper | ||||
| [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434). | ||||
|  | ||||
| This implementation is based on the [PyTorch DCGAN Tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html). | ||||
| """ | ||||
|  | ||||
| import torch.nn as nn | ||||
|  | ||||
| from labml import experiment | ||||
| @ -9,24 +16,30 @@ from labml_nn.gan.simple_mnist_experiment import Configs | ||||
| class Generator(Module): | ||||
|     """ | ||||
|     ### Convolutional Generator Network | ||||
|  | ||||
|     This is similar to the de-convolutional network used for CelebA faces, | ||||
|     but modified for MNIST images. | ||||
|  | ||||
|     <img src="https://pytorch.org/tutorials/_images/dcgan_generator.png" style="max-width:90%" /> | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         # The input is $1 \times 1$ with 100 channels | ||||
|         self.layers = nn.Sequential( | ||||
|             # Gives a 3x3 output | ||||
|             # This gives $3 \times 3$ output | ||||
|             nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False), | ||||
|             nn.BatchNorm2d(1024), | ||||
|             nn.ReLU(True), | ||||
|             # This gives a 7x7 | ||||
|             # This gives $7 \times 7$ | ||||
|             nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False), | ||||
|             nn.BatchNorm2d(512), | ||||
|             nn.ReLU(True), | ||||
|             # This give 14x14 | ||||
|             # This give $14 \times 14$ | ||||
|             nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), | ||||
|             nn.BatchNorm2d(256), | ||||
|             nn.ReLU(True), | ||||
|             # This gives 28*28 | ||||
|             # This gives $28 \times 28$ | ||||
|             nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False), | ||||
|             nn.Tanh() | ||||
|         ) | ||||
| @ -34,6 +47,7 @@ class Generator(Module): | ||||
|         self.apply(_weights_init) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         # Change from shape `[batch_size, 100]` to `[batch_size, 100, 1, 1]` | ||||
|         x = x.unsqueeze(-1).unsqueeze(-1) | ||||
|         x = self.layers(x) | ||||
|         return x | ||||
| @ -46,19 +60,20 @@ class Discriminator(Module): | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         # The input is $28 \times 28$ with one channel | ||||
|         self.layers = nn.Sequential( | ||||
|             # This gives 14x14 | ||||
|             # This gives $14 \times 14$ | ||||
|             nn.Conv2d(1, 256, 4, 2, 1, bias=False), | ||||
|             nn.LeakyReLU(0.2, inplace=True), | ||||
|             # This gives 7x7 | ||||
|             # This gives $7 \times 7$ | ||||
|             nn.Conv2d(256, 512, 4, 2, 1, bias=False), | ||||
|             nn.BatchNorm2d(512), | ||||
|             nn.LeakyReLU(0.2, inplace=True), | ||||
|             # This gives 3x3 | ||||
|             # This gives $3 \times 3$ | ||||
|             nn.Conv2d(512, 1024, 3, 2, 0, bias=False), | ||||
|             nn.BatchNorm2d(1024), | ||||
|             nn.LeakyReLU(0.2, inplace=True), | ||||
|             # state size. (ndf*4) x 8 x 8 | ||||
|             # This gives $1 \times 1$ | ||||
|             nn.Conv2d(1024, 1, 3, 1, 0, bias=False), | ||||
|         ) | ||||
|         self.apply(_weights_init) | ||||
|  | ||||
| @ -25,6 +25,8 @@ and | ||||
| #### ✨ [Capsule Networks](http://lab-ml.com/labml_nn/capsule_networks/) | ||||
|  | ||||
| #### ✨ [Generative Adversarial Networks](http://lab-ml.com/labml_nn/gan/) | ||||
| * [GAN with a multi-layer perceptron](http://lab-ml.com/labml_nn/gan/simple_mnist_experiment.html) | ||||
| * [GAN with deep convolutional network](http://lab-ml.com/labml_nn/gan/dcgan.html) | ||||
|  | ||||
| ### Installation | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri