🤦‍♂️ fix

This commit is contained in:
Varuna Jayasiri
2021-10-21 11:47:05 +05:30
parent 8aa83ddf7b
commit 77bf55e03a
178 changed files with 35256 additions and 9727 deletions

View File

@ -24,8 +24,6 @@
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/gan/original/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
@ -70,7 +68,6 @@
<a href='#section-0'>#</a>
</div>
<h1>Generative Adversarial Networks experiment with MNIST</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
@ -113,9 +110,8 @@
<a href='#section-2'>#</a>
</div>
<h3>Simple MLP Generator</h3>
<p>This has three linear layers of increasing size with <code>LeakyReLU</code>
activations. The final layer has a <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">anh</span></span></span></span> activation.</p>
<p>This has three linear layers of increasing size with <code>LeakyReLU</code> activations.
The final layer has a $tanh$ activation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">36</span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
@ -161,9 +157,9 @@
<a href='#section-5'>#</a>
</div>
<h3>Simple MLP Discriminator</h3>
<p>This has three linear layers of decreasing size with <code>LeakyReLU</code>
activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.</p>
<p>This has three linear layers of decreasing size with <code>LeakyReLU</code> activations.
The final layer has a single output that gives the logit of whether input
is real or fake. You can get the probability by calculating the sigmoid of it.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span><span class="k">class</span> <span class="nc">Discriminator</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
@ -208,8 +204,8 @@
<a href='#section-8'>#</a>
</div>
<h2>Configurations</h2>
<p>This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.</p>
<p>This extends MNIST configurations to get the data loaders and Training and validation loop
configurations to simplify our implementation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="p">,</span> <span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
@ -243,8 +239,7 @@
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p> Initializations</p>
<p>Initializations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
@ -272,8 +267,9 @@
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p> <span class="katex-display"><span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mclose">)</span></span></span></span></span></p>
<p>
<script type="math/tex; mode=display">z \sim p(z)</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">def</span> <span class="nf">sample_z</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
@ -295,8 +291,7 @@
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p> Take a training step</p>
<p>Take a training step</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
@ -307,8 +302,7 @@
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Set model states </p>
<p>Set model states</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span>
@ -320,8 +314,7 @@
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Get MNIST images </p>
<p>Get MNIST images</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">data</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
@ -332,8 +325,7 @@
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Increment step in training mode </p>
<p>Increment step in training mode</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
@ -345,8 +337,7 @@
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Train the discriminator </p>
<p>Train the discriminator</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;discriminator&quot;</span><span class="p">):</span></pre></div>
@ -357,8 +348,7 @@
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Get discriminator loss </p>
<p>Get discriminator loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">calc_discriminator_loss</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
@ -369,8 +359,7 @@
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Train </p>
<p>Train</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
@ -386,9 +375,7 @@
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Train the generator once in every <code>discriminator_k</code>
</p>
<p>Train the generator once in every <code>discriminator_k</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_interval</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_k</span><span class="p">):</span>
@ -401,8 +388,7 @@
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Train </p>
<p>Train</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
@ -420,8 +406,7 @@
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p> Calculate discriminator loss</p>
<p>Calculate discriminator loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="k">def</span> <span class="nf">calc_discriminator_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span></pre></div>
@ -447,8 +432,7 @@
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Log stuff </p>
<p>Log stuff</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.discriminator.true.&quot;</span><span class="p">,</span> <span class="n">loss_true</span><span class="p">)</span>
@ -463,8 +447,7 @@
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p> Calculate generator loss</p>
<p>Calculate generator loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">def</span> <span class="nf">calc_generator_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
@ -489,8 +472,7 @@
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Log stuff </p>
<p>Log stuff</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;generated&#39;</span><span class="p">,</span> <span class="n">generated_images</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">6</span><span class="p">])</span>
@ -528,10 +510,9 @@
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Setting exponent decay rate for first moment of gradient, <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> to <code>0.5</code>
is important. Default of <code>0.9</code>
fails. </p>
<p>Setting exponent decay rate for first moment of gradient,
$\beta_1$ to <code>0.5</code> is important.
Default of <code>0.9</code> fails.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span>
@ -559,10 +540,9 @@
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Setting exponent decay rate for first moment of gradient, <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> to <code>0.5</code>
is important. Default of <code>0.9</code>
fails. </p>
<p>Setting exponent decay rate for first moment of gradient,
$\beta_1$ to <code>0.5</code> is important.
Default of <code>0.9</code> fails.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span>
@ -601,6 +581,24 @@
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')