mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
✨ wasserstein gan
This commit is contained in:
41
labml_nn/gan/wasserstein/__init__.py
Normal file
41
labml_nn/gan/wasserstein/__init__.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorLoss(Module):
|
||||||
|
"""
|
||||||
|
## Discriminator Loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
|
||||||
|
"""
|
||||||
|
`logits_true` are logits from $D(\pmb{x}^{(i)})$ and
|
||||||
|
`logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
|
||||||
|
"""
|
||||||
|
|
||||||
|
return F.relu(1 - logits_true).mean(), F.relu(1 + logits_false).mean()
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorLoss(Module):
|
||||||
|
"""
|
||||||
|
## Generator Loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, logits: torch.Tensor):
|
||||||
|
return -logits.mean()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
|
||||||
|
"""
|
||||||
|
Create smoothed labels
|
||||||
|
"""
|
||||||
|
return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)
|
29
labml_nn/gan/wasserstein/experiment.py
Normal file
29
labml_nn/gan/wasserstein/experiment.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# We import the [simple gan experiment]((simple_mnist_experiment.html) and change the
|
||||||
|
# generator and discriminator networks
|
||||||
|
from labml import experiment
|
||||||
|
|
||||||
|
from labml.configs import calculate
|
||||||
|
from labml_nn.gan.dcgan import Configs
|
||||||
|
from labml_nn.gan.wasserstein import GeneratorLoss, DiscriminatorLoss
|
||||||
|
|
||||||
|
calculate(Configs.generator_loss, 'wasserstein', lambda c: GeneratorLoss())
|
||||||
|
calculate(Configs.discriminator_loss, 'wasserstein', lambda c: DiscriminatorLoss())
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
conf = Configs()
|
||||||
|
experiment.create(name='mnist_wassertein_dcgan', comment='test')
|
||||||
|
experiment.configs(conf,
|
||||||
|
{
|
||||||
|
'discriminator': 'cnn',
|
||||||
|
'generator': 'cnn',
|
||||||
|
'label_smoothing': 0.01,
|
||||||
|
'generator_loss': 'wasserstein',
|
||||||
|
'discriminator_loss': 'wasserstein',
|
||||||
|
})
|
||||||
|
with experiment.start():
|
||||||
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Reference in New Issue
Block a user