mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
wassersteing gan experiment
This commit is contained in:
@ -17,7 +17,7 @@ import torch.nn as nn
|
||||
from labml import experiment
|
||||
from labml.configs import calculate
|
||||
from labml_helpers.module import Module
|
||||
from labml_nn.gan.simple_mnist_experiment import Configs
|
||||
from labml_nn.gan.original.experiment import Configs
|
||||
|
||||
|
||||
class Generator(Module):
|
||||
@ -107,7 +107,7 @@ calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
|
||||
|
||||
def main():
|
||||
conf = Configs()
|
||||
experiment.create(name='mnist_dcgan', comment='test')
|
||||
experiment.create(name='mnist_dcgan')
|
||||
experiment.configs(conf,
|
||||
{'discriminator': 'cnn',
|
||||
'generator': 'cnn',
|
286
labml_nn/gan/dcgan/experiment.ipynb
Normal file
286
labml_nn/gan/dcgan/experiment.ipynb
Normal file
File diff suppressed because one or more lines are too long
264
labml_nn/gan/original/experiment.ipynb
Normal file
264
labml_nn/gan/original/experiment.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -85,6 +85,7 @@ class Discriminator(Module):
|
||||
|
||||
class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
device: torch.device = DeviceConfigs()
|
||||
dataset_transforms = 'mnist_gan_transforms'
|
||||
epochs: int = 10
|
||||
|
||||
is_save_models = True
|
||||
@ -146,7 +147,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
loss = self.generator_loss(logits)
|
||||
|
||||
# Log stuff
|
||||
tracker.add('generated', generated_images[0:5])
|
||||
tracker.add('generated', generated_images[0:6])
|
||||
tracker.add("loss.generator.", loss)
|
||||
|
||||
# Train
|
||||
@ -161,7 +162,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
|
||||
|
||||
@option(Configs.dataset_transforms)
|
||||
def mnist_transforms():
|
||||
def mnist_gan_transforms():
|
||||
return transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,))
|
288
labml_nn/gan/wasserstein/experiment.ipynb
Normal file
288
labml_nn/gan/wasserstein/experiment.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -1,5 +1,5 @@
|
||||
# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the
|
||||
# generator and discriminator networks
|
||||
# We import the [DCGAN experiment]((../dcgan.html) and change the
|
||||
# loss functions
|
||||
from labml import experiment
|
||||
|
||||
from labml.configs import calculate
|
||||
|
Reference in New Issue
Block a user