unescape *

This commit is contained in:
Varuna Jayasiri
2021-10-21 11:46:06 +05:30
parent 4c5a706836
commit 8aa83ddf7b
179 changed files with 9727 additions and 35256 deletions

View File

@ -24,6 +24,8 @@
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/gan/original/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
@ -68,6 +70,7 @@
<a href='#section-0'>#</a>
</div>
<h1>Generative Adversarial Networks experiment with MNIST</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
@ -110,8 +113,9 @@
<a href='#section-2'>#</a>
</div>
<h3>Simple MLP Generator</h3>
<p>This has three linear layers of increasing size with <code>LeakyReLU</code> activations.
The final layer has a $tanh$ activation.</p>
<p>This has three linear layers of increasing size with <code>LeakyReLU</code>
activations. The final layer has a <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">anh</span></span></span></span> activation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">36</span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
@ -157,9 +161,9 @@ The final layer has a $tanh$ activation.</p>
<a href='#section-5'>#</a>
</div>
<h3>Simple MLP Discriminator</h3>
<p>This has three linear layers of decreasing size with <code>LeakyReLU</code> activations.
The final layer has a single output that gives the logit of whether input
is real or fake. You can get the probability by calculating the sigmoid of it.</p>
<p>This has three linear layers of decreasing size with <code>LeakyReLU</code>
activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span><span class="k">class</span> <span class="nc">Discriminator</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
@ -204,8 +208,8 @@ is real or fake. You can get the probability by calculating the sigmoid of it.</
<a href='#section-8'>#</a>
</div>
<h2>Configurations</h2>
<p>This extends MNIST configurations to get the data loaders and Training and validation loop
configurations to simplify our implementation.</p>
<p>This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="p">,</span> <span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
@ -239,7 +243,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Initializations</p>
<p> Initializations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
@ -267,9 +272,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>
<script type="math/tex; mode=display">z \sim p(z)</script>
</p>
<p> <span class="katex-display"><span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mclose">)</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">def</span> <span class="nf">sample_z</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
@ -291,7 +295,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Take a training step</p>
<p> Take a training step</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
@ -302,7 +307,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Set model states</p>
<p>Set model states </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span>
@ -314,7 +320,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Get MNIST images</p>
<p>Get MNIST images </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">data</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
@ -325,7 +332,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Increment step in training mode</p>
<p>Increment step in training mode </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
@ -337,7 +345,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Train the discriminator</p>
<p>Train the discriminator </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;discriminator&quot;</span><span class="p">):</span></pre></div>
@ -348,7 +357,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Get discriminator loss</p>
<p>Get discriminator loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">calc_discriminator_loss</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
@ -359,7 +369,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Train</p>
<p>Train </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
@ -375,7 +386,9 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Train the generator once in every <code>discriminator_k</code></p>
<p>Train the generator once in every <code>discriminator_k</code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_interval</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_k</span><span class="p">):</span>
@ -388,7 +401,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Train</p>
<p>Train </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
@ -406,7 +420,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Calculate discriminator loss</p>
<p> Calculate discriminator loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="k">def</span> <span class="nf">calc_discriminator_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span></pre></div>
@ -432,7 +447,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Log stuff</p>
<p>Log stuff </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.discriminator.true.&quot;</span><span class="p">,</span> <span class="n">loss_true</span><span class="p">)</span>
@ -447,7 +463,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Calculate generator loss</p>
<p> Calculate generator loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">def</span> <span class="nf">calc_generator_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
@ -472,7 +489,8 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Log stuff</p>
<p>Log stuff </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;generated&#39;</span><span class="p">,</span> <span class="n">generated_images</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">6</span><span class="p">])</span>
@ -510,9 +528,10 @@ configurations to simplify our implementation.</p>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Setting exponent decay rate for first moment of gradient,
$\beta_1$ to <code>0.5</code> is important.
Default of <code>0.9</code> fails.</p>
<p>Setting exponent decay rate for first moment of gradient, <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> to <code>0.5</code>
is important. Default of <code>0.9</code>
fails. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span>
@ -540,9 +559,10 @@ Default of <code>0.9</code> fails.</p>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Setting exponent decay rate for first moment of gradient,
$\beta_1$ to <code>0.5</code> is important.
Default of <code>0.9</code> fails.</p>
<p>Setting exponent decay rate for first moment of gradient, <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> to <code>0.5</code>
is important. Default of <code>0.9</code>
fails. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span>
@ -581,24 +601,6 @@ Default of <code>0.9</code> fails.</p>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')