From bc86802ddc2220a41b30377489a26850c2d44fcc Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Fri, 7 May 2021 16:21:12 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9A=20wasserstein=20gan?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/cnn/utils/cv_train.html | 69 +--- .../{cycle_gan.html => cycle_gan/index.html} | 13 +- docs/gan/{dcgan.html => dcgan/index.html} | 15 +- docs/gan/index.html | 238 +------------ .../experiment.html} | 210 +++++------ docs/gan/original/index.html | 327 ++++++++++++++++++ docs/gan/wasserstein/experiment.html | 217 ++++++++++++ docs/gan/wasserstein/index.html | 265 ++++++++++++++ docs/sitemap.xml | 90 +++-- labml_nn/gan/wasserstein/__init__.py | 127 +++++-- labml_nn/gan/wasserstein/experiment.py | 19 +- 11 files changed, 1140 insertions(+), 450 deletions(-) rename docs/gan/{cycle_gan.html => cycle_gan/index.html} (99%) rename docs/gan/{dcgan.html => dcgan/index.html} (98%) rename docs/gan/{simple_mnist_experiment.html => original/experiment.html} (81%) create mode 100644 docs/gan/original/index.html create mode 100644 docs/gan/wasserstein/experiment.html create mode 100644 docs/gan/wasserstein/index.html diff --git a/docs/cnn/utils/cv_train.html b/docs/cnn/utils/cv_train.html index 29f6db26..1bba00d4 100644 --- a/docs/cnn/utils/cv_train.html +++ b/docs/cnn/utils/cv_train.html @@ -72,38 +72,7 @@ -

Cross-Validation & Early Stopping

-

Implementation of fundamental techniques namely Cross-Validation and Early Stopping -

Cross-Validation

-

- Getting data is expensive and in some cases, one has no option but to use a limited amount of data for training their machine learning models. - This is where Cross-Validation is useful. Steps are as follows: -

    -
  1. Split the data in K folds
  2. -
  3. Use K-1 folds to train a set of models
  4. -
  5. Validate the models on the remaining fold
  6. -
  7. Repeat (1) and (2) for all the folds
  8. -
  9. Average the performance over all runs
  10. -
-

-

Early-Stopping

- Deep Learning networks are prone to overfitting, that is although overfitted models have a good performance on train set, they have poor generalization capabilities. - In other words, overfitted models have low bias and high variance. Lower the bias higher the capability of model to fit the data. Higher the variance higher the sensitivity with respect to training data. -
Formally, it can be represented as:
-

-

Therefore, user has to find a tradeoff between bias and variance.

-

-

Early-Stopping is one of the way to find this tradeoff. It helps to find a good setting of parameters and preventing overfitting on dataset and saving computation time. - This can be visualized through the following graph of train loss and validation loss over time:


- - - Training v/s Validation set Loss -
-

It can be seen that train error continue to decrease but the validation error start to increase after around 40 epochs. - Therefore, our goal is to stop the training after the validation loss increases

- -

- +
3import torch
@@ -128,10 +97,7 @@
                 
-                    

Cross-Validation

-

Splitting of training set in folds can be represented as:

- CV folds - +
21def cross_val_train(cost, trainset, epochs, splits, device=None):
@@ -190,7 +156,7 @@
                 
-                

Training steps

+

training steps

65            net.train()  # Enable Dropout
@@ -203,7 +169,6 @@
                     #
                 

Get the inputs; data is a list of [inputs, labels]

-

Load the inputs in GPU if available else CPU

68                if device:
@@ -242,7 +207,7 @@
                 
-                

Calculate loss

+

Print loss

82                running_loss += loss.item()
@@ -258,7 +223,7 @@
                 
-                

Validation and printing the metrics

+

Validation

90            loss_accuracy = Test(net, cost, valdata, device)
@@ -294,17 +259,7 @@
                 
-                

Early stopping

-

Early stopping can be understood graphically - the way weights change during the course of training.

-
    -
  • Solid contour lines indicate the contours of the negative log-likelihood (train error)
  • -
  • Dashed line indicates the trajectory taken by the optimizer
  • -
  • w∗ denotes the weight setting correspoding to the minimum training error
  • -
  • w denotes the final weights setting chosen by the model after early-stopping
  • -
- early-stopping -
- code reference here +

Early stopping refered from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py

110            if losses[epoch] > min_loss:
@@ -358,7 +313,7 @@
                 
-                

Retrieve the model which has the best accuracy over the validation set

+
138def retreive_best_trial():
@@ -412,7 +367,7 @@
                 
-                

Forward pass

+

forward pass

166    output = net(images)
@@ -423,7 +378,7 @@ -

Loss in batch

+

loss in batch

168    loss = cost(output, labels)
@@ -434,7 +389,7 @@ -

Update validation loss

+

update validation loss

171    _, preds = torch.max(output, dim=1)
@@ -502,7 +457,7 @@
                 
-                

Loss in batch

+

loss in batch

197            loss += cost(outputs, labels)
@@ -514,7 +469,7 @@
                 
-                

Calculate loss and accuracy over the validation set

+

losses[epoch] += loss.item()

201            _, predicted = torch.max(outputs.data, 1)
diff --git a/docs/gan/cycle_gan.html b/docs/gan/cycle_gan/index.html
similarity index 99%
rename from docs/gan/cycle_gan.html
rename to docs/gan/cycle_gan/index.html
index 538c6099..3fb584f1 100644
--- a/docs/gan/cycle_gan.html
+++ b/docs/gan/cycle_gan/index.html
@@ -12,7 +12,7 @@
     
     
 
-    
+    
     
     
     
@@ -22,8 +22,8 @@
 
     Cycle GAN
     
-    
-    
+    
+    
     
     
     
     
-

-

$p_{data}(\pmb{x})$ is the probability distribution over data, -whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to -gaussian noise.

-

This file defines the loss functions. Here is an MNIST example -with two multilayer perceptron for the generator and discriminator.

-
-
-
34import torch
-35import torch.nn as nn
-36import torch.utils.data
-37import torch.utils.data
-38
-39from labml_helpers.module import Module
-
-
-
-
- -

Discriminator Loss

-

Discriminator should ascend on the gradient,

-

- -

-

$m$ is the mini-batch size and $(i)$ is used to index samples in the mini-batch. -$\pmb{x}$ are samples from $p_{data}$ and $\pmb{z}$ are samples from $p_z$.

-
-
-
42class DiscriminatorLogitsLoss(Module):
-
-
-
-
- - -
-
-
57    def __init__(self, smoothing: float = 0.2):
-58        super().__init__()
-
-
-
-
- -

We use PyTorch Binary Cross Entropy Loss, which is -$-\sum\Big[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})\Big]$, -where $y$ are the labels and $\hat{y}$ are the predictions. -Note the negative sign. -We use labels equal to $1$ for $\pmb{x}$ from $p_{data}$ -and labels equal to $0$ for $\pmb{x}$ from $p_{G}.$ -Then descending on the sum of these is the same as ascending on -the above gradient.

-

BCEWithLogitsLoss combines softmax and binary cross entropy loss.

-
-
-
69        self.loss_true = nn.BCEWithLogitsLoss()
-70        self.loss_false = nn.BCEWithLogitsLoss()
-
-
-
-
- -

We use label smoothing because it seems to work better in some cases

-
-
-
73        self.smoothing = smoothing
-
-
-
-
- -

Labels are registered as buffered and persistence is set to False.

-
-
-
76        self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
-77        self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
-
-
-
-
- -

logits_true are logits from $D(\pmb{x}^{(i)})$ and -logits_false are logits from $D(G(\pmb{z}^{(i)}))$

-
-
-
79    def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
-
-
-
-
- - -
-
-
84        if len(logits_true) > len(self.labels_true):
-85            self.register_buffer("labels_true",
-86                                 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
-87        if len(logits_false) > len(self.labels_false):
-88            self.register_buffer("labels_false",
-89                                 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
-90
-91        return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
-92                self.loss_false(logits_false, self.labels_false[:len(logits_false)]))
-
-
-
-
- -

Generator Loss

-

Generator should descend on the gradient,

-

- -

-
-
-
95class GeneratorLogitsLoss(Module):
-
-
-
-
- - -
-
-
105    def __init__(self, smoothing: float = 0.2):
-106        super().__init__()
-107        self.loss_true = nn.BCEWithLogitsLoss()
-108        self.smoothing = smoothing
-
-
-
-
- -

We use labels equal to $1$ for $\pmb{x}$ from $p_{G}.$ -Then descending on this loss is the same as descending on -the above gradient.

-
-
-
112        self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
-
-
-
-
- - -
-
-
114    def __call__(self, logits: torch.Tensor):
-115        if len(logits) > len(self.fake_labels):
-116            self.register_buffer("fake_labels",
-117                                 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
-118
-119        return self.loss_true(logits, self.fake_labels[:len(logits)])
-
-
-
-
- -

Create smoothed labels

-
-
-
122def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
-
-
-
-
- - -
-
-
126    return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)
-
-
+ + + +
+
+
+
+

+ home + gan + original +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Generative Adversarial Networks (GAN)

+

This is an implementation of +Generative Adversarial Networks.

+

The generator, $G(\pmb{z}; \theta_g)$ generates samples that match the +distribution of data, while the discriminator, $D(\pmb{x}; \theta_g)$ +gives the probability that $\pmb{x}$ came from data rather than $G$.

+

We train $D$ and $G$ simultaneously on a two-player min-max game with value +function $V(G, D)$.

+

+ +

+

$p_{data}(\pmb{x})$ is the probability distribution over data, +whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to +gaussian noise.

+

This file defines the loss functions. Here is an MNIST example +with two multilayer perceptron for the generator and discriminator.

+
+
+
34import torch
+35import torch.nn as nn
+36import torch.utils.data
+37import torch.utils.data
+38
+39from labml_helpers.module import Module
+
+
+
+
+ +

Discriminator Loss

+

Discriminator should ascend on the gradient,

+

+ +

+

$m$ is the mini-batch size and $(i)$ is used to index samples in the mini-batch. +$\pmb{x}$ are samples from $p_{data}$ and $\pmb{z}$ are samples from $p_z$.

+
+
+
42class DiscriminatorLogitsLoss(Module):
+
+
+
+
+ + +
+
+
57    def __init__(self, smoothing: float = 0.2):
+58        super().__init__()
+
+
+
+
+ +

We use PyTorch Binary Cross Entropy Loss, which is +$-\sum\Big[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})\Big]$, +where $y$ are the labels and $\hat{y}$ are the predictions. +Note the negative sign. +We use labels equal to $1$ for $\pmb{x}$ from $p_{data}$ +and labels equal to $0$ for $\pmb{x}$ from $p_{G}.$ +Then descending on the sum of these is the same as ascending on +the above gradient.

+

BCEWithLogitsLoss combines softmax and binary cross entropy loss.

+
+
+
69        self.loss_true = nn.BCEWithLogitsLoss()
+70        self.loss_false = nn.BCEWithLogitsLoss()
+
+
+
+
+ +

We use label smoothing because it seems to work better in some cases

+
+
+
73        self.smoothing = smoothing
+
+
+
+
+ +

Labels are registered as buffered and persistence is set to False.

+
+
+
76        self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
+77        self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
+
+
+
+
+ +

logits_true are logits from $D(\pmb{x}^{(i)})$ and +logits_false are logits from $D(G(\pmb{z}^{(i)}))$

+
+
+
79    def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
+
+
+
+
+ + +
+
+
84        if len(logits_true) > len(self.labels_true):
+85            self.register_buffer("labels_true",
+86                                 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
+87        if len(logits_false) > len(self.labels_false):
+88            self.register_buffer("labels_false",
+89                                 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
+90
+91        return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
+92                self.loss_false(logits_false, self.labels_false[:len(logits_false)]))
+
+
+
+
+ +

Generator Loss

+

Generator should descend on the gradient,

+

+ +

+
+
+
95class GeneratorLogitsLoss(Module):
+
+
+
+
+ + +
+
+
105    def __init__(self, smoothing: float = 0.2):
+106        super().__init__()
+107        self.loss_true = nn.BCEWithLogitsLoss()
+108        self.smoothing = smoothing
+
+
+
+
+ +

We use labels equal to $1$ for $\pmb{x}$ from $p_{G}.$ +Then descending on this loss is the same as descending on +the above gradient.

+
+
+
112        self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
+
+
+
+
+ + +
+
+
114    def __call__(self, logits: torch.Tensor):
+115        if len(logits) > len(self.fake_labels):
+116            self.register_buffer("fake_labels",
+117                                 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
+118
+119        return self.loss_true(logits, self.fake_labels[:len(logits)])
+
+
+
+
+ +

Create smoothed labels

+
+
+
122def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
+
+
+
+
+ + +
+
+
126    return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)
+
+
+
+
+ + + + + \ No newline at end of file diff --git a/docs/gan/wasserstein/experiment.html b/docs/gan/wasserstein/experiment.html new file mode 100644 index 00000000..c0f76c87 --- /dev/null +++ b/docs/gan/wasserstein/experiment.html @@ -0,0 +1,217 @@ + + + + + + + + + + + + + + + + + + + + + + + WGAN experiment with MNIST + + + + + + + + +
+
+
+
+

+ home + gan + wasserstein +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

WGAN experiment with MNIST

+
+
+
9from labml import experiment
+10
+11from labml.configs import calculate
+
+
+
+
+ +

Import configurations from DCGAN experiment

+
+
+
13from labml_nn.gan.dcgan import Configs
+
+
+
+
+ +

Import Wasserstein GAN losses

+
+
+
16from labml_nn.gan.wasserstein import GeneratorLoss, DiscriminatorLoss
+
+
+
+
+ +

Set configurations options for Wasserstein GAN losses

+
+
+
19calculate(Configs.generator_loss, 'wasserstein', lambda c: GeneratorLoss())
+20calculate(Configs.discriminator_loss, 'wasserstein', lambda c: DiscriminatorLoss())
+
+
+
+
+ + +
+
+
23def main():
+
+
+
+
+ +

Create configs object

+
+
+
25    conf = Configs()
+
+
+
+
+ +

Create experiment

+
+
+
27    experiment.create(name='mnist_wassertein_dcgan', comment='test')
+
+
+
+
+ +

Override configurations

+
+
+
29    experiment.configs(conf,
+30                       {
+31                           'discriminator': 'cnn',
+32                           'generator': 'cnn',
+33                           'label_smoothing': 0.01,
+34                           'generator_loss': 'wasserstein',
+35                           'discriminator_loss': 'wasserstein',
+36                       })
+
+
+
+
+ +

Start the experiment and run training loop

+
+
+
39    with experiment.start():
+40        conf.run()
+41
+42
+43if __name__ == '__main__':
+44    main()
+
+
+
+
+ + + + + \ No newline at end of file diff --git a/docs/gan/wasserstein/index.html b/docs/gan/wasserstein/index.html new file mode 100644 index 00000000..5c557bde --- /dev/null +++ b/docs/gan/wasserstein/index.html @@ -0,0 +1,265 @@ + + + + + + + + + + + + + + + + + + + + + + + Wasserstein GAN (WGAN) + + + + + + + + +
+
+
+
+

+ home + gan + wasserstein +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

This is an implementation of +Wasserstein GAN.

+

The original GAN loss is based on Jensen-Shannon (JS) divergence +between the real distribution $\mathbb{P}_r$ and generated distribution $\mathbb{P}_g$. +The Wasserstein GAN is based on Earth Mover distance between these distributions.

+

+ +

+

$\Pi(\mathbb{P}_r, \mathbb{P}_g)$ is the set of all joint distributions, whose +marginal probabilities are $\gamma(x, y)$.

+

$\mathbb{E}_{(x,y) \sim \gamma} \Vert x - y \Vert$ is the earth mover distance for +a given joint distribution ($x$ and $y$ are probabilities).

+

So $W(\mathbb{P}_r, \mathbb{P}g)$ is equal to the least earth mover distance for +any joint distribution between the real distribution $\mathbb{P}_r$ and generated distribution $\mathbb{P}_g$.

+

The paper shows that Jensen-Shannon (JS) divergence and other measures for difference between two probability +distributions are not smooth. And therefore if we are doing a gradient descent on one of the probability +distributions (parameterized) it will not converge.

+

Based on Kantorovich-Rubinstein duality, + +

+

where $\Vert f \Vert_L \le 1$ are all 1-Lipschitz functions.

+

That is, it is equal to the greatest difference + +among all 1-Lipschitz functions.

+

For $K$-Lipschitz functions, + +

+

If all $K$-Lipschitz functions can be represented as $f_w$ where $f$ is parameterized by +$w \in \mathcal{W}$,

+

+ +

+

If $(\mathbb{P}_{g})$ is represented by a generator and $z$ is from a known +distribution $z \sim p(z)$,

+

+ +

+

Now to converge $g_\theta$ with $\mathbb{P}_{r}$ we can gradient descent on $\theta$ +to minimize above formula.

+

Similarly we can find $\max_{w \in \mathcal{W}}$ by ascending on $w$, +while keeping $K$ bounded. One way to keep $K$ bounded is to clip all weights in the neural +network that defines $f$ clipped within a range.

+

Here is the code to try this on a simple MNIST generation experiment.

+

Open In Colab

+
+
+
85import torch.utils.data
+86from torch.nn import functional as F
+87
+88from labml_helpers.module import Module
+
+
+
+
+ +

Discriminator Loss

+

We want to find $w$ to maximize +, +so we minimize, + +

+
+
+
91class DiscriminatorLoss(Module):
+
+
+
+
+ +
    +
  • f_real is $f_w(x)$
  • +
  • f_fake is $f_w(g_\theta(z))$
  • +
+
+
+
102    def __call__(self, f_real: torch.Tensor, f_fake: torch.Tensor):
+
+
+
+
+ +

We use ReLUs to clip the loss to keep $f \in [-1, +1]$ range.

+
+
+
109        return F.relu(1 - f_real).mean(), F.relu(1 + f_fake).mean()
+
+
+
+
+ +

Generator Loss

+

We want to find $\theta$ to minimize + +The first component is independent of $\theta$, +so we minimize, + +

+
+
+
112class GeneratorLoss(Module):
+
+
+
+
+ +
    +
  • f_fake is $f_w(g_\theta(z))$
  • +
+
+
+
124    def __call__(self, f_fake: torch.Tensor):
+
+
+
+
+ + +
+
+
128        return -f_fake.mean()
+
+
+
+
+ + + + + \ No newline at end of file diff --git a/docs/sitemap.xml b/docs/sitemap.xml index 83fe00de..2bf1d6dd 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -7,36 +7,78 @@ http://www.sitemaps.org/schemas/sitemap/0.9/sitemap.xsd"> - https://nn.labml.ai/gan/cycle_gan.html - 2021-01-23T16:30:00+00:00 + https://nn.labml.ai/gan/wasserstein/experiment.html + 2021-05-06T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/wasserstein/index.html + 2021-05-05T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/wasserstein/experiment.html + 2021-05-06T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/original/experiment.html + 2021-05-06T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/original/index.html + 2021-05-05T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/original/experiment.html + 2021-05-06T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/dcgan/experiment.html + 2021-05-06T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/dcgan/index.html + 2021-05-06T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/cycle_gan/experiment.html + 2021-05-05T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/cycle_gan/index.html + 2021-05-05T16:30:00+00:00 1.00 https://nn.labml.ai/gan/index.html - 2021-02-14T16:30:00+00:00 - 1.00 - - - - - https://nn.labml.ai/gan/simple_mnist_experiment.html - 2020-12-10T16:30:00+00:00 - 1.00 - - - - - https://nn.labml.ai/gan/dcgan.html - 2021-02-14T16:30:00+00:00 - 1.00 - - - - - https://nn.labml.ai/gan/cycle_gan.html - 2021-02-27T16:30:00+00:00 + 2021-05-05T16:30:00+00:00 1.00 diff --git a/labml_nn/gan/wasserstein/__init__.py b/labml_nn/gan/wasserstein/__init__.py index c70a96d4..9919db3f 100644 --- a/labml_nn/gan/wasserstein/__init__.py +++ b/labml_nn/gan/wasserstein/__init__.py @@ -1,4 +1,87 @@ -import torch +r""" +--- +title: Wasserstein GAN (WGAN) +summary: A simple PyTorch implementation/tutorial of Wasserstein Generative Adversarial Networks (WGAN) loss functions. +--- + +This is an implementation of +[Wasserstein GAN](https://arxiv.org/abs/1701.07875). + +The original GAN loss is based on Jensen-Shannon (JS) divergence +between the real distribution $\mathbb{P}_r$ and generated distribution $\mathbb{P}_g$. +The Wasserstein GAN is based on Earth Mover distance between these distributions. + +$$ +W(\mathbb{P}_r, \mathbb{P}_g) = + \underset{\gamma \in \Pi(\mathbb{P}_r, \mathbb{P}_g)} {\mathrm{inf}} + \mathbb{E}_{(x,y) \sim \gamma} + \Vert x - y \Vert +$$ + +$\Pi(\mathbb{P}_r, \mathbb{P}_g)$ is the set of all joint distributions, whose +marginal probabilities are $\gamma(x, y)$. + +$\mathbb{E}_{(x,y) \sim \gamma} \Vert x - y \Vert$ is the earth mover distance for +a given joint distribution ($x$ and $y$ are probabilities). + +So $W(\mathbb{P}_r, \mathbb{P}g)$ is equal to the least earth mover distance for +any joint distribution between the real distribution $\mathbb{P}_r$ and generated distribution $\mathbb{P}_g$. + +The paper shows that Jensen-Shannon (JS) divergence and other measures for difference between two probability +distributions are not smooth. And therefore if we are doing a gradient descent on one of the probability +distributions (parameterized) it will not converge. + +Based on Kantorovich-Rubinstein duality, +$$ +W(\mathbb{P}_r, \mathbb{P}_g) = + \underset{\Vert f \Vert_L \le 1} {\mathrm{sup}} + \mathbb{E}_{x \sim \mathbb{P}_r} [f(x)]- \mathbb{E}_{x \sim \mathbb{P}_g} [f(x)] +$$ + +where $\Vert f \Vert_L \le 1$ are all 1-Lipschitz functions. + +That is, it is equal to the greatest difference +$$\mathbb{E}_{x \sim \mathbb{P}_r} [f(x)] - \mathbb{E}_{x \sim \mathbb{P}_g} [f(x)]$$ +among all 1-Lipschitz functions. + +For $K$-Lipschitz functions, +$$ +W(\mathbb{P}_r, \mathbb{P}_g) = + \underset{\Vert f \Vert_L \le K} {\mathrm{sup}} + \mathbb{E}_{x \sim \mathbb{P}_r} \Bigg[\frac{1}{K} f(x) \Bigg] + - \mathbb{E}_{x \sim \mathbb{P}_g} \Bigg[\frac{1}{K} f(x) \Bigg] +$$ + +If all $K$-Lipschitz functions can be represented as $f_w$ where $f$ is parameterized by +$w \in \mathcal{W}$, + +$$ +K \cdot W(\mathbb{P}_r, \mathbb{P}_g) = + \max_{w \in \mathcal{W}} + \mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)]- \mathbb{E}_{x \sim \mathbb{P}_g} [f_w(x)] +$$ + +If $(\mathbb{P}_{g})$ is represented by a generator $$g_\theta (z)$$ and $z$ is from a known +distribution $z \sim p(z)$, + +$$ +K \ cdot W(\mathbb{P}_r, \mathbb{P}_\theta) = + \max_{w \in \mathcal{W}} + \mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)]- \mathbb{E}_{z \sim p(z)} [f_w(g_\theta(z))] +$$ + +Now to converge $g_\theta$ with $\mathbb{P}_{r}$ we can gradient descent on $\theta$ +to minimize above formula. + +Similarly we can find $\max_{w \in \mathcal{W}}$ by ascending on $w$, +while keeping $K$ bounded. *One way to keep $K$ bounded is to clip all weights in the neural +network that defines $f$ clipped within a range.* + +Here is the code to try this on a [simple MNIST generation experiment](experiment.html). + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/gan/wasserstein/experiment.ipynb) +""" + import torch.utils.data from torch.nn import functional as F @@ -8,34 +91,38 @@ from labml_helpers.module import Module class DiscriminatorLoss(Module): """ ## Discriminator Loss + + We want to find $w$ to maximize + $$\mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)]- \mathbb{E}_{z \sim p(z)} [f_w(g_\theta(z))]$$, + so we minimize, + $$-\frac{1}{m} \sum_{i=1}^m f_w \big(x^{(i)} \big) + + \frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)$$ """ - def __init__(self): - super().__init__() - - def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor): + def __call__(self, f_real: torch.Tensor, f_fake: torch.Tensor): """ - `logits_true` are logits from $D(\pmb{x}^{(i)})$ and - `logits_false` are logits from $D(G(\pmb{z}^{(i)}))$ + * `f_real` is $f_w(x)$ + * `f_fake` is $f_w(g_\theta(z))$ """ - return F.relu(1 - logits_true).mean(), F.relu(1 + logits_false).mean() + # We use ReLUs to clip the loss to keep $f \in [-1, +1]$ range. + return F.relu(1 - f_real).mean(), F.relu(1 + f_fake).mean() class GeneratorLoss(Module): """ ## Generator Loss + + We want to find $\theta$ to minimize + $$\mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)]- \mathbb{E}_{z \sim p(z)} [f_w(g_\theta(z))]$$ + The first component is independent of $\theta$, + so we minimize, + $$-\frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)$$ + """ - def __init__(self): - super().__init__() - - def __call__(self, logits: torch.Tensor): - return -logits.mean() - - -def _create_labels(n: int, r1: float, r2: float, device: torch.device = None): - """ - Create smoothed labels - """ - return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2) + def __call__(self, f_fake: torch.Tensor): + """ + * `f_fake` is $f_w(g_\theta(z))$ + """ + return -f_fake.mean() diff --git a/labml_nn/gan/wasserstein/experiment.py b/labml_nn/gan/wasserstein/experiment.py index 3eb6f795..54a71fc3 100644 --- a/labml_nn/gan/wasserstein/experiment.py +++ b/labml_nn/gan/wasserstein/experiment.py @@ -1,18 +1,31 @@ -# We import the [DCGAN experiment]((../dcgan.html) and change the -# loss functions +""" +--- +title: WGAN experiment with MNIST +summary: This experiment generates MNIST images using convolutional neural network. +--- + +# WGAN experiment with MNIST +""" from labml import experiment from labml.configs import calculate +# Import configurations from [DCGAN experiment](../dcgan/index.html) from labml_nn.gan.dcgan import Configs + +# Import [Wasserstein GAN losses](./index.html) from labml_nn.gan.wasserstein import GeneratorLoss, DiscriminatorLoss +# Set configurations options for Wasserstein GAN losses calculate(Configs.generator_loss, 'wasserstein', lambda c: GeneratorLoss()) calculate(Configs.discriminator_loss, 'wasserstein', lambda c: DiscriminatorLoss()) def main(): + # Create configs object conf = Configs() + # Create experiment experiment.create(name='mnist_wassertein_dcgan', comment='test') + # Override configurations experiment.configs(conf, { 'discriminator': 'cnn', @@ -21,6 +34,8 @@ def main(): 'generator_loss': 'wasserstein', 'discriminator_loss': 'wasserstein', }) + + # Start the experiment and run training loop with experiment.start(): conf.run()