mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 10:51:23 +08:00
wassersteing gan experiment
This commit is contained in:
@ -17,7 +17,7 @@ import torch.nn as nn
|
|||||||
from labml import experiment
|
from labml import experiment
|
||||||
from labml.configs import calculate
|
from labml.configs import calculate
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml_nn.gan.simple_mnist_experiment import Configs
|
from labml_nn.gan.original.experiment import Configs
|
||||||
|
|
||||||
|
|
||||||
class Generator(Module):
|
class Generator(Module):
|
||||||
@ -107,7 +107,7 @@ calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
conf = Configs()
|
conf = Configs()
|
||||||
experiment.create(name='mnist_dcgan', comment='test')
|
experiment.create(name='mnist_dcgan')
|
||||||
experiment.configs(conf,
|
experiment.configs(conf,
|
||||||
{'discriminator': 'cnn',
|
{'discriminator': 'cnn',
|
||||||
'generator': 'cnn',
|
'generator': 'cnn',
|
286
labml_nn/gan/dcgan/experiment.ipynb
Normal file
286
labml_nn/gan/dcgan/experiment.ipynb
Normal file
File diff suppressed because one or more lines are too long
264
labml_nn/gan/original/experiment.ipynb
Normal file
264
labml_nn/gan/original/experiment.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -85,6 +85,7 @@ class Discriminator(Module):
|
|||||||
|
|
||||||
class Configs(MNISTConfigs, TrainValidConfigs):
|
class Configs(MNISTConfigs, TrainValidConfigs):
|
||||||
device: torch.device = DeviceConfigs()
|
device: torch.device = DeviceConfigs()
|
||||||
|
dataset_transforms = 'mnist_gan_transforms'
|
||||||
epochs: int = 10
|
epochs: int = 10
|
||||||
|
|
||||||
is_save_models = True
|
is_save_models = True
|
||||||
@ -146,7 +147,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
loss = self.generator_loss(logits)
|
loss = self.generator_loss(logits)
|
||||||
|
|
||||||
# Log stuff
|
# Log stuff
|
||||||
tracker.add('generated', generated_images[0:5])
|
tracker.add('generated', generated_images[0:6])
|
||||||
tracker.add("loss.generator.", loss)
|
tracker.add("loss.generator.", loss)
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
@ -161,7 +162,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
|
|
||||||
|
|
||||||
@option(Configs.dataset_transforms)
|
@option(Configs.dataset_transforms)
|
||||||
def mnist_transforms():
|
def mnist_gan_transforms():
|
||||||
return transforms.Compose([
|
return transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5,), (0.5,))
|
transforms.Normalize((0.5,), (0.5,))
|
288
labml_nn/gan/wasserstein/experiment.ipynb
Normal file
288
labml_nn/gan/wasserstein/experiment.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -1,5 +1,5 @@
|
|||||||
# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the
|
# We import the [DCGAN experiment]((../dcgan.html) and change the
|
||||||
# generator and discriminator networks
|
# loss functions
|
||||||
from labml import experiment
|
from labml import experiment
|
||||||
|
|
||||||
from labml.configs import calculate
|
from labml.configs import calculate
|
||||||
|
Reference in New Issue
Block a user