from typing import Optional import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.utils.data from torchvision import transforms import labml.utils.pytorch as pytorch_utils from labml import tracker, monit, experiment from labml.configs import option, calculate from labml_helpers.datasets.mnist import MNISTConfigs from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module from labml_helpers.optimizer import OptimizerConfigs from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs, hook_model_outputs, Mode from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray' class Generator(Module): def __init__(self): super(Generator, self).__init__() layer_sizes = [256, 512, 1024] layers = [] d_prev = 100 for size in layer_sizes: layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)] d_prev = size self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 28 * 28), nn.Tanh()) def forward(self, x): x = self.layers(x) x = x.view(x.shape[0], 1, 28, 28) return x class Discriminator(Module): def __init__(self): super(Discriminator, self).__init__() layer_sizes = [512, 256] layers = [] d_prev = 28 * 28 for size in layer_sizes: layers = layers + [nn.Linear(d_prev, size), nn.LeakyReLU(0.2)] d_prev = size self.layers = nn.Sequential(*layers, nn.Linear(d_prev, 1)) def forward(self, x): return self.layers(x.view(x.shape[0], -1)) class GANBatchStep(BatchStepProtocol): def __init__(self, *, discriminator: Module, generator: Module, discriminator_optimizer: Optional[torch.optim.Adam], generator_optimizer: Optional[torch.optim.Adam], discriminator_loss: DiscriminatorLogitsLoss, generator_loss: GeneratorLogitsLoss): self.generator = generator self.discriminator = discriminator self.generator_loss = generator_loss self.discriminator_loss = discriminator_loss self.generator_optimizer = generator_optimizer self.discriminator_optimizer = discriminator_optimizer hook_model_outputs(self.generator, 'generator') hook_model_outputs(self.discriminator, 'discriminator') tracker.set_scalar("loss.generator.*", True) tracker.set_scalar("loss.discriminator.*", True) tracker.set_image("generated", True, 1 / 100) def prepare_for_iteration(self): if MODE_STATE.is_train: self.generator.train() self.discriminator.train() else: self.generator.eval() self.discriminator.eval() def process(self, batch: any, state: any): device = self.discriminator.device data, target = batch data, target = data.to(device), target.to(device) with monit.section("generator"): latent = torch.randn(data.shape[0], 100, device=device) if MODE_STATE.is_train: self.generator_optimizer.zero_grad() generated_images = self.generator(latent) tracker.add('generated', generated_images[0:5]) logits = self.discriminator(generated_images) loss = self.generator_loss(logits) tracker.add("loss.generator.", loss) if MODE_STATE.is_train: loss.backward() if MODE_STATE.is_log_parameters: pytorch_utils.store_model_indicators(self.generator, 'generator') self.generator_optimizer.step() with monit.section("discriminator"): latent = torch.randn(data.shape[0], 100, device=device) if MODE_STATE.is_train: self.discriminator_optimizer.zero_grad() logits_true = self.discriminator(data) logits_false = self.discriminator(self.generator(latent).detach()) loss_true, loss_false = self.discriminator_loss(logits_true, logits_false) loss = loss_true + loss_false tracker.add("loss.discriminator.true.", loss_true) tracker.add("loss.discriminator.false.", loss_false) tracker.add("loss.discriminator.", loss) if MODE_STATE.is_train: loss.backward() if MODE_STATE.is_log_parameters: pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') self.discriminator_optimizer.step() return {'samples': len(data)}, None class Configs(MNISTConfigs, TrainValidConfigs): device: torch.device = DeviceConfigs() epochs: int = 10 is_save_models = True discriminator: Module generator: Module generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam generator_loss: GeneratorLogitsLoss discriminator_loss: DiscriminatorLogitsLoss batch_step = 'gan_batch_step' label_smoothing: float = 0.2 @option(Configs.dataset_transforms) def mnist_transforms(): return transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) @option(Configs.batch_step) def gan_batch_step(c: Configs): return GANBatchStep(discriminator=c.discriminator, generator=c.generator, discriminator_optimizer=c.discriminator_optimizer, generator_optimizer=c.generator_optimizer, discriminator_loss=c.discriminator_loss, generator_loss=c.generator_loss) calculate(Configs.generator, lambda c: Generator().to(c.device)) calculate(Configs.discriminator, lambda c: Discriminator().to(c.device)) calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device)) calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device)) @option(Configs.discriminator_optimizer) def _discriminator_optimizer(c: Configs): opt_conf = OptimizerConfigs() opt_conf.optimizer = 'Adam' opt_conf.parameters = c.discriminator.parameters() opt_conf.learning_rate = 2.5e-4 opt_conf.betas = (0.5, 0.999) return opt_conf @option(Configs.generator_optimizer) def _generator_optimizer(c: Configs): opt_conf = OptimizerConfigs() opt_conf.optimizer = 'Adam' opt_conf.parameters = c.generator.parameters() opt_conf.learning_rate = 2.5e-4 opt_conf.betas = (0.5, 0.999) return opt_conf def main(): conf = Configs() experiment.create(name='mnist_gan', comment='test') experiment.configs(conf, {'label_smoothing': 0.01}, 'run') with experiment.start(): conf.run() if __name__ == '__main__': main()