mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 10:51:23 +08:00
🥳 simple gan
This commit is contained in:
@ -6,35 +6,41 @@ import torch.utils.data
|
|||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
def create_labels(n: int, r1: float, r2: float, device: torch.device = None):
|
||||||
|
return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorLogitsLoss(Module):
|
class DiscriminatorLogitsLoss(Module):
|
||||||
def __init__(self):
|
def __init__(self, smoothing: float = 0.2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.loss_true = nn.BCEWithLogitsLoss()
|
self.loss_true = nn.BCEWithLogitsLoss()
|
||||||
self.loss_false = nn.BCEWithLogitsLoss()
|
self.loss_false = nn.BCEWithLogitsLoss()
|
||||||
self.register_buffer('labels_true', torch.ones(256, 1, requires_grad=False), False)
|
self.smoothing = smoothing
|
||||||
self.register_buffer('labels_false', torch.zeros(256, 1, requires_grad=False), False)
|
self.register_buffer('labels_true', create_labels(256, 1.0 - smoothing, 1.0), False)
|
||||||
|
self.register_buffer('labels_false', create_labels(256, 0.0, smoothing), False)
|
||||||
|
|
||||||
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
|
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
|
||||||
if len(logits_true) > len(self.labels_true):
|
if len(logits_true) > len(self.labels_true):
|
||||||
self.register_buffer("labels_true",
|
self.register_buffer("labels_true",
|
||||||
self.labels_true.new_ones(len(logits_true), 1, requires_grad=False), False)
|
create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
|
||||||
if len(logits_false) > len(self.labels_false):
|
if len(logits_false) > len(self.labels_false):
|
||||||
self.register_buffer("labels_false",
|
self.register_buffer("labels_false",
|
||||||
self.labels_false.new_zeros(len(logits_false), 1, requires_grad=False), False)
|
create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
|
||||||
|
|
||||||
return self.loss_true(logits_true, self.labels_true[:len(logits_true)]), \
|
return self.loss_true(logits_true, self.labels_true[:len(logits_true)]), \
|
||||||
self.loss_false(logits_false, self.labels_false[:len(logits_false)])
|
self.loss_false(logits_false, self.labels_false[:len(logits_false)])
|
||||||
|
|
||||||
|
|
||||||
class GeneratorLogitsLoss(Module):
|
class GeneratorLogitsLoss(Module):
|
||||||
def __init__(self):
|
def __init__(self, smoothing: float = 0.2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.loss_true = nn.BCEWithLogitsLoss()
|
self.loss_true = nn.BCEWithLogitsLoss()
|
||||||
self.register_buffer('fake_labels', torch.ones(256, 1, requires_grad=False), False)
|
self.smoothing = smoothing
|
||||||
|
self.register_buffer('fake_labels', create_labels(256, 1.0 - smoothing, 1.0), False)
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor):
|
def __call__(self, logits: torch.Tensor):
|
||||||
if len(logits) > len(self.fake_labels):
|
if len(logits) > len(self.fake_labels):
|
||||||
self.register_buffer("fake_labels",
|
self.register_buffer("fake_labels",
|
||||||
self.fake_labels.new_ones(len(logits), 1, requires_grad=False), False)
|
create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
|
||||||
|
|
||||||
return self.loss_true(logits, self.fake_labels[:len(logits)])
|
return self.loss_true(logits, self.fake_labels[:len(logits)])
|
||||||
|
@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
from labml import tracker, monit, experiment
|
from labml import tracker, monit, experiment
|
||||||
from labml.configs import option, calculate
|
from labml.configs import option, calculate
|
||||||
@ -72,7 +73,7 @@ class GANBatchStep(BatchStepProtocol):
|
|||||||
self.discriminator_optimizer = discriminator_optimizer
|
self.discriminator_optimizer = discriminator_optimizer
|
||||||
tracker.set_scalar("loss.generator.*", True)
|
tracker.set_scalar("loss.generator.*", True)
|
||||||
tracker.set_scalar("loss.discriminator.*", True)
|
tracker.set_scalar("loss.discriminator.*", True)
|
||||||
tracker.set_image("generated", True)
|
tracker.set_image("generated", True, 1 / 100)
|
||||||
|
|
||||||
def prepare_for_iteration(self):
|
def prepare_for_iteration(self):
|
||||||
if MODE_STATE.is_train:
|
if MODE_STATE.is_train:
|
||||||
@ -92,7 +93,7 @@ class GANBatchStep(BatchStepProtocol):
|
|||||||
if MODE_STATE.is_train:
|
if MODE_STATE.is_train:
|
||||||
self.generator_optimizer.zero_grad()
|
self.generator_optimizer.zero_grad()
|
||||||
generated_images = self.generator(latent)
|
generated_images = self.generator(latent)
|
||||||
# tracker.add('generated', generated_images[0:1])
|
tracker.add('generated', generated_images[0:5])
|
||||||
logits = self.discriminator(generated_images)
|
logits = self.discriminator(generated_images)
|
||||||
loss = self.generator_loss(logits)
|
loss = self.generator_loss(logits)
|
||||||
tracker.add("loss.generator.", loss)
|
tracker.add("loss.generator.", loss)
|
||||||
@ -127,9 +128,18 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
generator: Module
|
generator: Module
|
||||||
generator_optimizer: torch.optim.Adam
|
generator_optimizer: torch.optim.Adam
|
||||||
discriminator_optimizer: torch.optim.Adam
|
discriminator_optimizer: torch.optim.Adam
|
||||||
discriminator_loss = DiscriminatorLogitsLoss()
|
generator_loss: GeneratorLogitsLoss
|
||||||
generator_loss = GeneratorLogitsLoss()
|
discriminator_loss: DiscriminatorLogitsLoss
|
||||||
batch_step = 'gan_batch_step'
|
batch_step = 'gan_batch_step'
|
||||||
|
label_smoothing: float = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.dataset_transforms)
|
||||||
|
def mnist_transforms():
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,), (0.5,))
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.batch_step)
|
@option(Configs.batch_step)
|
||||||
@ -144,6 +154,8 @@ def gan_batch_step(c: Configs):
|
|||||||
|
|
||||||
calculate(Configs.generator, lambda c: Generator().to(c.device))
|
calculate(Configs.generator, lambda c: Generator().to(c.device))
|
||||||
calculate(Configs.discriminator, lambda c: Discriminator().to(c.device))
|
calculate(Configs.discriminator, lambda c: Discriminator().to(c.device))
|
||||||
|
calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
|
||||||
|
calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.discriminator_optimizer)
|
@option(Configs.discriminator_optimizer)
|
||||||
@ -166,8 +178,12 @@ def main():
|
|||||||
experiment.configs(conf,
|
experiment.configs(conf,
|
||||||
{'generator_optimizer.learning_rate': 2.5e-4,
|
{'generator_optimizer.learning_rate': 2.5e-4,
|
||||||
'generator_optimizer.optimizer': 'Adam',
|
'generator_optimizer.optimizer': 'Adam',
|
||||||
|
'generator_optimizer.betas': (0.5, 0.999),
|
||||||
'discriminator_optimizer.learning_rate': 2.5e-4,
|
'discriminator_optimizer.learning_rate': 2.5e-4,
|
||||||
'discriminator_optimizer.optimizer': 'Adam'},
|
'discriminator_optimizer.optimizer': 'Adam',
|
||||||
|
'discriminator_optimizer.betas': (0.5, 0.999),
|
||||||
|
'label_smoothing': 0.01
|
||||||
|
},
|
||||||
'run')
|
'run')
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
conf.run()
|
conf.run()
|
||||||
|
Reference in New Issue
Block a user