unescape *

This commit is contained in:
Varuna Jayasiri
2021-10-21 11:46:06 +05:30
parent 4c5a706836
commit 8aa83ddf7b
179 changed files with 9727 additions and 35256 deletions

View File

@ -24,6 +24,8 @@
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/gan/wasserstein/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
@ -68,6 +70,7 @@
<a href='#section-0'>#</a>
</div>
<h1>WGAN experiment with MNIST</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">9</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
@ -80,7 +83,8 @@
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p>Import configurations from <a href="../dcgan/index.html">DCGAN experiment</a></p>
<p>Import configurations from <a href="../dcgan/index.html">DCGAN experiment</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span class="kn">from</span> <span class="nn">labml_nn.gan.dcgan</span> <span class="kn">import</span> <span class="n">Configs</span></pre></div>
@ -91,7 +95,8 @@
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Import <a href="./index.html">Wasserstein GAN losses</a></p>
<p>Import <a href="./index.html">Wasserstein GAN losses</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.gan.wasserstein</span> <span class="kn">import</span> <span class="n">GeneratorLoss</span><span class="p">,</span> <span class="n">DiscriminatorLoss</span></pre></div>
@ -102,7 +107,8 @@
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Set configurations options for Wasserstein GAN losses</p>
<p>Set configurations options for Wasserstein GAN losses </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">19</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">generator_loss</span><span class="p">,</span> <span class="s1">&#39;wasserstein&#39;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">c</span><span class="p">:</span> <span class="n">GeneratorLoss</span><span class="p">())</span>
@ -125,7 +131,8 @@
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Create configs object</p>
<p>Create configs object </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">25</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
@ -136,7 +143,8 @@
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Create experiment</p>
<p>Create experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;mnist_wassertein_dcgan&#39;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">&#39;test&#39;</span><span class="p">)</span></pre></div>
@ -147,7 +155,8 @@
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Override configurations</p>
<p>Override configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span>
@ -165,7 +174,8 @@
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Start the experiment and run training loop</p>
<p>Start the experiment and run training loop </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
@ -181,24 +191,6 @@
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')

View File

@ -24,6 +24,8 @@
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/gan/wasserstein/gradient_penalty/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
@ -69,6 +71,7 @@
<a href='#section-0'>#</a>
</div>
<h1>WGAN-GP experiment with MNIST</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
@ -81,7 +84,8 @@
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p>Import configurations from <a href="../experiment.html">Wasserstein experiment</a></p>
<p>Import configurations from <a href="../experiment.html">Wasserstein experiment</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_nn.gan.wasserstein.experiment</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">OriginalConfigs</span></pre></div>
@ -92,7 +96,8 @@
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.gan.wasserstein.gradient_penalty</span> <span class="kn">import</span> <span class="n">GradientPenalty</span></pre></div>
@ -104,8 +109,8 @@
<a href='#section-3'>#</a>
</div>
<h2>Configuration class</h2>
<p>We extend <a href="../../original/experiment.html">original GAN implementation</a> and override the discriminator (critic) loss
calculation to include gradient penalty.</p>
<p>We extend <a href="../../original/experiment.html">original GAN implementation</a> and override the discriminator (critic) loss calculation to include gradient penalty.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">19</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">OriginalConfigs</span><span class="p">):</span></pre></div>
@ -116,7 +121,8 @@ calculation to include gradient penalty.</p>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Gradient penalty coefficient $\lambda$</p>
<p>Gradient penalty coefficient <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">λ</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">28</span> <span class="n">gradient_penalty_coefficient</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10.0</span></pre></div>
@ -127,7 +133,8 @@ calculation to include gradient penalty.</p>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">30</span> <span class="n">gradient_penalty</span> <span class="o">=</span> <span class="n">GradientPenalty</span><span class="p">()</span></pre></div>
@ -138,8 +145,8 @@ calculation to include gradient penalty.</p>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>This overrides the original discriminator loss calculation and
includes gradient penalty.</p>
<p> This overrides the original discriminator loss calculation and includes gradient penalty.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span> <span class="k">def</span> <span class="nf">calc_discriminator_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
@ -150,7 +157,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Require gradients on $x$ to calculate gradient penalty</p>
<p>Require gradients on <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">x</span></span></span></span> to calculate gradient penalty </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="n">data</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span></pre></div>
@ -161,7 +169,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Sample $z \sim p(z)$</p>
<p>Sample <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mclose">)</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="n">latent</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_z</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span></pre></div>
@ -172,7 +181,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>$D(x)$</p>
<p><span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">f_real</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
@ -183,7 +193,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>$D(G_\theta(z))$</p>
<p><span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mclose">))</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="n">f_fake</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">latent</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span></pre></div>
@ -194,7 +205,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Get discriminator losses</p>
<p>Get discriminator losses </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="n">loss_true</span><span class="p">,</span> <span class="n">loss_false</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span><span class="p">(</span><span class="n">f_real</span><span class="p">,</span> <span class="n">f_fake</span><span class="p">)</span></pre></div>
@ -205,7 +217,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Calculate gradient penalties in training mode</p>
<p>Calculate gradient penalties in training mode </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
@ -219,7 +232,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Skip gradient penalty otherwise</p>
<p>Skip gradient penalty otherwise </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="k">else</span><span class="p">:</span>
@ -231,7 +245,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Log stuff</p>
<p>Log stuff </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.discriminator.true.&quot;</span><span class="p">,</span> <span class="n">loss_true</span><span class="p">)</span>
@ -257,7 +272,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Create configs object</p>
<p>Create configs object </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
@ -268,7 +284,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Create experiment</p>
<p>Create experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;mnist_wassertein_gp_dcgan&#39;</span><span class="p">)</span></pre></div>
@ -279,7 +296,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Override configurations</p>
<p>Override configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span>
@ -298,7 +316,8 @@ includes gradient penalty.</p>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Start the experiment and run training loop</p>
<p>Start the experiment and run training loop </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
@ -314,24 +333,6 @@ includes gradient penalty.</p>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')

File diff suppressed because one or more lines are too long

View File

@ -24,6 +24,8 @@
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/gan/wasserstein/gradient_penalty/readme.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
@ -69,19 +71,11 @@
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/gan/wasserstein/gradient_penalty/index.html">Gradient Penalty for Wasserstein GAN (WGAN-GP)</a></h1>
<p>This is an implementation of
<a href="https://papers.labml.ai/paper/1704.00028">Improved Training of Wasserstein GANs</a>.</p>
<p><a href="https://nn.labml.ai/gan/wasserstein/index.html">WGAN</a> suggests
clipping weights to enforce Lipschitz constraint
on the discriminator network (critic).
This and other weight constraints like L2 norm clipping, weight normalization,
L1, L2 weight decay have problems:</p>
<ol>
<li>Limiting the capacity of the discriminator</li>
<li>Exploding and vanishing gradients (without <a href="https://nn.labml.ai/normalization/batch_norm/index.html">Batch Normalization</a>).</li>
</ol>
<p>The paper <a href="https://papers.labml.ai/paper/1704.00028">Improved Training of Wasserstein GANs</a>
proposal a better way to improve Lipschitz constraint, a gradient penalty.</p>
<p>This is an implementation of <a href="https://papers.labml.ai/paper/1704.00028">Improved Training of Wasserstein GANs</a>.</p>
<p><a href="https://nn.labml.ai/gan/wasserstein/index.html">WGAN</a> suggests clipping weights to enforce Lipschitz constraint on the discriminator network (critic). This and other weight constraints like L2 norm clipping, weight normalization, L1, L2 weight decay have problems:</p>
<p>1. Limiting the capacity of the discriminator 2. Exploding and vanishing gradients (without <a href="https://nn.labml.ai/normalization/batch_norm/index.html">Batch Normalization</a>).</p>
<p>The paper <a href="https://papers.labml.ai/paper/1704.00028">Improved Training of Wasserstein GANs</a> proposal a better way to improve Lipschitz constraint, a gradient penalty. </p>
</div>
<div class='code'>
@ -92,24 +86,6 @@ proposal a better way to improve Lipschitz constraint, a gradient penalty.</p>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')

File diff suppressed because one or more lines are too long

View File

@ -24,6 +24,8 @@
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/gan/wasserstein/readme.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
@ -68,8 +70,8 @@
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/gan/wasserstein/index.html">Wasserstein GAN - WGAN</a></h1>
<p>This is an implementation of
<a href="https://papers.labml.ai/paper/1701.07875">Wasserstein GAN</a>.</p>
<p>This is an implementation of <a href="https://papers.labml.ai/paper/1701.07875">Wasserstein GAN</a>. </p>
</div>
<div class='code'>
@ -80,24 +82,6 @@
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')