10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option, calculate
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss27def weights_init(m):
28    classname = m.__class__.__name__
29    if classname.find('Linear') != -1:
30        nn.init.normal_(m.weight.data, 0.0, 0.02)
31    elif classname.find('BatchNorm') != -1:
32        nn.init.normal_(m.weight.data, 1.0, 0.02)
33        nn.init.constant_(m.bias.data, 0)This has three linear layers of increasing size with LeakyReLU
 activations. The final layer has a  activation.
36class Generator(Module):44    def __init__(self):
45        super().__init__()
46        layer_sizes = [256, 512, 1024]
47        layers = []
48        d_prev = 100
49        for size in layer_sizes:
50            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51            d_prev = size
52
53        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55        self.apply(weights_init)57    def forward(self, x):
58        return self.layers(x).view(x.shape[0], 1, 28, 28)This has three linear layers of decreasing size with LeakyReLU
 activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.
61class Discriminator(Module):70    def __init__(self):
71        super().__init__()
72        layer_sizes = [1024, 512, 256]
73        layers = []
74        d_prev = 28 * 28
75        for size in layer_sizes:
76            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77            d_prev = size
78
79        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80        self.apply(weights_init)82    def forward(self, x):
83        return self.layers(x.view(x.shape[0], -1))This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.
86class Configs(MNISTConfigs, TrainValidConfigs):94    device: torch.device = DeviceConfigs()
95    dataset_transforms = 'mnist_gan_transforms'
96    epochs: int = 10
97
98    is_save_models = True
99    discriminator: Module = 'mlp'
100    generator: Module = 'mlp'
101    generator_optimizer: torch.optim.Adam
102    discriminator_optimizer: torch.optim.Adam
103    generator_loss: GeneratorLogitsLoss = 'original'
104    discriminator_loss: DiscriminatorLogitsLoss = 'original'
105    label_smoothing: float = 0.2
106    discriminator_k: int = 1Initializations
108    def init(self):112        self.state_modules = []
113
114        hook_model_outputs(self.mode, self.generator, 'generator')
115        hook_model_outputs(self.mode, self.discriminator, 'discriminator')
116        tracker.set_scalar("loss.generator.*", True)
117        tracker.set_scalar("loss.discriminator.*", True)
118        tracker.set_image("generated", True, 1 / 100)120    def sample_z(self, batch_size: int):124        return torch.randn(batch_size, 100, device=self.device)Take a training step
126    def step(self, batch: Any, batch_idx: BatchIndex):Set model states
132        self.generator.train(self.mode.is_train)
133        self.discriminator.train(self.mode.is_train)Get MNIST images
136        data = batch[0].to(self.device)Increment step in training mode
139        if self.mode.is_train:
140            tracker.add_global_step(len(data))Train the discriminator
143        with monit.section("discriminator"):Get discriminator loss
145            loss = self.calc_discriminator_loss(data)Train
148            if self.mode.is_train:
149                self.discriminator_optimizer.zero_grad()
150                loss.backward()
151                if batch_idx.is_last:
152                    tracker.add('discriminator', self.discriminator)
153                self.discriminator_optimizer.step()Train the generator once in every discriminator_k
 
156        if batch_idx.is_interval(self.discriminator_k):
157            with monit.section("generator"):
158                loss = self.calc_generator_loss(data.shape[0])Train
161                if self.mode.is_train:
162                    self.generator_optimizer.zero_grad()
163                    loss.backward()
164                    if batch_idx.is_last:
165                        tracker.add('generator', self.generator)
166                    self.generator_optimizer.step()
167
168        tracker.save()Calculate discriminator loss
170    def calc_discriminator_loss(self, data):174        latent = self.sample_z(data.shape[0])
175        logits_true = self.discriminator(data)
176        logits_false = self.discriminator(self.generator(latent).detach())
177        loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
178        loss = loss_true + loss_falseLog stuff
181        tracker.add("loss.discriminator.true.", loss_true)
182        tracker.add("loss.discriminator.false.", loss_false)
183        tracker.add("loss.discriminator.", loss)
184
185        return lossCalculate generator loss
187    def calc_generator_loss(self, batch_size: int):191        latent =  self.sample_z(batch_size)
192        generated_images = self.generator(latent)
193        logits = self.discriminator(generated_images)
194        loss = self.generator_loss(logits)Log stuff
197        tracker.add('generated', generated_images[0:6])
198        tracker.add("loss.generator.", loss)
199
200        return loss205@option(Configs.dataset_transforms)
206def mnist_gan_transforms():
207    return transforms.Compose([
208        transforms.ToTensor(),
209        transforms.Normalize((0.5,), (0.5,))
210    ])
211
212
213@option(Configs.discriminator_optimizer)
214def _discriminator_optimizer(c: Configs):
215    opt_conf = OptimizerConfigs()
216    opt_conf.optimizer = 'Adam'
217    opt_conf.parameters = c.discriminator.parameters()
218    opt_conf.learning_rate = 2.5e-4Setting exponent decay rate for first moment of gradient,  to 0.5
 is important. Default of 0.9
 fails. 
222    opt_conf.betas = (0.5, 0.999)
223    return opt_conf226@option(Configs.generator_optimizer)
227def _generator_optimizer(c: Configs):
228    opt_conf = OptimizerConfigs()
229    opt_conf.optimizer = 'Adam'
230    opt_conf.parameters = c.generator.parameters()
231    opt_conf.learning_rate = 2.5e-4Setting exponent decay rate for first moment of gradient,  to 0.5
 is important. Default of 0.9
 fails. 
235    opt_conf.betas = (0.5, 0.999)
236    return opt_conf
237
238
239calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
240calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
241calculate(Configs.generator_loss, 'original', lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
242calculate(Configs.discriminator_loss, 'original', lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))245def main():
246    conf = Configs()
247    experiment.create(name='mnist_gan', comment='test')
248    experiment.configs(conf,
249                       {'label_smoothing': 0.01})
250    with experiment.start():
251        conf.run()
252
253
254if __name__ == '__main__':
255    main()