diff --git a/labml_nn/gan/mnist.py b/labml_nn/gan/mnist.py index 88b2bd16..20c24597 100644 --- a/labml_nn/gan/mnist.py +++ b/labml_nn/gan/mnist.py @@ -161,14 +161,20 @@ calculate(Configs.discriminator_loss, lambda c: DiscriminatorLogitsLoss(c.label_ @option(Configs.discriminator_optimizer) def _discriminator_optimizer(c: Configs): opt_conf = OptimizerConfigs() + opt_conf.optimizer = 'Adam' opt_conf.parameters = c.discriminator.parameters() + opt_conf.learning_rate = 2.5e-4 + opt_conf.betas = (0.5, 0.999) return opt_conf @option(Configs.generator_optimizer) def _generator_optimizer(c: Configs): opt_conf = OptimizerConfigs() + opt_conf.optimizer = 'Adam' opt_conf.parameters = c.generator.parameters() + opt_conf.learning_rate = 2.5e-4 + opt_conf.betas = (0.5, 0.999) return opt_conf @@ -176,14 +182,7 @@ def main(): conf = Configs() experiment.create(name='mnist_gan', comment='test') experiment.configs(conf, - {'generator_optimizer.learning_rate': 2.5e-4, - 'generator_optimizer.optimizer': 'Adam', - 'generator_optimizer.betas': (0.5, 0.999), - 'discriminator_optimizer.learning_rate': 2.5e-4, - 'discriminator_optimizer.optimizer': 'Adam', - 'discriminator_optimizer.betas': (0.5, 0.999), - 'label_smoothing': 0.01 - }, + {'label_smoothing': 0.01}, 'run') with experiment.start(): conf.run()