10from typing import Any
11
12from torchvision import transforms
13
14import torch
15import torch.nn as nn
16import torch.utils.data
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
20from labml_nn.helpers.datasets import MNISTConfigs
21from labml_nn.helpers.device import DeviceConfigs
22from labml_nn.helpers.optimizer import OptimizerConfigs
23from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex26def weights_init(m):
27    classname = m.__class__.__name__
28    if classname.find('Linear') != -1:
29        nn.init.normal_(m.weight.data, 0.0, 0.02)
30    elif classname.find('BatchNorm') != -1:
31        nn.init.normal_(m.weight.data, 1.0, 0.02)
32        nn.init.constant_(m.bias.data, 0)This has three linear layers of increasing size with LeakyReLU
 activations. The final layer has a  activation.
35class Generator(nn.Module):43    def __init__(self):
44        super().__init__()
45        layer_sizes = [256, 512, 1024]
46        layers = []
47        d_prev = 100
48        for size in layer_sizes:
49            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
50            d_prev = size
51
52        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
53
54        self.apply(weights_init)56    def forward(self, x):
57        return self.layers(x).view(x.shape[0], 1, 28, 28)This has three linear layers of decreasing size with LeakyReLU
 activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.
60class Discriminator(nn.Module):69    def __init__(self):
70        super().__init__()
71        layer_sizes = [1024, 512, 256]
72        layers = []
73        d_prev = 28 * 28
74        for size in layer_sizes:
75            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
76            d_prev = size
77
78        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
79        self.apply(weights_init)81    def forward(self, x):
82        return self.layers(x.view(x.shape[0], -1))This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.
85class Configs(MNISTConfigs, TrainValidConfigs):93    device: torch.device = DeviceConfigs()
94    dataset_transforms = 'mnist_gan_transforms'
95    epochs: int = 10
96
97    is_save_models = True
98    discriminator: nn.Module = 'mlp'
99    generator: nn.Module = 'mlp'
100    generator_optimizer: torch.optim.Adam
101    discriminator_optimizer: torch.optim.Adam
102    generator_loss: GeneratorLogitsLoss = 'original'
103    discriminator_loss: DiscriminatorLogitsLoss = 'original'
104    label_smoothing: float = 0.2
105    discriminator_k: int = 1Initializations
107    def init(self):111        self.state_modules = []
112
113        tracker.set_scalar("loss.generator.*", True)
114        tracker.set_scalar("loss.discriminator.*", True)
115        tracker.set_image("generated", True, 1 / 100)117    def sample_z(self, batch_size: int):121        return torch.randn(batch_size, 100, device=self.device)Take a training step
123    def step(self, batch: Any, batch_idx: BatchIndex):Set model states
129        self.generator.train(self.mode.is_train)
130        self.discriminator.train(self.mode.is_train)Get MNIST images
133        data = batch[0].to(self.device)Increment step in training mode
136        if self.mode.is_train:
137            tracker.add_global_step(len(data))Train the discriminator
140        with monit.section("discriminator"):Get discriminator loss
142            loss = self.calc_discriminator_loss(data)Train
145            if self.mode.is_train:
146                self.discriminator_optimizer.zero_grad()
147                loss.backward()
148                if batch_idx.is_last:
149                    tracker.add('discriminator', self.discriminator)
150                self.discriminator_optimizer.step()Train the generator once in every discriminator_k
 
153        if batch_idx.is_interval(self.discriminator_k):
154            with monit.section("generator"):
155                loss = self.calc_generator_loss(data.shape[0])Train
158                if self.mode.is_train:
159                    self.generator_optimizer.zero_grad()
160                    loss.backward()
161                    if batch_idx.is_last:
162                        tracker.add('generator', self.generator)
163                    self.generator_optimizer.step()
164
165        tracker.save()Calculate discriminator loss
167    def calc_discriminator_loss(self, data):171        latent = self.sample_z(data.shape[0])
172        logits_true = self.discriminator(data)
173        logits_false = self.discriminator(self.generator(latent).detach())
174        loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
175        loss = loss_true + loss_falseLog stuff
178        tracker.add("loss.discriminator.true.", loss_true)
179        tracker.add("loss.discriminator.false.", loss_false)
180        tracker.add("loss.discriminator.", loss)
181
182        return lossCalculate generator loss
184    def calc_generator_loss(self, batch_size: int):188        latent = self.sample_z(batch_size)
189        generated_images = self.generator(latent)
190        logits = self.discriminator(generated_images)
191        loss = self.generator_loss(logits)Log stuff
194        tracker.add('generated', generated_images[0:6])
195        tracker.add("loss.generator.", loss)
196
197        return loss200@option(Configs.dataset_transforms)
201def mnist_gan_transforms():
202    return transforms.Compose([
203        transforms.ToTensor(),
204        transforms.Normalize((0.5,), (0.5,))
205    ])
206
207
208@option(Configs.discriminator_optimizer)
209def _discriminator_optimizer(c: Configs):
210    opt_conf = OptimizerConfigs()
211    opt_conf.optimizer = 'Adam'
212    opt_conf.parameters = c.discriminator.parameters()
213    opt_conf.learning_rate = 2.5e-4Setting exponent decay rate for first moment of gradient,  to 0.5
 is important. Default of 0.9
 fails. 
217    opt_conf.betas = (0.5, 0.999)
218    return opt_conf221@option(Configs.generator_optimizer)
222def _generator_optimizer(c: Configs):
223    opt_conf = OptimizerConfigs()
224    opt_conf.optimizer = 'Adam'
225    opt_conf.parameters = c.generator.parameters()
226    opt_conf.learning_rate = 2.5e-4Setting exponent decay rate for first moment of gradient,  to 0.5
 is important. Default of 0.9
 fails. 
230    opt_conf.betas = (0.5, 0.999)
231    return opt_conf
232
233
234calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
235calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
236calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
237calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))240def main():
241    conf = Configs()
242    experiment.create(name='mnist_gan', comment='test')
243    experiment.configs(conf,
244                       {'label_smoothing': 0.01})
245    with experiment.start():
246        conf.run()
247
248
249if __name__ == '__main__':
250    main()