mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 19:01:12 +08:00
🥳 simple gan
This commit is contained in:
@ -6,35 +6,41 @@ import torch.utils.data
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
def create_labels(n: int, r1: float, r2: float, device: torch.device = None):
|
||||
return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)
|
||||
|
||||
|
||||
class DiscriminatorLogitsLoss(Module):
|
||||
def __init__(self):
|
||||
def __init__(self, smoothing: float = 0.2):
|
||||
super().__init__()
|
||||
self.loss_true = nn.BCEWithLogitsLoss()
|
||||
self.loss_false = nn.BCEWithLogitsLoss()
|
||||
self.register_buffer('labels_true', torch.ones(256, 1, requires_grad=False), False)
|
||||
self.register_buffer('labels_false', torch.zeros(256, 1, requires_grad=False), False)
|
||||
self.smoothing = smoothing
|
||||
self.register_buffer('labels_true', create_labels(256, 1.0 - smoothing, 1.0), False)
|
||||
self.register_buffer('labels_false', create_labels(256, 0.0, smoothing), False)
|
||||
|
||||
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
|
||||
if len(logits_true) > len(self.labels_true):
|
||||
self.register_buffer("labels_true",
|
||||
self.labels_true.new_ones(len(logits_true), 1, requires_grad=False), False)
|
||||
create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
|
||||
if len(logits_false) > len(self.labels_false):
|
||||
self.register_buffer("labels_false",
|
||||
self.labels_false.new_zeros(len(logits_false), 1, requires_grad=False), False)
|
||||
create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
|
||||
|
||||
return self.loss_true(logits_true, self.labels_true[:len(logits_true)]), \
|
||||
self.loss_false(logits_false, self.labels_false[:len(logits_false)])
|
||||
|
||||
|
||||
class GeneratorLogitsLoss(Module):
|
||||
def __init__(self):
|
||||
def __init__(self, smoothing: float = 0.2):
|
||||
super().__init__()
|
||||
self.loss_true = nn.BCEWithLogitsLoss()
|
||||
self.register_buffer('fake_labels', torch.ones(256, 1, requires_grad=False), False)
|
||||
self.smoothing = smoothing
|
||||
self.register_buffer('fake_labels', create_labels(256, 1.0 - smoothing, 1.0), False)
|
||||
|
||||
def __call__(self, logits: torch.Tensor):
|
||||
if len(logits) > len(self.fake_labels):
|
||||
self.register_buffer("fake_labels",
|
||||
self.fake_labels.new_ones(len(logits), 1, requires_grad=False), False)
|
||||
create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
|
||||
|
||||
return self.loss_true(logits, self.fake_labels[:len(logits)])
|
||||
|
@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
from torchvision import transforms
|
||||
|
||||
from labml import tracker, monit, experiment
|
||||
from labml.configs import option, calculate
|
||||
@ -72,7 +73,7 @@ class GANBatchStep(BatchStepProtocol):
|
||||
self.discriminator_optimizer = discriminator_optimizer
|
||||
tracker.set_scalar("loss.generator.*", True)
|
||||
tracker.set_scalar("loss.discriminator.*", True)
|
||||
tracker.set_image("generated", True)
|
||||
tracker.set_image("generated", True, 1 / 100)
|
||||
|
||||
def prepare_for_iteration(self):
|
||||
if MODE_STATE.is_train:
|
||||
@ -92,7 +93,7 @@ class GANBatchStep(BatchStepProtocol):
|
||||
if MODE_STATE.is_train:
|
||||
self.generator_optimizer.zero_grad()
|
||||
generated_images = self.generator(latent)
|
||||
# tracker.add('generated', generated_images[0:1])
|
||||
tracker.add('generated', generated_images[0:5])
|
||||
logits = self.discriminator(generated_images)
|
||||
loss = self.generator_loss(logits)
|
||||
tracker.add("loss.generator.", loss)
|
||||
@ -127,9 +128,18 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
generator: Module
|
||||
generator_optimizer: torch.optim.Adam
|
||||
discriminator_optimizer: torch.optim.Adam
|
||||
discriminator_loss = DiscriminatorLogitsLoss()
|
||||
generator_loss = GeneratorLogitsLoss()
|
||||
generator_loss: GeneratorLogitsLoss
|
||||
discriminator_loss: DiscriminatorLogitsLoss
|
||||
batch_step = 'gan_batch_step'
|
||||
label_smoothing: float = 0.2
|
||||
|
||||
|
||||
@option(Configs.dataset_transforms)
|
||||
def mnist_transforms():
|
||||
return transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,))
|
||||
])
|
||||
|
||||
|
||||
@option(Configs.batch_step)
|
||||
@ -144,6 +154,8 @@ def gan_batch_step(c: Configs):
|
||||
|
||||
calculate(Configs.generator, lambda c: Generator().to(c.device))
|
||||
calculate(Configs.discriminator, lambda c: Discriminator().to(c.device))
|
||||
calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
|
||||
calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
|
||||
|
||||
|
||||
@option(Configs.discriminator_optimizer)
|
||||
@ -166,8 +178,12 @@ def main():
|
||||
experiment.configs(conf,
|
||||
{'generator_optimizer.learning_rate': 2.5e-4,
|
||||
'generator_optimizer.optimizer': 'Adam',
|
||||
'generator_optimizer.betas': (0.5, 0.999),
|
||||
'discriminator_optimizer.learning_rate': 2.5e-4,
|
||||
'discriminator_optimizer.optimizer': 'Adam'},
|
||||
'discriminator_optimizer.optimizer': 'Adam',
|
||||
'discriminator_optimizer.betas': (0.5, 0.999),
|
||||
'label_smoothing': 0.01
|
||||
},
|
||||
'run')
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
Reference in New Issue
Block a user