10from typing import Any
11
12import torch
13import torch.nn as nn
14import torch.utils.data
15from torchvision import transforms
16
17from labml import tracker, monit, experiment
18from labml.configs import option
19from labml_helpers.datasets.mnist import MNISTConfigs
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_helpers.optimizer import OptimizerConfigs
23from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
24from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss27def weights_init(m):
28    classname = m.__class__.__name__
29    if classname.find('Linear') != -1:
30        nn.init.normal_(m.weight.data, 0.0, 0.02)
31    elif classname.find('BatchNorm') != -1:
32        nn.init.normal_(m.weight.data, 1.0, 0.02)
33        nn.init.constant_(m.bias.data, 0)This has three linear layers of increasing size with LeakyReLU activations.
The final layer has a $tanh$ activation.
36class Generator(Module):44    def __init__(self):
45        super().__init__()
46        layer_sizes = [256, 512, 1024]
47        layers = []
48        d_prev = 100
49        for size in layer_sizes:
50            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
51            d_prev = size
52
53        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh())
54
55        self.apply(weights_init)57    def forward(self, x):
58        return self.layers(x).view(x.shape[0], 1, 28, 28)This has three  linear layers of decreasing size with LeakyReLU activations.
The final layer has a single output that gives the logit of whether input
is real or fake. You can get the probability by calculating the sigmoid of it.
61class Discriminator(Module):70    def __init__(self):
71        super().__init__()
72        layer_sizes = [1024, 512, 256]
73        layers = []
74        d_prev = 28 * 28
75        for size in layer_sizes:
76            layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)]
77            d_prev = size
78
79        self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1))
80        self.apply(weights_init)82    def forward(self, x):
83        return self.layers(x.view(x.shape[0], -1))86class Configs(MNISTConfigs, TrainValidConfigs):
87    device: torch.device = DeviceConfigs()
88    epochs: int = 10
89
90    is_save_models = True
91    discriminator: Module
92    generator: Module
93    generator_optimizer: torch.optim.Adam
94    discriminator_optimizer: torch.optim.Adam
95    generator_loss: GeneratorLogitsLoss
96    discriminator_loss: DiscriminatorLogitsLoss
97    label_smoothing: float = 0.2
98    discriminator_k: int = 1100    def init(self):
101        self.state_modules = []
102        self.generator = Generator().to(self.device)
103        self.discriminator = Discriminator().to(self.device)
104        self.generator_loss = GeneratorLogitsLoss(self.label_smoothing).to(self.device)
105        self.discriminator_loss = DiscriminatorLogitsLoss(self.label_smoothing).to(self.device)
106
107        hook_model_outputs(self.mode, self.generator, 'generator')
108        hook_model_outputs(self.mode, self.discriminator, 'discriminator')
109        tracker.set_scalar("loss.generator.*", True)
110        tracker.set_scalar("loss.discriminator.*", True)
111        tracker.set_image("generated", True, 1 / 100)113    def step(self, batch: Any, batch_idx: BatchIndex):
114        self.generator.train(self.mode.is_train)
115        self.discriminator.train(self.mode.is_train)
116
117        data, target = batch[0].to(self.device), batch[1].to(self.device)Increment step in training mode
120        if self.mode.is_train:
121            tracker.add_global_step(len(data))Train the discriminator
124        with monit.section("discriminator"):
125            for _ in range(self.discriminator_k):
126                latent = torch.randn(data.shape[0], 100, device=self.device)
127                logits_true = self.discriminator(data)
128                logits_false = self.discriminator(self.generator(latent).detach())
129                loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
130                loss = loss_true + loss_falseLog stuff
133                tracker.add("loss.discriminator.true.", loss_true)
134                tracker.add("loss.discriminator.false.", loss_false)
135                tracker.add("loss.discriminator.", loss)Train
138                if self.mode.is_train:
139                    self.discriminator_optimizer.zero_grad()
140                    loss.backward()
141                    if batch_idx.is_last:
142                        tracker.add('discriminator', self.discriminator)
143                    self.discriminator_optimizer.step()Train the generator
146        with monit.section("generator"):
147            latent = torch.randn(data.shape[0], 100, device=self.device)
148            generated_images = self.generator(latent)
149            logits = self.discriminator(generated_images)
150            loss = self.generator_loss(logits)Log stuff
153            tracker.add('generated', generated_images[0:5])
154            tracker.add("loss.generator.", loss)Train
157            if self.mode.is_train:
158                self.generator_optimizer.zero_grad()
159                loss.backward()
160                if batch_idx.is_last:
161                    tracker.add('generator', self.generator)
162                self.generator_optimizer.step()
163
164        tracker.save()167@option(Configs.dataset_transforms)
168def mnist_transforms():
169    return transforms.Compose([
170        transforms.ToTensor(),
171        transforms.Normalize((0.5,), (0.5,))
172    ])
173
174
175@option(Configs.discriminator_optimizer)
176def _discriminator_optimizer(c: Configs):
177    opt_conf = OptimizerConfigs()
178    opt_conf.optimizer = 'Adam'
179    opt_conf.parameters = c.discriminator.parameters()
180    opt_conf.learning_rate = 2.5e-4Setting exponent decay rate for first moment of gradient,
$\beta_$ to0.5is important.
Default of0.9` fails.
184    opt_conf.betas = (0.5, 0.999)
185    return opt_conf188@option(Configs.generator_optimizer)
189def _generator_optimizer(c: Configs):
190    opt_conf = OptimizerConfigs()
191    opt_conf.optimizer = 'Adam'
192    opt_conf.parameters = c.generator.parameters()
193    opt_conf.learning_rate = 2.5e-4Setting exponent decay rate for first moment of gradient,
$\beta_$ to0.5is important.
Default of0.9` fails.
197    opt_conf.betas = (0.5, 0.999)
198    return opt_conf201def main():
202    conf = Configs()
203    experiment.create(name='mnist_gan', comment='test')
204    experiment.configs(conf,
205                       {'label_smoothing': 0.01})
206    with experiment.start():
207        conf.run()
208
209
210if __name__ == '__main__':
211    main()