Files
Varuna Jayasiri c4d2e8cd22 docs
2025-07-31 08:48:07 +05:30

643 lines
55 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This experiment generates MNIST images using multi-layer perceptron."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Generative Adversarial Networks experiment with MNIST"/>
<meta name="twitter:description" content="This experiment generates MNIST images using multi-layer perceptron."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/gan/original/experiment.html"/>
<meta property="og:title" content="Generative Adversarial Networks experiment with MNIST"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Generative Adversarial Networks experiment with MNIST"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Generative Adversarial Networks experiment with MNIST"/>
<meta property="og:description" content="This experiment generates MNIST images using multi-layer perceptron."/>
<title>Generative Adversarial Networks experiment with MNIST</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/gan/original/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">gan</a>
<a class="parent" href="index.html">original</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/original/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Generative Adversarial Networks experiment with MNIST</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
<span class="lineno">11</span>
<span class="lineno">12</span><span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">transforms</span>
<span class="lineno">13</span>
<span class="lineno">14</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">15</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">16</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">monit</span><span class="p">,</span> <span class="n">experiment</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span><span class="p">,</span> <span class="n">calculate</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.gan.original</span> <span class="kn">import</span> <span class="n">DiscriminatorLogitsLoss</span><span class="p">,</span> <span class="n">GeneratorLogitsLoss</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.datasets</span> <span class="kn">import</span> <span class="n">MNISTConfigs</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.optimizer</span> <span class="kn">import</span> <span class="n">OptimizerConfigs</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.trainer</span> <span class="kn">import</span> <span class="n">TrainValidConfigs</span><span class="p">,</span> <span class="n">BatchIndex</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">26</span><span class="k">def</span> <span class="nf">weights_init</span><span class="p">(</span><span class="n">m</span><span class="p">):</span>
<span class="lineno">27</span> <span class="n">classname</span> <span class="o">=</span> <span class="n">m</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
<span class="lineno">28</span> <span class="k">if</span> <span class="n">classname</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;Linear&#39;</span><span class="p">)</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
<span class="lineno">29</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.02</span><span class="p">)</span>
<span class="lineno">30</span> <span class="k">elif</span> <span class="n">classname</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;BatchNorm&#39;</span><span class="p">)</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
<span class="lineno">31</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.02</span><span class="p">)</span>
<span class="lineno">32</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">constant_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h3>Simple MLP Generator</h3>
<p>This has three linear layers of increasing size with <code class="highlight"><span></span><span class="n">LeakyReLU</span></code>
activations. The final layer has a <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">anh</span></span></span></span></span> activation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">44</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">45</span> <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">1024</span><span class="p">]</span>
<span class="lineno">46</span> <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">47</span> <span class="n">d_prev</span> <span class="o">=</span> <span class="mi">100</span>
<span class="lineno">48</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">layer_sizes</span><span class="p">:</span>
<span class="lineno">49</span> <span class="n">layers</span> <span class="o">=</span> <span class="n">layers</span> <span class="o">+</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_prev</span><span class="p">,</span> <span class="n">size</span><span class="p">),</span> <span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">)]</span>
<span class="lineno">50</span> <span class="n">d_prev</span> <span class="o">=</span> <span class="n">size</span>
<span class="lineno">51</span>
<span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_prev</span><span class="p">,</span> <span class="mi">28</span> <span class="o">*</span> <span class="mi">28</span><span class="p">),</span> <span class="n">nn</span><span class="o">.</span><span class="n">Tanh</span><span class="p">())</span>
<span class="lineno">53</span>
<span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">weights_init</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">57</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<h3>Simple MLP Discriminator</h3>
<p>This has three linear layers of decreasing size with <code class="highlight"><span></span><span class="n">LeakyReLU</span></code>
activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span><span class="k">class</span> <span class="nc">Discriminator</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">70</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">71</span> <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1024</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">256</span><span class="p">]</span>
<span class="lineno">72</span> <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">73</span> <span class="n">d_prev</span> <span class="o">=</span> <span class="mi">28</span> <span class="o">*</span> <span class="mi">28</span>
<span class="lineno">74</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">layer_sizes</span><span class="p">:</span>
<span class="lineno">75</span> <span class="n">layers</span> <span class="o">=</span> <span class="n">layers</span> <span class="o">+</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_prev</span><span class="p">,</span> <span class="n">size</span><span class="p">),</span> <span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">)]</span>
<span class="lineno">76</span> <span class="n">d_prev</span> <span class="o">=</span> <span class="n">size</span>
<span class="lineno">77</span>
<span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_prev</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">79</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">weights_init</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">82</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h2>Configurations</h2>
<p>This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MNISTConfigs</span><span class="p">,</span> <span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span>
<span class="lineno">94</span> <span class="n">dataset_transforms</span> <span class="o">=</span> <span class="s1">&#39;mnist_gan_transforms&#39;</span>
<span class="lineno">95</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span>
<span class="lineno">96</span>
<span class="lineno">97</span> <span class="n">is_save_models</span> <span class="o">=</span> <span class="kc">True</span>
<span class="lineno">98</span> <span class="n">discriminator</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;mlp&#39;</span>
<span class="lineno">99</span> <span class="n">generator</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;mlp&#39;</span>
<span class="lineno">100</span> <span class="n">generator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
<span class="lineno">101</span> <span class="n">discriminator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
<span class="lineno">102</span> <span class="n">generator_loss</span><span class="p">:</span> <span class="n">GeneratorLogitsLoss</span> <span class="o">=</span> <span class="s1">&#39;original&#39;</span>
<span class="lineno">103</span> <span class="n">discriminator_loss</span><span class="p">:</span> <span class="n">DiscriminatorLogitsLoss</span> <span class="o">=</span> <span class="s1">&#39;original&#39;</span>
<span class="lineno">104</span> <span class="n">label_smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span>
<span class="lineno">105</span> <span class="n">discriminator_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p> Initializations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">112</span>
<span class="lineno">113</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;loss.generator.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">114</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;loss.discriminator.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">115</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_image</span><span class="p">(</span><span class="s2">&quot;generated&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="mi">1</span> <span class="o">/</span> <span class="mi">100</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p> <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mclose">)</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="k">def</span> <span class="nf">sample_z</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p> Take a training step</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Set model states </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span>
<span class="lineno">130</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Get MNIST images </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">data</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Increment step in training mode </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">137</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Train the discriminator </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;discriminator&quot;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Get discriminator loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">calc_discriminator_loss</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Train </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">146</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="lineno">147</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="lineno">148</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
<span class="lineno">149</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;discriminator&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">)</span>
<span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Train the generator once in every <code class="highlight"><span></span><span class="n">discriminator_k</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_interval</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_k</span><span class="p">):</span>
<span class="lineno">154</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;generator&quot;</span><span class="p">):</span>
<span class="lineno">155</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">calc_generator_loss</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Train </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">159</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="lineno">160</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="lineno">161</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
<span class="lineno">162</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;generator&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span>
<span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="lineno">164</span>
<span class="lineno">165</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p> Calculate discriminator loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span> <span class="k">def</span> <span class="nf">calc_discriminator_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">latent</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_z</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="lineno">172</span> <span class="n">logits_true</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="lineno">173</span> <span class="n">logits_false</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">latent</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
<span class="lineno">174</span> <span class="n">loss_true</span><span class="p">,</span> <span class="n">loss_false</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span><span class="p">(</span><span class="n">logits_true</span><span class="p">,</span> <span class="n">logits_false</span><span class="p">)</span>
<span class="lineno">175</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_true</span> <span class="o">+</span> <span class="n">loss_false</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Log stuff </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.discriminator.true.&quot;</span><span class="p">,</span> <span class="n">loss_true</span><span class="p">)</span>
<span class="lineno">179</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.discriminator.false.&quot;</span><span class="p">,</span> <span class="n">loss_false</span><span class="p">)</span>
<span class="lineno">180</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.discriminator.&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>
<span class="lineno">181</span>
<span class="lineno">182</span> <span class="k">return</span> <span class="n">loss</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p> Calculate generator loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">def</span> <span class="nf">calc_generator_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">latent</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_z</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="lineno">189</span> <span class="n">generated_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">latent</span><span class="p">)</span>
<span class="lineno">190</span> <span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">generated_images</span><span class="p">)</span>
<span class="lineno">191</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_loss</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Log stuff </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;generated&#39;</span><span class="p">,</span> <span class="n">generated_images</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">6</span><span class="p">])</span>
<span class="lineno">195</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.generator.&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>
<span class="lineno">196</span>
<span class="lineno">197</span> <span class="k">return</span> <span class="n">loss</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">dataset_transforms</span><span class="p">)</span>
<span class="lineno">201</span><span class="k">def</span> <span class="nf">mnist_gan_transforms</span><span class="p">():</span>
<span class="lineno">202</span> <span class="k">return</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="lineno">203</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="lineno">204</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.5</span><span class="p">,),</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,))</span>
<span class="lineno">205</span> <span class="p">])</span>
<span class="lineno">206</span>
<span class="lineno">207</span>
<span class="lineno">208</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="p">)</span>
<span class="lineno">209</span><span class="k">def</span> <span class="nf">_discriminator_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span>
<span class="lineno">210</span> <span class="n">opt_conf</span> <span class="o">=</span> <span class="n">OptimizerConfigs</span><span class="p">()</span>
<span class="lineno">211</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="lineno">212</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
<span class="lineno">213</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">2.5e-4</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Setting exponent decay rate for first moment of gradient, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> to <code class="highlight"><span></span><span class="mf">0.5</span></code>
is important. Default of <code class="highlight"><span></span><span class="mf">0.9</span></code>
fails. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span>
<span class="lineno">218</span> <span class="k">return</span> <span class="n">opt_conf</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="p">)</span>
<span class="lineno">222</span><span class="k">def</span> <span class="nf">_generator_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span>
<span class="lineno">223</span> <span class="n">opt_conf</span> <span class="o">=</span> <span class="n">OptimizerConfigs</span><span class="p">()</span>
<span class="lineno">224</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="s1">&#39;Adam&#39;</span>
<span class="lineno">225</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
<span class="lineno">226</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">2.5e-4</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Setting exponent decay rate for first moment of gradient, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05278em">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> to <code class="highlight"><span></span><span class="mf">0.5</span></code>
is important. Default of <code class="highlight"><span></span><span class="mf">0.9</span></code>
fails. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">opt_conf</span><span class="o">.</span><span class="n">betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span>
<span class="lineno">231</span> <span class="k">return</span> <span class="n">opt_conf</span>
<span class="lineno">232</span>
<span class="lineno">233</span>
<span class="lineno">234</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">generator</span><span class="p">,</span> <span class="s1">&#39;mlp&#39;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">c</span><span class="p">:</span> <span class="n">Generator</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
<span class="lineno">235</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">discriminator</span><span class="p">,</span> <span class="s1">&#39;mlp&#39;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">c</span><span class="p">:</span> <span class="n">Discriminator</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
<span class="lineno">236</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">generator_loss</span><span class="p">,</span> <span class="s1">&#39;original&#39;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">c</span><span class="p">:</span> <span class="n">GeneratorLogitsLoss</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">label_smoothing</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
<span class="lineno">237</span><span class="n">calculate</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">discriminator_loss</span><span class="p">,</span> <span class="s1">&#39;original&#39;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">c</span><span class="p">:</span> <span class="n">DiscriminatorLogitsLoss</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">label_smoothing</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">240</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="lineno">241</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span>
<span class="lineno">242</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;mnist_gan&#39;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">&#39;test&#39;</span><span class="p">)</span>
<span class="lineno">243</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span>
<span class="lineno">244</span> <span class="p">{</span><span class="s1">&#39;label_smoothing&#39;</span><span class="p">:</span> <span class="mf">0.01</span><span class="p">})</span>
<span class="lineno">245</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">246</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
<span class="lineno">247</span>
<span class="lineno">248</span>
<span class="lineno">249</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">250</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>