wasserstein gan

This commit is contained in:
Varuna Jayasiri
2021-05-05 21:18:58 +05:30
parent 4b85f36f11
commit bcb673cf21
2 changed files with 70 additions and 0 deletions

View File

@ -0,0 +1,41 @@
import torch
import torch.utils.data
from torch.nn import functional as F
from labml_helpers.module import Module
class DiscriminatorLoss(Module):
"""
## Discriminator Loss
"""
def __init__(self):
super().__init__()
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
"""
`logits_true` are logits from $D(\pmb{x}^{(i)})$ and
`logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
"""
return F.relu(1 - logits_true).mean(), F.relu(1 + logits_false).mean()
class GeneratorLoss(Module):
"""
## Generator Loss
"""
def __init__(self):
super().__init__()
def __call__(self, logits: torch.Tensor):
return -logits.mean()
def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
"""
Create smoothed labels
"""
return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)

View File

@ -0,0 +1,29 @@
# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the
# generator and discriminator networks
from labml import experiment
from labml.configs import calculate
from labml_nn.gan.dcgan import Configs
from labml_nn.gan.wasserstein import GeneratorLoss, DiscriminatorLoss
calculate(Configs.generator_loss, 'wasserstein', lambda c: GeneratorLoss())
calculate(Configs.discriminator_loss, 'wasserstein', lambda c: DiscriminatorLoss())
def main():
conf = Configs()
experiment.create(name='mnist_wassertein_dcgan', comment='test')
experiment.configs(conf,
{
'discriminator': 'cnn',
'generator': 'cnn',
'label_smoothing': 0.01,
'generator_loss': 'wasserstein',
'discriminator_loss': 'wasserstein',
})
with experiment.start():
conf.run()
if __name__ == '__main__':
main()