mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-03 05:46:16 +08:00
tracking hooks
This commit is contained in:
@ -6,13 +6,14 @@ import torch.nn as nn
|
|||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import labml.utils.pytorch as pytorch_utils
|
||||||
from labml import tracker, monit, experiment
|
from labml import tracker, monit, experiment
|
||||||
from labml.configs import option, calculate
|
from labml.configs import option, calculate
|
||||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||||
from labml_helpers.device import DeviceConfigs
|
from labml_helpers.device import DeviceConfigs
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml_helpers.optimizer import OptimizerConfigs
|
from labml_helpers.optimizer import OptimizerConfigs
|
||||||
from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs
|
from labml_helpers.train_valid import MODE_STATE, BatchStepProtocol, TrainValidConfigs, hook_model_outputs, Mode
|
||||||
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
from labml_nn.gan import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
||||||
|
|
||||||
plt.rcParams['image.interpolation'] = 'nearest'
|
plt.rcParams['image.interpolation'] = 'nearest'
|
||||||
@ -71,6 +72,9 @@ class GANBatchStep(BatchStepProtocol):
|
|||||||
self.discriminator_loss = discriminator_loss
|
self.discriminator_loss = discriminator_loss
|
||||||
self.generator_optimizer = generator_optimizer
|
self.generator_optimizer = generator_optimizer
|
||||||
self.discriminator_optimizer = discriminator_optimizer
|
self.discriminator_optimizer = discriminator_optimizer
|
||||||
|
|
||||||
|
hook_model_outputs(self.generator, 'generator')
|
||||||
|
hook_model_outputs(self.discriminator, 'discriminator')
|
||||||
tracker.set_scalar("loss.generator.*", True)
|
tracker.set_scalar("loss.generator.*", True)
|
||||||
tracker.set_scalar("loss.discriminator.*", True)
|
tracker.set_scalar("loss.discriminator.*", True)
|
||||||
tracker.set_image("generated", True, 1 / 100)
|
tracker.set_image("generated", True, 1 / 100)
|
||||||
@ -99,6 +103,8 @@ class GANBatchStep(BatchStepProtocol):
|
|||||||
tracker.add("loss.generator.", loss)
|
tracker.add("loss.generator.", loss)
|
||||||
if MODE_STATE.is_train:
|
if MODE_STATE.is_train:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
if MODE_STATE.is_log_parameters:
|
||||||
|
pytorch_utils.store_model_indicators(self.generator, 'generator')
|
||||||
self.generator_optimizer.step()
|
self.generator_optimizer.step()
|
||||||
|
|
||||||
with monit.section("discriminator"):
|
with monit.section("discriminator"):
|
||||||
@ -114,6 +120,8 @@ class GANBatchStep(BatchStepProtocol):
|
|||||||
tracker.add("loss.discriminator.", loss)
|
tracker.add("loss.discriminator.", loss)
|
||||||
if MODE_STATE.is_train:
|
if MODE_STATE.is_train:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
if MODE_STATE.is_log_parameters:
|
||||||
|
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
||||||
self.discriminator_optimizer.step()
|
self.discriminator_optimizer.step()
|
||||||
|
|
||||||
return {'samples': len(data)}, None
|
return {'samples': len(data)}, None
|
||||||
@ -163,8 +171,8 @@ def _discriminator_optimizer(c: Configs):
|
|||||||
opt_conf = OptimizerConfigs()
|
opt_conf = OptimizerConfigs()
|
||||||
opt_conf.optimizer = 'Adam'
|
opt_conf.optimizer = 'Adam'
|
||||||
opt_conf.parameters = c.discriminator.parameters()
|
opt_conf.parameters = c.discriminator.parameters()
|
||||||
opt_conf.learning_rate = 2.5e-4
|
opt_conf.learning_rate = 2.5e-5
|
||||||
opt_conf.betas = (0.5, 0.999)
|
# opt_conf.betas = (0.5, 0.999)
|
||||||
return opt_conf
|
return opt_conf
|
||||||
|
|
||||||
|
|
||||||
@ -173,8 +181,8 @@ def _generator_optimizer(c: Configs):
|
|||||||
opt_conf = OptimizerConfigs()
|
opt_conf = OptimizerConfigs()
|
||||||
opt_conf.optimizer = 'Adam'
|
opt_conf.optimizer = 'Adam'
|
||||||
opt_conf.parameters = c.generator.parameters()
|
opt_conf.parameters = c.generator.parameters()
|
||||||
opt_conf.learning_rate = 2.5e-4
|
opt_conf.learning_rate = 2.5e-5
|
||||||
opt_conf.betas = (0.5, 0.999)
|
# opt_conf.betas = (0.5, 0.999)
|
||||||
return opt_conf
|
return opt_conf
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user