🥳 simple gan

This commit is contained in:
Varuna Jayasiri
2020-09-26 20:42:38 +05:30
parent 90815b9c65
commit 12dc9cc4f6
2 changed files with 35 additions and 13 deletions

View File

@ -6,35 +6,41 @@ import torch.utils.data
from labml_helpers.module import Module from labml_helpers.module import Module
def create_labels(n: int, r1: float, r2: float, device: torch.device = None):
return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)
class DiscriminatorLogitsLoss(Module): class DiscriminatorLogitsLoss(Module):
def __init__(self): def __init__(self, smoothing: float = 0.2):
super().__init__() super().__init__()
self.loss_true = nn.BCEWithLogitsLoss() self.loss_true = nn.BCEWithLogitsLoss()
self.loss_false = nn.BCEWithLogitsLoss() self.loss_false = nn.BCEWithLogitsLoss()
self.register_buffer('labels_true', torch.ones(256, 1, requires_grad=False), False) self.smoothing = smoothing
self.register_buffer('labels_false', torch.zeros(256, 1, requires_grad=False), False) self.register_buffer('labels_true', create_labels(256, 1.0 - smoothing, 1.0), False)
self.register_buffer('labels_false', create_labels(256, 0.0, smoothing), False)
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor): def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
if len(logits_true) > len(self.labels_true): if len(logits_true) > len(self.labels_true):
self.register_buffer("labels_true", self.register_buffer("labels_true",
self.labels_true.new_ones(len(logits_true), 1, requires_grad=False), False) create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
if len(logits_false) > len(self.labels_false): if len(logits_false) > len(self.labels_false):
self.register_buffer("labels_false", self.register_buffer("labels_false",
self.labels_false.new_zeros(len(logits_false), 1, requires_grad=False), False) create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
return self.loss_true(logits_true, self.labels_true[:len(logits_true)]), \ return self.loss_true(logits_true, self.labels_true[:len(logits_true)]), \
self.loss_false(logits_false, self.labels_false[:len(logits_false)]) self.loss_false(logits_false, self.labels_false[:len(logits_false)])
class GeneratorLogitsLoss(Module): class GeneratorLogitsLoss(Module):
def __init__(self): def __init__(self, smoothing: float = 0.2):
super().__init__() super().__init__()
self.loss_true = nn.BCEWithLogitsLoss() self.loss_true = nn.BCEWithLogitsLoss()
self.register_buffer('fake_labels', torch.ones(256, 1, requires_grad=False), False) self.smoothing = smoothing
self.register_buffer('fake_labels', create_labels(256, 1.0 - smoothing, 1.0), False)
def __call__(self, logits: torch.Tensor): def __call__(self, logits: torch.Tensor):
if len(logits) > len(self.fake_labels): if len(logits) > len(self.fake_labels):
self.register_buffer("fake_labels", self.register_buffer("fake_labels",
self.fake_labels.new_ones(len(logits), 1, requires_grad=False), False) create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
return self.loss_true(logits, self.fake_labels[:len(logits)]) return self.loss_true(logits, self.fake_labels[:len(logits)])

View File

@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.data import torch.utils.data
from torchvision import transforms
from labml import tracker, monit, experiment from labml import tracker, monit, experiment
from labml.configs import option, calculate from labml.configs import option, calculate
@ -72,7 +73,7 @@ class GANBatchStep(BatchStepProtocol):
self.discriminator_optimizer = discriminator_optimizer self.discriminator_optimizer = discriminator_optimizer
tracker.set_scalar("loss.generator.*", True) tracker.set_scalar("loss.generator.*", True)
tracker.set_scalar("loss.discriminator.*", True) tracker.set_scalar("loss.discriminator.*", True)
tracker.set_image("generated", True) tracker.set_image("generated", True, 1 / 100)
def prepare_for_iteration(self): def prepare_for_iteration(self):
if MODE_STATE.is_train: if MODE_STATE.is_train:
@ -92,7 +93,7 @@ class GANBatchStep(BatchStepProtocol):
if MODE_STATE.is_train: if MODE_STATE.is_train:
self.generator_optimizer.zero_grad() self.generator_optimizer.zero_grad()
generated_images = self.generator(latent) generated_images = self.generator(latent)
# tracker.add('generated', generated_images[0:1]) tracker.add('generated', generated_images[0:5])
logits = self.discriminator(generated_images) logits = self.discriminator(generated_images)
loss = self.generator_loss(logits) loss = self.generator_loss(logits)
tracker.add("loss.generator.", loss) tracker.add("loss.generator.", loss)
@ -127,9 +128,18 @@ class Configs(MNISTConfigs, TrainValidConfigs):
generator: Module generator: Module
generator_optimizer: torch.optim.Adam generator_optimizer: torch.optim.Adam
discriminator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam
discriminator_loss = DiscriminatorLogitsLoss() generator_loss: GeneratorLogitsLoss
generator_loss = GeneratorLogitsLoss() discriminator_loss: DiscriminatorLogitsLoss
batch_step = 'gan_batch_step' batch_step = 'gan_batch_step'
label_smoothing: float = 0.2
@option(Configs.dataset_transforms)
def mnist_transforms():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
@option(Configs.batch_step) @option(Configs.batch_step)
@ -144,6 +154,8 @@ def gan_batch_step(c: Configs):
calculate(Configs.generator, lambda c: Generator().to(c.device)) calculate(Configs.generator, lambda c: Generator().to(c.device))
calculate(Configs.discriminator, lambda c: Discriminator().to(c.device)) calculate(Configs.discriminator, lambda c: Discriminator().to(c.device))
calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
@option(Configs.discriminator_optimizer) @option(Configs.discriminator_optimizer)
@ -166,8 +178,12 @@ def main():
experiment.configs(conf, experiment.configs(conf,
{'generator_optimizer.learning_rate': 2.5e-4, {'generator_optimizer.learning_rate': 2.5e-4,
'generator_optimizer.optimizer': 'Adam', 'generator_optimizer.optimizer': 'Adam',
'generator_optimizer.betas': (0.5, 0.999),
'discriminator_optimizer.learning_rate': 2.5e-4, 'discriminator_optimizer.learning_rate': 2.5e-4,
'discriminator_optimizer.optimizer': 'Adam'}, 'discriminator_optimizer.optimizer': 'Adam',
'discriminator_optimizer.betas': (0.5, 0.999),
'label_smoothing': 0.01
},
'run') 'run')
with experiment.start(): with experiment.start():
conf.run() conf.run()