wassersteing gan experiment

This commit is contained in:
Varuna Jayasiri
2021-05-06 15:39:45 +05:30
parent bcb673cf21
commit da040fa94e
6 changed files with 845 additions and 6 deletions

View File

@ -17,7 +17,7 @@ import torch.nn as nn
from labml import experiment
from labml.configs import calculate
from labml_helpers.module import Module
from labml_nn.gan.simple_mnist_experiment import Configs
from labml_nn.gan.original.experiment import Configs
class Generator(Module):
@ -107,7 +107,7 @@ calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
def main():
conf = Configs()
experiment.create(name='mnist_dcgan', comment='test')
experiment.create(name='mnist_dcgan')
experiment.configs(conf,
{'discriminator': 'cnn',
'generator': 'cnn',

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -85,6 +85,7 @@ class Discriminator(Module):
class Configs(MNISTConfigs, TrainValidConfigs):
device: torch.device = DeviceConfigs()
dataset_transforms = 'mnist_gan_transforms'
epochs: int = 10
is_save_models = True
@ -146,7 +147,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
loss = self.generator_loss(logits)
# Log stuff
tracker.add('generated', generated_images[0:5])
tracker.add('generated', generated_images[0:6])
tracker.add("loss.generator.", loss)
# Train
@ -161,7 +162,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
@option(Configs.dataset_transforms)
def mnist_transforms():
def mnist_gan_transforms():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,5 @@
# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the
# generator and discriminator networks
# We import the [DCGAN experiment]((../dcgan.html) and change the
# loss functions
from labml import experiment
from labml.configs import calculate