diff --git a/labml_nn/gan/simple_mnist_experiment.py b/labml_nn/gan/simple_mnist_experiment.py index 478fca99..10c1895d 100644 --- a/labml_nn/gan/simple_mnist_experiment.py +++ b/labml_nn/gan/simple_mnist_experiment.py @@ -2,7 +2,7 @@ # Generative Adversarial Networks experiment with MNIST """ -from typing import Optional +from typing import Any import torch import torch.nn as nn @@ -16,7 +16,7 @@ from labml_helpers.datasets.mnist import MNISTConfigs from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module from labml_helpers.optimizer import OptimizerConfigs -from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs, hook_model_outputs +from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss @@ -79,89 +79,6 @@ class Discriminator(Module): return self.layers(x.view(x.shape[0], -1)) -class GANBatchStep(BatchStepProtocol): - def __init__(self, *, - discriminator: Module, - generator: Module, - discriminator_optimizer: Optional[torch.optim.Adam], - generator_optimizer: Optional[torch.optim.Adam], - discriminator_loss: DiscriminatorLogitsLoss, - generator_loss: GeneratorLogitsLoss, - discriminator_k: int): - - self.discriminator_k = discriminator_k - self.generator = generator - self.discriminator = discriminator - self.generator_loss = generator_loss - self.discriminator_loss = discriminator_loss - self.generator_optimizer = generator_optimizer - self.discriminator_optimizer = discriminator_optimizer - - hook_model_outputs(self.generator, 'generator') - hook_model_outputs(self.discriminator, 'discriminator') - tracker.set_scalar("loss.generator.*", True) - tracker.set_scalar("loss.discriminator.*", True) - tracker.set_image("generated", True, 1 / 100) - - def prepare_for_iteration(self): - if MODE_STATE.is_train: - self.generator.train() - self.discriminator.train() - else: - self.generator.eval() - self.discriminator.eval() - - def process(self, batch: any, state: any): - device = self.discriminator.device - data, target = batch - data, target = data.to(device), target.to(device) - - # Train the discriminator - with monit.section("discriminator"): - for _ in range(self.discriminator_k): - latent = torch.randn(data.shape[0], 100, device=device) - if MODE_STATE.is_train: - self.discriminator_optimizer.zero_grad() - logits_true = self.discriminator(data) - logits_false = self.discriminator(self.generator(latent).detach()) - loss_true, loss_false = self.discriminator_loss(logits_true, logits_false) - loss = loss_true + loss_false - - # Log stuff - tracker.add("loss.discriminator.true.", loss_true) - tracker.add("loss.discriminator.false.", loss_false) - tracker.add("loss.discriminator.", loss) - - # Train - if MODE_STATE.is_train: - loss.backward() - if MODE_STATE.is_log_parameters: - pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') - self.discriminator_optimizer.step() - - # Train the generator - with monit.section("generator"): - latent = torch.randn(data.shape[0], 100, device=device) - if MODE_STATE.is_train: - self.generator_optimizer.zero_grad() - generated_images = self.generator(latent) - logits = self.discriminator(generated_images) - loss = self.generator_loss(logits) - - # Log stuff - tracker.add('generated', generated_images[0:5]) - tracker.add("loss.generator.", loss) - - # Train - if MODE_STATE.is_train: - loss.backward() - if MODE_STATE.is_log_parameters: - pytorch_utils.store_model_indicators(self.generator, 'generator') - self.generator_optimizer.step() - - return {'samples': len(data)}, None - - class Configs(MNISTConfigs, TrainValidConfigs): device: torch.device = DeviceConfigs() epochs: int = 10 @@ -173,10 +90,77 @@ class Configs(MNISTConfigs, TrainValidConfigs): discriminator_optimizer: torch.optim.Adam generator_loss: GeneratorLogitsLoss discriminator_loss: DiscriminatorLogitsLoss - batch_step = 'gan_batch_step' label_smoothing: float = 0.2 discriminator_k: int = 1 + log_params_updates: int = 2 ** 32 # 0 if not + + def init(self): + self.state_modules = [] + self.generator = Generator().to(self.device) + self.discriminator = Discriminator().to(self.device) + self.generator_loss = GeneratorLogitsLoss(self.label_smoothing).to(self.device) + self.discriminator_loss = DiscriminatorLogitsLoss(self.label_smoothing).to(self.device) + + hook_model_outputs(self.mode, self.generator, 'generator') + hook_model_outputs(self.mode, self.discriminator, 'discriminator') + tracker.set_scalar("loss.generator.*", True) + tracker.set_scalar("loss.discriminator.*", True) + tracker.set_image("generated", True, 1 / 100) + + def step(self, batch: Any, batch_idx: BatchIndex): + self.generator.train(self.mode.is_train) + self.discriminator.train(self.mode.is_train) + + data, target = batch[0].to(self.device), batch[1].to(self.device) + + # Increment step in training mode + if self.mode.is_train: + tracker.add_global_step(len(data)) + + # Train the discriminator + with monit.section("discriminator"): + for _ in range(self.discriminator_k): + latent = torch.randn(data.shape[0], 100, device=self.device) + logits_true = self.discriminator(data) + logits_false = self.discriminator(self.generator(latent).detach()) + loss_true, loss_false = self.discriminator_loss(logits_true, logits_false) + loss = loss_true + loss_false + + # Log stuff + tracker.add("loss.discriminator.true.", loss_true) + tracker.add("loss.discriminator.false.", loss_false) + tracker.add("loss.discriminator.", loss) + + # Train + if self.mode.is_train: + self.discriminator_optimizer.zero_grad() + loss.backward() + if batch_idx.is_interval(self.log_params_updates): + pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') + self.discriminator_optimizer.step() + + # Train the generator + with monit.section("generator"): + latent = torch.randn(data.shape[0], 100, device=self.device) + generated_images = self.generator(latent) + logits = self.discriminator(generated_images) + loss = self.generator_loss(logits) + + # Log stuff + tracker.add('generated', generated_images[0:5]) + tracker.add("loss.generator.", loss) + + # Train + if self.mode.is_train: + self.generator_optimizer.zero_grad() + loss.backward() + if batch_idx.is_interval(self.log_params_updates): + pytorch_utils.store_model_indicators(self.generator, 'generator') + self.generator_optimizer.step() + + tracker.save() + @option(Configs.dataset_transforms) def mnist_transforms(): @@ -186,23 +170,6 @@ def mnist_transforms(): ]) -@option(Configs.batch_step) -def gan_batch_step(c: Configs): - return GANBatchStep(discriminator=c.discriminator, - generator=c.generator, - discriminator_optimizer=c.discriminator_optimizer, - generator_optimizer=c.generator_optimizer, - discriminator_loss=c.discriminator_loss, - generator_loss=c.generator_loss, - discriminator_k=c.discriminator_k) - - -calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device)) -calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device)) -calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device)) -calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device)) - - @option(Configs.discriminator_optimizer) def _discriminator_optimizer(c: Configs): opt_conf = OptimizerConfigs()