🥳 simple gan

This commit is contained in:
Varuna Jayasiri
2020-09-26 20:42:38 +05:30
parent 90815b9c65
commit 12dc9cc4f6
2 changed files with 35 additions and 13 deletions

View File

@ -6,35 +6,41 @@ import torch.utils.data
from labml_helpers.module import Module
def create_labels(n: int, r1: float, r2: float, device: torch.device = None):
return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)
class DiscriminatorLogitsLoss(Module):
def __init__(self):
def __init__(self, smoothing: float = 0.2):
super().__init__()
self.loss_true = nn.BCEWithLogitsLoss()
self.loss_false = nn.BCEWithLogitsLoss()
self.register_buffer('labels_true', torch.ones(256, 1, requires_grad=False), False)
self.register_buffer('labels_false', torch.zeros(256, 1, requires_grad=False), False)
self.smoothing = smoothing
self.register_buffer('labels_true', create_labels(256, 1.0 - smoothing, 1.0), False)
self.register_buffer('labels_false', create_labels(256, 0.0, smoothing), False)
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
if len(logits_true) > len(self.labels_true):
self.register_buffer("labels_true",
self.labels_true.new_ones(len(logits_true), 1, requires_grad=False), False)
create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
if len(logits_false) > len(self.labels_false):
self.register_buffer("labels_false",
self.labels_false.new_zeros(len(logits_false), 1, requires_grad=False), False)
create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
return self.loss_true(logits_true, self.labels_true[:len(logits_true)]), \
self.loss_false(logits_false, self.labels_false[:len(logits_false)])
class GeneratorLogitsLoss(Module):
def __init__(self):
def __init__(self, smoothing: float = 0.2):
super().__init__()
self.loss_true = nn.BCEWithLogitsLoss()
self.register_buffer('fake_labels', torch.ones(256, 1, requires_grad=False), False)
self.smoothing = smoothing
self.register_buffer('fake_labels', create_labels(256, 1.0 - smoothing, 1.0), False)
def __call__(self, logits: torch.Tensor):
if len(logits) > len(self.fake_labels):
self.register_buffer("fake_labels",
self.fake_labels.new_ones(len(logits), 1, requires_grad=False), False)
create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
return self.loss_true(logits, self.fake_labels[:len(logits)])

View File

@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.utils.data
from torchvision import transforms
from labml import tracker, monit, experiment
from labml.configs import option, calculate
@ -72,7 +73,7 @@ class GANBatchStep(BatchStepProtocol):
self.discriminator_optimizer = discriminator_optimizer
tracker.set_scalar("loss.generator.*", True)
tracker.set_scalar("loss.discriminator.*", True)
tracker.set_image("generated", True)
tracker.set_image("generated", True, 1 / 100)
def prepare_for_iteration(self):
if MODE_STATE.is_train:
@ -92,7 +93,7 @@ class GANBatchStep(BatchStepProtocol):
if MODE_STATE.is_train:
self.generator_optimizer.zero_grad()
generated_images = self.generator(latent)
# tracker.add('generated', generated_images[0:1])
tracker.add('generated', generated_images[0:5])
logits = self.discriminator(generated_images)
loss = self.generator_loss(logits)
tracker.add("loss.generator.", loss)
@ -127,9 +128,18 @@ class Configs(MNISTConfigs, TrainValidConfigs):
generator: Module
generator_optimizer: torch.optim.Adam
discriminator_optimizer: torch.optim.Adam
discriminator_loss = DiscriminatorLogitsLoss()
generator_loss = GeneratorLogitsLoss()
generator_loss: GeneratorLogitsLoss
discriminator_loss: DiscriminatorLogitsLoss
batch_step = 'gan_batch_step'
label_smoothing: float = 0.2
@option(Configs.dataset_transforms)
def mnist_transforms():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
@option(Configs.batch_step)
@ -144,6 +154,8 @@ def gan_batch_step(c: Configs):
calculate(Configs.generator, lambda c: Generator().to(c.device))
calculate(Configs.discriminator, lambda c: Discriminator().to(c.device))
calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
@option(Configs.discriminator_optimizer)
@ -166,8 +178,12 @@ def main():
experiment.configs(conf,
{'generator_optimizer.learning_rate': 2.5e-4,
'generator_optimizer.optimizer': 'Adam',
'generator_optimizer.betas': (0.5, 0.999),
'discriminator_optimizer.learning_rate': 2.5e-4,
'discriminator_optimizer.optimizer': 'Adam'},
'discriminator_optimizer.optimizer': 'Adam',
'discriminator_optimizer.betas': (0.5, 0.999),
'label_smoothing': 0.01
},
'run')
with experiment.start():
conf.run()