mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	unescape *
This commit is contained in:
		| @ -24,6 +24,8 @@ | ||||
|     <link rel="shortcut icon" href="/icon.png"/> | ||||
|     <link rel="stylesheet" href="../../pylit.css"> | ||||
|     <link rel="canonical" href="https://nn.labml.ai/gan/original/experiment.html"/> | ||||
|     <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous"> | ||||
|  | ||||
|     <!-- Global site tag (gtag.js) - Google Analytics --> | ||||
|     <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script> | ||||
|     <script> | ||||
| @ -68,6 +70,7 @@ | ||||
|                 <a href='#section-0'>#</a> | ||||
|             </div> | ||||
|             <h1>Generative Adversarial Networks experiment with MNIST</h1> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span> | ||||
| @ -110,8 +113,9 @@ | ||||
|                 <a href='#section-2'>#</a> | ||||
|             </div> | ||||
|             <h3>Simple MLP Generator</h3> | ||||
| <p>This has three linear layers of increasing size with <code>LeakyReLU</code> activations. | ||||
| The final layer has a $tanh$ activation.</p> | ||||
| <p>This has three linear layers of increasing size with <code>LeakyReLU</code> | ||||
|  activations. The final layer has a <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">anh</span></span></span></span> activation.</p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">36</span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> | ||||
| @ -157,9 +161,9 @@ The final layer has a $tanh$ activation.</p> | ||||
|                 <a href='#section-5'>#</a> | ||||
|             </div> | ||||
|             <h3>Simple MLP Discriminator</h3> | ||||
| <p>This has three  linear layers of decreasing size with <code>LeakyReLU</code> activations. | ||||
| The final layer has a single output that gives the logit of whether input | ||||
| is real or fake. You can get the probability by calculating the sigmoid of it.</p> | ||||
| <p>This has three linear layers of decreasing size with <code>LeakyReLU</code> | ||||
|  activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.</p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">61</span><span class="k">class</span> <span class="nc">Discriminator</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> | ||||
| @ -204,8 +208,8 @@ is real or fake. You can get the probability by calculating the sigmoid of it.</ | ||||
|                 <a href='#section-8'>#</a> | ||||
|             </div> | ||||
|             <h2>Configurations</h2> | ||||
| <p>This extends MNIST configurations to get the data loaders and Training and validation loop | ||||
| configurations to simplify our implementation.</p> | ||||
| <p>This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.</p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">86</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="p">,</span> <span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div> | ||||
| @ -239,7 +243,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-10'>#</a> | ||||
|             </div> | ||||
|             <p>Initializations</p> | ||||
|             <p> Initializations</p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">108</span>    <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div> | ||||
| @ -267,9 +272,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-12'>#</a> | ||||
|             </div> | ||||
|             <p> | ||||
| <script type="math/tex; mode=display">z \sim p(z)</script> | ||||
| </p> | ||||
|             <p> <span class="katex-display"><span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∼</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mclose">)</span></span></span></span></span></p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">120</span>    <span class="k">def</span> <span class="nf">sample_z</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div> | ||||
| @ -291,7 +295,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-14'>#</a> | ||||
|             </div> | ||||
|             <p>Take a training step</p> | ||||
|             <p> Take a training step</p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">126</span>    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div> | ||||
| @ -302,7 +307,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-15'>#</a> | ||||
|             </div> | ||||
|             <p>Set model states</p> | ||||
|             <p>Set model states </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">132</span>        <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span> | ||||
| @ -314,7 +320,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-16'>#</a> | ||||
|             </div> | ||||
|             <p>Get MNIST images</p> | ||||
|             <p>Get MNIST images </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">136</span>        <span class="n">data</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div> | ||||
| @ -325,7 +332,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-17'>#</a> | ||||
|             </div> | ||||
|             <p>Increment step in training mode</p> | ||||
|             <p>Increment step in training mode </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">139</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span> | ||||
| @ -337,7 +345,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-18'>#</a> | ||||
|             </div> | ||||
|             <p>Train the discriminator</p> | ||||
|             <p>Train the discriminator </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">143</span>        <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">"discriminator"</span><span class="p">):</span></pre></div> | ||||
| @ -348,7 +357,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-19'>#</a> | ||||
|             </div> | ||||
|             <p>Get discriminator loss</p> | ||||
|             <p>Get discriminator loss </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">145</span>            <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">calc_discriminator_loss</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div> | ||||
| @ -359,7 +369,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-20'>#</a> | ||||
|             </div> | ||||
|             <p>Train</p> | ||||
|             <p>Train </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">148</span>            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span> | ||||
| @ -375,7 +386,9 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-21'>#</a> | ||||
|             </div> | ||||
|             <p>Train the generator once in every <code>discriminator_k</code></p> | ||||
|             <p>Train the generator once in every <code>discriminator_k</code> | ||||
|  </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">156</span>        <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_interval</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_k</span><span class="p">):</span> | ||||
| @ -388,7 +401,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-22'>#</a> | ||||
|             </div> | ||||
|             <p>Train</p> | ||||
|             <p>Train </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">161</span>                <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span> | ||||
| @ -406,7 +420,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-23'>#</a> | ||||
|             </div> | ||||
|             <p>Calculate discriminator loss</p> | ||||
|             <p> Calculate discriminator loss</p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">170</span>    <span class="k">def</span> <span class="nf">calc_discriminator_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span></pre></div> | ||||
| @ -432,7 +447,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-25'>#</a> | ||||
|             </div> | ||||
|             <p>Log stuff</p> | ||||
|             <p>Log stuff </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">181</span>        <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">"loss.discriminator.true."</span><span class="p">,</span> <span class="n">loss_true</span><span class="p">)</span> | ||||
| @ -447,7 +463,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-26'>#</a> | ||||
|             </div> | ||||
|             <p>Calculate generator loss</p> | ||||
|             <p> Calculate generator loss</p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">187</span>    <span class="k">def</span> <span class="nf">calc_generator_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div> | ||||
| @ -472,7 +489,8 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-28'>#</a> | ||||
|             </div> | ||||
|             <p>Log stuff</p> | ||||
|             <p>Log stuff </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">197</span>        <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">'generated'</span><span class="p">,</span> <span class="n">generated_images</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">6</span><span class="p">])</span> | ||||
| @ -510,9 +528,10 @@ configurations to simplify our implementation.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-30'>#</a> | ||||
|             </div> | ||||
|             <p>Setting exponent decay rate for first moment of gradient, | ||||
| $\beta_1$ to <code>0.5</code> is important. | ||||
| Default of <code>0.9</code> fails.</p> | ||||
|             <p>Setting exponent decay rate for first moment of gradient, <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> to <code>0.5</code> | ||||
|  is important. Default of <code>0.9</code> | ||||
|  fails. </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">222</span>    <span class="n">opt_conf</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span> | ||||
| @ -540,9 +559,10 @@ Default of <code>0.9</code> fails.</p> | ||||
|             <div class='section-link'> | ||||
|                 <a href='#section-32'>#</a> | ||||
|             </div> | ||||
|             <p>Setting exponent decay rate for first moment of gradient, | ||||
| $\beta_1$ to <code>0.5</code> is important. | ||||
| Default of <code>0.9</code> fails.</p> | ||||
|             <p>Setting exponent decay rate for first moment of gradient, <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> to <code>0.5</code> | ||||
|  is important. Default of <code>0.9</code> | ||||
|  fails. </p> | ||||
|  | ||||
|         </div> | ||||
|         <div class='code'> | ||||
|             <div class="highlight"><pre><span class="lineno">235</span>    <span class="n">opt_conf</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span> | ||||
| @ -581,24 +601,6 @@ Default of <code>0.9</code> fails.</p> | ||||
|         <a href="https://labml.ai">labml.ai</a> | ||||
|     </div> | ||||
| </div> | ||||
| <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML"> | ||||
| </script> | ||||
| <!-- MathJax configuration --> | ||||
| <script type="text/x-mathjax-config"> | ||||
|     MathJax.Hub.Config({ | ||||
|         tex2jax: { | ||||
|             inlineMath: [ ['$','$'] ], | ||||
|             displayMath: [ ['$$','$$'] ], | ||||
|             processEscapes: true, | ||||
|             processEnvironments: true | ||||
|         }, | ||||
|         // Center justify equations in code and markdown cells. Elsewhere | ||||
|         // we use CSS to left justify single line equations in code cells. | ||||
|         displayAlign: 'center', | ||||
|         "HTML-CSS": { fonts: ["TeX"] } | ||||
|     }); | ||||
|  | ||||
| </script> | ||||
| <script> | ||||
|     function handleImages() { | ||||
|         var images = document.querySelectorAll('p>img') | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri