mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 18:27:03 +08:00
gan step
This commit is contained in:
@ -2,7 +2,7 @@
|
|||||||
# Generative Adversarial Networks experiment with MNIST
|
# Generative Adversarial Networks experiment with MNIST
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -16,7 +16,7 @@ from labml_helpers.datasets.mnist import MNISTConfigs
|
|||||||
from labml_helpers.device import DeviceConfigs
|
from labml_helpers.device import DeviceConfigs
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml_helpers.optimizer import OptimizerConfigs
|
from labml_helpers.optimizer import OptimizerConfigs
|
||||||
from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs, hook_model_outputs
|
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||||
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
||||||
|
|
||||||
|
|
||||||
@ -79,89 +79,6 @@ class Discriminator(Module):
|
|||||||
return self.layers(x.view(x.shape[0], -1))
|
return self.layers(x.view(x.shape[0], -1))
|
||||||
|
|
||||||
|
|
||||||
class GANBatchStep(BatchStepProtocol):
|
|
||||||
def __init__(self, *,
|
|
||||||
discriminator: Module,
|
|
||||||
generator: Module,
|
|
||||||
discriminator_optimizer: Optional[torch.optim.Adam],
|
|
||||||
generator_optimizer: Optional[torch.optim.Adam],
|
|
||||||
discriminator_loss: DiscriminatorLogitsLoss,
|
|
||||||
generator_loss: GeneratorLogitsLoss,
|
|
||||||
discriminator_k: int):
|
|
||||||
|
|
||||||
self.discriminator_k = discriminator_k
|
|
||||||
self.generator = generator
|
|
||||||
self.discriminator = discriminator
|
|
||||||
self.generator_loss = generator_loss
|
|
||||||
self.discriminator_loss = discriminator_loss
|
|
||||||
self.generator_optimizer = generator_optimizer
|
|
||||||
self.discriminator_optimizer = discriminator_optimizer
|
|
||||||
|
|
||||||
hook_model_outputs(self.generator, 'generator')
|
|
||||||
hook_model_outputs(self.discriminator, 'discriminator')
|
|
||||||
tracker.set_scalar("loss.generator.*", True)
|
|
||||||
tracker.set_scalar("loss.discriminator.*", True)
|
|
||||||
tracker.set_image("generated", True, 1 / 100)
|
|
||||||
|
|
||||||
def prepare_for_iteration(self):
|
|
||||||
if MODE_STATE.is_train:
|
|
||||||
self.generator.train()
|
|
||||||
self.discriminator.train()
|
|
||||||
else:
|
|
||||||
self.generator.eval()
|
|
||||||
self.discriminator.eval()
|
|
||||||
|
|
||||||
def process(self, batch: any, state: any):
|
|
||||||
device = self.discriminator.device
|
|
||||||
data, target = batch
|
|
||||||
data, target = data.to(device), target.to(device)
|
|
||||||
|
|
||||||
# Train the discriminator
|
|
||||||
with monit.section("discriminator"):
|
|
||||||
for _ in range(self.discriminator_k):
|
|
||||||
latent = torch.randn(data.shape[0], 100, device=device)
|
|
||||||
if MODE_STATE.is_train:
|
|
||||||
self.discriminator_optimizer.zero_grad()
|
|
||||||
logits_true = self.discriminator(data)
|
|
||||||
logits_false = self.discriminator(self.generator(latent).detach())
|
|
||||||
loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
|
|
||||||
loss = loss_true + loss_false
|
|
||||||
|
|
||||||
# Log stuff
|
|
||||||
tracker.add("loss.discriminator.true.", loss_true)
|
|
||||||
tracker.add("loss.discriminator.false.", loss_false)
|
|
||||||
tracker.add("loss.discriminator.", loss)
|
|
||||||
|
|
||||||
# Train
|
|
||||||
if MODE_STATE.is_train:
|
|
||||||
loss.backward()
|
|
||||||
if MODE_STATE.is_log_parameters:
|
|
||||||
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
|
||||||
self.discriminator_optimizer.step()
|
|
||||||
|
|
||||||
# Train the generator
|
|
||||||
with monit.section("generator"):
|
|
||||||
latent = torch.randn(data.shape[0], 100, device=device)
|
|
||||||
if MODE_STATE.is_train:
|
|
||||||
self.generator_optimizer.zero_grad()
|
|
||||||
generated_images = self.generator(latent)
|
|
||||||
logits = self.discriminator(generated_images)
|
|
||||||
loss = self.generator_loss(logits)
|
|
||||||
|
|
||||||
# Log stuff
|
|
||||||
tracker.add('generated', generated_images[0:5])
|
|
||||||
tracker.add("loss.generator.", loss)
|
|
||||||
|
|
||||||
# Train
|
|
||||||
if MODE_STATE.is_train:
|
|
||||||
loss.backward()
|
|
||||||
if MODE_STATE.is_log_parameters:
|
|
||||||
pytorch_utils.store_model_indicators(self.generator, 'generator')
|
|
||||||
self.generator_optimizer.step()
|
|
||||||
|
|
||||||
return {'samples': len(data)}, None
|
|
||||||
|
|
||||||
|
|
||||||
class Configs(MNISTConfigs, TrainValidConfigs):
|
class Configs(MNISTConfigs, TrainValidConfigs):
|
||||||
device: torch.device = DeviceConfigs()
|
device: torch.device = DeviceConfigs()
|
||||||
epochs: int = 10
|
epochs: int = 10
|
||||||
@ -173,10 +90,77 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
discriminator_optimizer: torch.optim.Adam
|
discriminator_optimizer: torch.optim.Adam
|
||||||
generator_loss: GeneratorLogitsLoss
|
generator_loss: GeneratorLogitsLoss
|
||||||
discriminator_loss: DiscriminatorLogitsLoss
|
discriminator_loss: DiscriminatorLogitsLoss
|
||||||
batch_step = 'gan_batch_step'
|
|
||||||
label_smoothing: float = 0.2
|
label_smoothing: float = 0.2
|
||||||
discriminator_k: int = 1
|
discriminator_k: int = 1
|
||||||
|
|
||||||
|
log_params_updates: int = 2 ** 32 # 0 if not
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.state_modules = []
|
||||||
|
self.generator = Generator().to(self.device)
|
||||||
|
self.discriminator = Discriminator().to(self.device)
|
||||||
|
self.generator_loss = GeneratorLogitsLoss(self.label_smoothing).to(self.device)
|
||||||
|
self.discriminator_loss = DiscriminatorLogitsLoss(self.label_smoothing).to(self.device)
|
||||||
|
|
||||||
|
hook_model_outputs(self.mode, self.generator, 'generator')
|
||||||
|
hook_model_outputs(self.mode, self.discriminator, 'discriminator')
|
||||||
|
tracker.set_scalar("loss.generator.*", True)
|
||||||
|
tracker.set_scalar("loss.discriminator.*", True)
|
||||||
|
tracker.set_image("generated", True, 1 / 100)
|
||||||
|
|
||||||
|
def step(self, batch: Any, batch_idx: BatchIndex):
|
||||||
|
self.generator.train(self.mode.is_train)
|
||||||
|
self.discriminator.train(self.mode.is_train)
|
||||||
|
|
||||||
|
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||||
|
|
||||||
|
# Increment step in training mode
|
||||||
|
if self.mode.is_train:
|
||||||
|
tracker.add_global_step(len(data))
|
||||||
|
|
||||||
|
# Train the discriminator
|
||||||
|
with monit.section("discriminator"):
|
||||||
|
for _ in range(self.discriminator_k):
|
||||||
|
latent = torch.randn(data.shape[0], 100, device=self.device)
|
||||||
|
logits_true = self.discriminator(data)
|
||||||
|
logits_false = self.discriminator(self.generator(latent).detach())
|
||||||
|
loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
|
||||||
|
loss = loss_true + loss_false
|
||||||
|
|
||||||
|
# Log stuff
|
||||||
|
tracker.add("loss.discriminator.true.", loss_true)
|
||||||
|
tracker.add("loss.discriminator.false.", loss_false)
|
||||||
|
tracker.add("loss.discriminator.", loss)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
if self.mode.is_train:
|
||||||
|
self.discriminator_optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
if batch_idx.is_interval(self.log_params_updates):
|
||||||
|
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
||||||
|
self.discriminator_optimizer.step()
|
||||||
|
|
||||||
|
# Train the generator
|
||||||
|
with monit.section("generator"):
|
||||||
|
latent = torch.randn(data.shape[0], 100, device=self.device)
|
||||||
|
generated_images = self.generator(latent)
|
||||||
|
logits = self.discriminator(generated_images)
|
||||||
|
loss = self.generator_loss(logits)
|
||||||
|
|
||||||
|
# Log stuff
|
||||||
|
tracker.add('generated', generated_images[0:5])
|
||||||
|
tracker.add("loss.generator.", loss)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
if self.mode.is_train:
|
||||||
|
self.generator_optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
if batch_idx.is_interval(self.log_params_updates):
|
||||||
|
pytorch_utils.store_model_indicators(self.generator, 'generator')
|
||||||
|
self.generator_optimizer.step()
|
||||||
|
|
||||||
|
tracker.save()
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.dataset_transforms)
|
@option(Configs.dataset_transforms)
|
||||||
def mnist_transforms():
|
def mnist_transforms():
|
||||||
@ -186,23 +170,6 @@ def mnist_transforms():
|
|||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.batch_step)
|
|
||||||
def gan_batch_step(c: Configs):
|
|
||||||
return GANBatchStep(discriminator=c.discriminator,
|
|
||||||
generator=c.generator,
|
|
||||||
discriminator_optimizer=c.discriminator_optimizer,
|
|
||||||
generator_optimizer=c.generator_optimizer,
|
|
||||||
discriminator_loss=c.discriminator_loss,
|
|
||||||
generator_loss=c.generator_loss,
|
|
||||||
discriminator_k=c.discriminator_k)
|
|
||||||
|
|
||||||
|
|
||||||
calculate(Configs.generator, 'mlp', lambda c: Generator().to(c.device))
|
|
||||||
calculate(Configs.discriminator, 'mlp', lambda c: Discriminator().to(c.device))
|
|
||||||
calculate(Configs.generator_loss, lambda c: GeneratorLogitsLoss(c.label_smoothing).to(c.device))
|
|
||||||
calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_smoothing).to(c.device))
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.discriminator_optimizer)
|
@option(Configs.discriminator_optimizer)
|
||||||
def _discriminator_optimizer(c: Configs):
|
def _discriminator_optimizer(c: Configs):
|
||||||
opt_conf = OptimizerConfigs()
|
opt_conf = OptimizerConfigs()
|
||||||
|
|||||||
Reference in New Issue
Block a user