diff --git a/labml_nn/gan/__init__.py b/labml_nn/gan/__init__.py index 182f9dfd..1e160a1d 100644 --- a/labml_nn/gan/__init__.py +++ b/labml_nn/gan/__init__.py @@ -6,35 +6,41 @@ import torch.utils.data from labml_helpers.module import Module +def create_labels(n: int, r1: float, r2: float, device: torch.device = None): + return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2) + + class DiscriminatorLogitsLoss(Module): - def __init__(self): + def __init__(self, smoothing: float = 0.2): super().__init__() self.loss_true = nn.BCEWithLogitsLoss() self.loss_false = nn.BCEWithLogitsLoss() - self.register_buffer('labels_true', torch.ones(256, 1, requires_grad=False), False) - self.register_buffer('labels_false', torch.zeros(256, 1, requires_grad=False), False) + self.smoothing = smoothing + self.register_buffer('labels_true', create_labels(256, 1.0 - smoothing, 1.0), False) + self.register_buffer('labels_false', create_labels(256, 0.0, smoothing), False) def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor): if len(logits_true) > len(self.labels_true): self.register_buffer("labels_true", - self.labels_true.new_ones(len(logits_true), 1, requires_grad=False), False) + create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False) if len(logits_false) > len(self.labels_false): self.register_buffer("labels_false", - self.labels_false.new_zeros(len(logits_false), 1, requires_grad=False), False) + create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False) return self.loss_true(logits_true, self.labels_true[:len(logits_true)]), \ self.loss_false(logits_false, self.labels_false[:len(logits_false)]) class GeneratorLogitsLoss(Module): - def __init__(self): + def __init__(self, smoothing: float = 0.2): super().__init__() self.loss_true = nn.BCEWithLogitsLoss() - self.register_buffer('fake_labels', torch.ones(256, 1, requires_grad=False), False) + self.smoothing = smoothing + self.register_buffer('fake_labels', create_labels(256, 1.0 - smoothing, 1.0), False) def __call__(self, logits: torch.Tensor): if len(logits) > len(self.fake_labels): self.register_buffer("fake_labels", - self.fake_labels.new_ones(len(logits), 1, requires_grad=False), False) + create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False) return self.loss_true(logits, self.fake_labels[:len(logits)]) diff --git a/labml_nn/gan/mnist.py b/labml_nn/gan/mnist.py index d08875f6..88b2bd16 100644 --- a/labml_nn/gan/mnist.py +++ b/labml_nn/gan/mnist.py @@ -4,6 +4,7 @@ import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.utils.data +from torchvision import transforms from labml import tracker, monit, experiment from labml.configs import option, calculate @@ -72,7 +73,7 @@ class GANBatchStep(BatchStepProtocol): self.discriminator_optimizer = discriminator_optimizer tracker.set_scalar("loss.generator.*", True) tracker.set_scalar("loss.discriminator.*", True) - tracker.set_image("generated", True) + tracker.set_image("generated", True, 1 / 100) def prepare_for_iteration(self): if MODE_STATE.is_train: @@ -92,7 +93,7 @@ class GANBatchStep(BatchStepProtocol): if MODE_STATE.is_train: self.generator_optimizer.zero_grad() generated_images = self.generator(latent) - # tracker.add('generated', generated_images[0:1]) + tracker.add('generated', generated_images[0:5]) logits = self.discriminator(generated_images) loss = self.generator_loss(logits) tracker.add("loss.generator.", loss) @@ -127,9 +128,18 @@ class Configs(MNISTConfigs, TrainValidConfigs): generator: Module generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam - discriminator_loss = DiscriminatorLogitsLoss() - generator_loss = GeneratorLogitsLoss() + generator_loss: GeneratorLogitsLoss + discriminator_loss: DiscriminatorLogitsLoss batch_step = 'gan_batch_step' + label_smoothing: float = 0.2 + + +@option(Configs.dataset_transforms) +def mnist_transforms(): + return transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) @option(Configs.batch_step) @@ -144,6 +154,8 @@ def gan_batch_step(c: Configs): calculate(Configs.generator, lambda c: Generator().to(c.device)) calculate(Configs.discriminator, lambda c: Discriminator().to(c.device)) +calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device)) +calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device)) @option(Configs.discriminator_optimizer) @@ -166,8 +178,12 @@ def main(): experiment.configs(conf, {'generator_optimizer.learning_rate': 2.5e-4, 'generator_optimizer.optimizer': 'Adam', + 'generator_optimizer.betas': (0.5, 0.999), 'discriminator_optimizer.learning_rate': 2.5e-4, - 'discriminator_optimizer.optimizer': 'Adam'}, + 'discriminator_optimizer.optimizer': 'Adam', + 'discriminator_optimizer.betas': (0.5, 0.999), + 'label_smoothing': 0.01 + }, 'run') with experiment.start(): conf.run()