wassersteing gan experiment

This commit is contained in:
Varuna Jayasiri
2021-05-06 15:39:45 +05:30
parent bcb673cf21
commit da040fa94e
6 changed files with 845 additions and 6 deletions

View File

@ -17,7 +17,7 @@ import torch.nn as nn
from labml import experiment from labml import experiment
from labml.configs import calculate from labml.configs import calculate
from labml_helpers.module import Module from labml_helpers.module import Module
from labml_nn.gan.simple_mnist_experiment import Configs from labml_nn.gan.original.experiment import Configs
class Generator(Module): class Generator(Module):
@ -107,7 +107,7 @@ calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
def main(): def main():
conf = Configs() conf = Configs()
experiment.create(name='mnist_dcgan', comment='test') experiment.create(name='mnist_dcgan')
experiment.configs(conf, experiment.configs(conf,
{'discriminator': 'cnn', {'discriminator': 'cnn',
'generator': 'cnn', 'generator': 'cnn',

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -85,6 +85,7 @@ class Discriminator(Module):
class Configs(MNISTConfigs, TrainValidConfigs): class Configs(MNISTConfigs, TrainValidConfigs):
device: torch.device = DeviceConfigs() device: torch.device = DeviceConfigs()
dataset_transforms = 'mnist_gan_transforms'
epochs: int = 10 epochs: int = 10
is_save_models = True is_save_models = True
@ -146,7 +147,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
loss = self.generator_loss(logits) loss = self.generator_loss(logits)
# Log stuff # Log stuff
tracker.add('generated', generated_images[0:5]) tracker.add('generated', generated_images[0:6])
tracker.add("loss.generator.", loss) tracker.add("loss.generator.", loss)
# Train # Train
@ -161,7 +162,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
@option(Configs.dataset_transforms) @option(Configs.dataset_transforms)
def mnist_transforms(): def mnist_gan_transforms():
return transforms.Compose([ return transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) transforms.Normalize((0.5,), (0.5,))

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,5 @@
# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the # We import the [DCGAN experiment]((../dcgan.html) and change the
# generator and discriminator networks # loss functions
from labml import experiment from labml import experiment
from labml.configs import calculate from labml.configs import calculate