mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 21:40:15 +08:00
tracking hooks
This commit is contained in:
@ -6,13 +6,14 @@ import torch.nn as nn
|
||||
import torch.utils.data
|
||||
from torchvision import transforms
|
||||
|
||||
import labml.utils.pytorch as pytorch_utils
|
||||
from labml import tracker, monit, experiment
|
||||
from labml.configs import option, calculate
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.optimizer import OptimizerConfigs
|
||||
from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs
|
||||
from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs, hook_model_outputs, Mode
|
||||
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
||||
|
||||
plt.rcParams['image.interpolation'] = 'nearest'
|
||||
@ -71,6 +72,9 @@ class GANBatchStep(BatchStepProtocol):
|
||||
self.discriminator_loss = discriminator_loss
|
||||
self.generator_optimizer = generator_optimizer
|
||||
self.discriminator_optimizer = discriminator_optimizer
|
||||
|
||||
hook_model_outputs(self.generator, 'generator')
|
||||
hook_model_outputs(self.discriminator, 'discriminator')
|
||||
tracker.set_scalar("loss.generator.*", True)
|
||||
tracker.set_scalar("loss.discriminator.*", True)
|
||||
tracker.set_image("generated", True, 1 / 100)
|
||||
@ -99,6 +103,8 @@ class GANBatchStep(BatchStepProtocol):
|
||||
tracker.add("loss.generator.", loss)
|
||||
if MODE_STATE.is_train:
|
||||
loss.backward()
|
||||
if MODE_STATE.is_log_parameters:
|
||||
pytorch_utils.store_model_indicators(self.generator, 'generator')
|
||||
self.generator_optimizer.step()
|
||||
|
||||
with monit.section("discriminator"):
|
||||
@ -114,6 +120,8 @@ class GANBatchStep(BatchStepProtocol):
|
||||
tracker.add("loss.discriminator.", loss)
|
||||
if MODE_STATE.is_train:
|
||||
loss.backward()
|
||||
if MODE_STATE.is_log_parameters:
|
||||
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
||||
self.discriminator_optimizer.step()
|
||||
|
||||
return {'samples': len(data)}, None
|
||||
@ -163,8 +171,8 @@ def _discriminator_optimizer(c: Configs):
|
||||
opt_conf = OptimizerConfigs()
|
||||
opt_conf.optimizer = 'Adam'
|
||||
opt_conf.parameters = c.discriminator.parameters()
|
||||
opt_conf.learning_rate = 2.5e-4
|
||||
opt_conf.betas = (0.5, 0.999)
|
||||
opt_conf.learning_rate = 2.5e-5
|
||||
# opt_conf.betas = (0.5, 0.999)
|
||||
return opt_conf
|
||||
|
||||
|
||||
@ -173,8 +181,8 @@ def _generator_optimizer(c: Configs):
|
||||
opt_conf = OptimizerConfigs()
|
||||
opt_conf.optimizer = 'Adam'
|
||||
opt_conf.parameters = c.generator.parameters()
|
||||
opt_conf.learning_rate = 2.5e-4
|
||||
opt_conf.betas = (0.5, 0.999)
|
||||
opt_conf.learning_rate = 2.5e-5
|
||||
# opt_conf.betas = (0.5, 0.999)
|
||||
return opt_conf
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user