Files
Varuna Jayasiri 591bfc3714 typosg
2021-05-07 16:45:38 +05:30

266 lines
13 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A simple PyTorch implementation/tutorial of Wasserstein Generative Adversarial Networks (WGAN) loss functions."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Wasserstein GAN (WGAN)"/>
<meta name="twitter:description" content="A simple PyTorch implementation/tutorial of Wasserstein Generative Adversarial Networks (WGAN) loss functions."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/gan/wasserstein/index.html"/>
<meta property="og:title" content="Wasserstein GAN (WGAN)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Wasserstein GAN (WGAN)"/>
<meta property="og:description" content="A simple PyTorch implementation/tutorial of Wasserstein Generative Adversarial Networks (WGAN) loss functions."/>
<title>Wasserstein GAN (WGAN)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/gan/wasserstein/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">gan</a>
<a class="parent" href="index.html">wasserstein</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/gan/wasserstein/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Wasserstein GAN (WGAN)</h1>
<p>This is an implementation of
<a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a>.</p>
<p>The original GAN loss is based on Jensen-Shannon (JS) divergence
between the real distribution $\mathbb{P}_r$ and generated distribution $\mathbb{P}_g$.
The Wasserstein GAN is based on Earth Mover distance between these distributions.</p>
<p>
<script type="math/tex; mode=display">
W(\mathbb{P}_r, \mathbb{P}_g) =
\underset{\gamma \in \Pi(\mathbb{P}_r, \mathbb{P}_g)} {\mathrm{inf}}
\mathbb{E}_{(x,y) \sim \gamma}
\Vert x - y \Vert
</script>
</p>
<p>$\Pi(\mathbb{P}_r, \mathbb{P}_g)$ is the set of all joint distributions, whose
marginal probabilities are $\gamma(x, y)$.</p>
<p>$\mathbb{E}_{(x,y) \sim \gamma} \Vert x - y \Vert$ is the earth mover distance for
a given joint distribution ($x$ and $y$ are probabilities).</p>
<p>So $W(\mathbb{P}_r, \mathbb{P}g)$ is equal to the least earth mover distance for
any joint distribution between the real distribution $\mathbb{P}_r$ and generated distribution $\mathbb{P}_g$.</p>
<p>The paper shows that Jensen-Shannon (JS) divergence and other measures for the difference between two probability
distributions are not smooth. And therefore if we are doing gradient descent on one of the probability
distributions (parameterized) it will not converge.</p>
<p>Based on Kantorovich-Rubinstein duality,
<script type="math/tex; mode=display">
W(\mathbb{P}_r, \mathbb{P}_g) =
\underset{\Vert f \Vert_L \le 1} {\mathrm{sup}}
\mathbb{E}_{x \sim \mathbb{P}_r} [f(x)]- \mathbb{E}_{x \sim \mathbb{P}_g} [f(x)]
</script>
</p>
<p>where $\Vert f \Vert_L \le 1$ are all 1-Lipschitz functions.</p>
<p>That is, it is equal to the greatest difference
<script type="math/tex; mode=display">\mathbb{E}_{x \sim \mathbb{P}_r} [f(x)] - \mathbb{E}_{x \sim \mathbb{P}_g} [f(x)]</script>
among all 1-Lipschitz functions.</p>
<p>For $K$-Lipschitz functions,
<script type="math/tex; mode=display">
W(\mathbb{P}_r, \mathbb{P}_g) =
\underset{\Vert f \Vert_L \le K} {\mathrm{sup}}
\mathbb{E}_{x \sim \mathbb{P}_r} \Bigg[\frac{1}{K} f(x) \Bigg]
- \mathbb{E}_{x \sim \mathbb{P}_g} \Bigg[\frac{1}{K} f(x) \Bigg]
</script>
</p>
<p>If all $K$-Lipschitz functions can be represented as $f_w$ where $f$ is parameterized by
$w \in \mathcal{W}$,</p>
<p>
<script type="math/tex; mode=display">
K \cdot W(\mathbb{P}_r, \mathbb{P}_g) =
\max_{w \in \mathcal{W}}
\mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)]- \mathbb{E}_{x \sim \mathbb{P}_g} [f_w(x)]
</script>
</p>
<p>If $(\mathbb{P}_{g})$ is represented by a generator <script type="math/tex; mode=display">g_\theta (z)</script> and $z$ is from a known
distribution $z \sim p(z)$,</p>
<p>
<script type="math/tex; mode=display">
K \ cdot W(\mathbb{P}_r, \mathbb{P}_\theta) =
\max_{w \in \mathcal{W}}
\mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)]- \mathbb{E}_{z \sim p(z)} [f_w(g_\theta(z))]
</script>
</p>
<p>Now to converge $g_\theta$ with $\mathbb{P}_{r}$ we can gradient descent on $\theta$
to minimize above formula.</p>
<p>Similarly we can find $\max_{w \in \mathcal{W}}$ by ascending on $w$,
while keeping $K$ bounded. <em>One way to keep $K$ bounded is to clip all weights in the neural
network that defines $f$ clipped within a range.</em></p>
<p>Here is the code to try this on a <a href="experiment.html">simple MNIST generation experiment</a>.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/gan/wasserstein/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span><span></span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">88</span><span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="lineno">89</span>
<span class="lineno">90</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Discriminator Loss</h2>
<p>We want to find $w$ to maximize
<script type="math/tex; mode=display">\mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)]- \mathbb{E}_{z \sim p(z)} [f_w(g_\theta(z))]</script>,
so we minimize,
<script type="math/tex; mode=display">-\frac{1}{m} \sum_{i=1}^m f_w \big(x^{(i)} \big) +
\frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span><span class="k">class</span> <span class="nc">DiscriminatorLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul>
<li><code>f_real</code> is $f_w(x)$</li>
<li><code>f_fake</code> is $f_w(g_\theta(z))$</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">f_real</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">f_fake</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>We use ReLUs to clip the loss to keep $f \in [-1, +1]$ range.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">f_real</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(),</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">f_fake</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<h2>Generator Loss</h2>
<p>We want to find $\theta$ to minimize
<script type="math/tex; mode=display">\mathbb{E}_{x \sim \mathbb{P}_r} [f_w(x)]- \mathbb{E}_{z \sim p(z)} [f_w(g_\theta(z))]</script>
The first component is independent of $\theta$,
so we minimize,
<script type="math/tex; mode=display">-\frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span><span class="k">class</span> <span class="nc">GeneratorLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<ul>
<li><code>f_fake</code> is $f_w(g_\theta(z))$</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">f_fake</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">return</span> <span class="o">-</span><span class="n">f_fake</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>