Files
Varuna Jayasiri e59200971b docs
2021-02-28 18:05:21 +05:30

1957 lines
147 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A simple PyTorch implementation/tutorial of Cycle GAN introduced in paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Cycle GAN"/>
<meta name="twitter:description" content="A simple PyTorch implementation/tutorial of Cycle GAN introduced in paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/gan/cycle_gan.html"/>
<meta property="og:title" content="Cycle GAN"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Cycle GAN"/>
<meta property="og:description" content="A simple PyTorch implementation/tutorial of Cycle GAN introduced in paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks."/>
<title>Cycle GAN</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/gan/cycle_gan.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">gan</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/gan/cycle_gan.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
rel="nofollow">
<img alt="Join Slact"
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Cycle GAN</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of the paper
<a href="https://arxiv.org/abs/1703.10593">Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks</a>.</p>
<p>I&rsquo;ve taken pieces of code from <a href="https://github.com/eriklindernoren/PyTorch-GAN">eriklindernoren/PyTorch-GAN</a>.
It is a very good resource if you want to checkout other GAN variations too.</p>
<p>Cycle GAN does image-to-image translation.
It trains a model to translate an image from given distribution to another, say, images of class A and B.
Images of a certain distribution could be things like images of a certain style, or nature.
The models do not need paired images between A and B.
Just a set of images of each class is enough.
This works very well on changing between image styles, lighting changes, pattern changes, etc.
For example, changing summer to winter, painting style to photos, and horses to zebras.</p>
<p>Cycle GAN trains two generator models and two discriminator models.
One generator translates images from A to B and the other from B to A.
The discriminators test whether the generated images look real.</p>
<p>This file contains the model code as well as the training code.
We also have a Google Colab notebook.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/gan/cycle_gan.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/93b11a665d6811ebaac80242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">36</span><span></span><span class="kn">import</span> <span class="nn">itertools</span>
<span class="lineno">37</span><span class="kn">import</span> <span class="nn">random</span>
<span class="lineno">38</span><span class="kn">import</span> <span class="nn">zipfile</span>
<span class="lineno">39</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span>
<span class="lineno">40</span>
<span class="lineno">41</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">42</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">43</span><span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transforms</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">Dataset</span>
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">torchvision.utils</span> <span class="kn">import</span> <span class="n">make_grid</span>
<span class="lineno">47</span>
<span class="lineno">48</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">monit</span>
<span class="lineno">49</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span>
<span class="lineno">50</span><span class="kn">from</span> <span class="nn">labml.utils.download</span> <span class="kn">import</span> <span class="n">download_file</span>
<span class="lineno">51</span><span class="kn">from</span> <span class="nn">labml.utils.pytorch</span> <span class="kn">import</span> <span class="n">get_modules</span>
<span class="lineno">52</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
<span class="lineno">53</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p>The generator is a residual network.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span><span class="k">class</span> <span class="nc">GeneratorResNet</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_residual_blocks</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">62</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>This first block runs a $7\times7$ convolution and maps the image to
a feature map.
The output feature map has the same height and width because we have
a padding of $3$.
Reflection padding is used because it gives better image quality at edges.</p>
<p><code>inplace=True</code> in <code>ReLU</code> saves a little bit of memory.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">out_features</span> <span class="o">=</span> <span class="mi">64</span>
<span class="lineno">71</span> <span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
<span class="lineno">72</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">input_channels</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding_mode</span><span class="o">=</span><span class="s1">&#39;reflect&#39;</span><span class="p">),</span>
<span class="lineno">73</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">out_features</span><span class="p">),</span>
<span class="lineno">74</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="lineno">75</span> <span class="p">]</span>
<span class="lineno">76</span> <span class="n">in_features</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>We down-sample with two $3 \times 3$ convolutions
with stride of 2</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">):</span>
<span class="lineno">81</span> <span class="n">out_features</span> <span class="o">*=</span> <span class="mi">2</span>
<span class="lineno">82</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span>
<span class="lineno">83</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
<span class="lineno">84</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">out_features</span><span class="p">),</span>
<span class="lineno">85</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="lineno">86</span> <span class="p">]</span>
<span class="lineno">87</span> <span class="n">in_features</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>We take this through <code>n_residual_blocks</code>.
This module is defined below.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_residual_blocks</span><span class="p">):</span>
<span class="lineno">92</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">ResidualBlock</span><span class="p">(</span><span class="n">out_features</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Then the resulting feature map is up-sampled
to match the original image height and width.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">):</span>
<span class="lineno">97</span> <span class="n">out_features</span> <span class="o">//=</span> <span class="mi">2</span>
<span class="lineno">98</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span>
<span class="lineno">99</span> <span class="n">nn</span><span class="o">.</span><span class="n">Upsample</span><span class="p">(</span><span class="n">scale_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
<span class="lineno">100</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
<span class="lineno">101</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">out_features</span><span class="p">),</span>
<span class="lineno">102</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="lineno">103</span> <span class="p">]</span>
<span class="lineno">104</span> <span class="n">in_features</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Finally we map the feature map to an RGB image</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">input_channels</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding_mode</span><span class="o">=</span><span class="s1">&#39;reflect&#39;</span><span class="p">),</span> <span class="n">nn</span><span class="o">.</span><span class="n">Tanh</span><span class="p">()]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Create a sequential module with the layers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Initialize weights to $\mathcal{N}(0, 0.2)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">weights_init_normal</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">116</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>This is the residual block, with two convolution layers.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span><span class="k">class</span> <span class="nc">ResidualBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">125</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">block</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="lineno">127</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding_mode</span><span class="o">=</span><span class="s1">&#39;reflect&#39;</span><span class="p">),</span>
<span class="lineno">128</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">),</span>
<span class="lineno">129</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="lineno">130</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding_mode</span><span class="o">=</span><span class="s1">&#39;reflect&#39;</span><span class="p">),</span>
<span class="lineno">131</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">),</span>
<span class="lineno">132</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
<span class="lineno">133</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">136</span> <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">block</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>This is the discriminator.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span><span class="k">class</span> <span class="nc">Discriminator</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]):</span>
<span class="lineno">145</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">146</span> <span class="n">channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">input_shape</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Output of the discriminator is also a map of probabilities*
whether each region of the image is real or generated</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">height</span> <span class="o">//</span> <span class="mi">2</span> <span class="o">**</span> <span class="mi">4</span><span class="p">,</span> <span class="n">width</span> <span class="o">//</span> <span class="mi">2</span> <span class="o">**</span> <span class="mi">4</span><span class="p">)</span>
<span class="lineno">151</span>
<span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Each of these blocks will shrink the height and width by a factor of 2</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">DiscriminatorBlock</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
<span class="lineno">155</span> <span class="n">DiscriminatorBlock</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span>
<span class="lineno">156</span> <span class="n">DiscriminatorBlock</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span>
<span class="lineno">157</span> <span class="n">DiscriminatorBlock</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Zero pad on top and left to keep the output height and width same
with the $4 \times 4$ kernel</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">nn</span><span class="o">.</span><span class="n">ZeroPad2d</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)),</span>
<span class="lineno">161</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">162</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Initialize weights to $\mathcal{N}(0, 0.2)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">weights_init_normal</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
<span class="lineno">168</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>This is the discriminator block module.
It does a convolution, an optional normalization, and a leaky ReLU.</p>
<p>It shrinks the height and width of the input feature map by half.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span><span class="k">class</span> <span class="nc">DiscriminatorBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_filters</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_filters</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">normalize</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
<span class="lineno">180</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">181</span> <span class="n">layers</span> <span class="o">=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_filters</span><span class="p">,</span> <span class="n">out_filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
<span class="lineno">182</span> <span class="k">if</span> <span class="n">normalize</span><span class="p">:</span>
<span class="lineno">183</span> <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">out_filters</span><span class="p">))</span>
<span class="lineno">184</span> <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">188</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Initialize convolution layer weights to $\mathcal{N}(0, 0.2)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span><span class="k">def</span> <span class="nf">weights_init_normal</span><span class="p">(</span><span class="n">m</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">classname</span> <span class="o">=</span> <span class="n">m</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
<span class="lineno">196</span> <span class="k">if</span> <span class="n">classname</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">&quot;Conv&quot;</span><span class="p">)</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
<span class="lineno">197</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.02</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Load an image and change to RGB if in grey-scale.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span><span class="k">def</span> <span class="nf">load_image</span><span class="p">(</span><span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
<span class="lineno">205</span> <span class="k">if</span> <span class="n">image</span><span class="o">.</span><span class="n">mode</span> <span class="o">!=</span> <span class="s1">&#39;RGB&#39;</span><span class="p">:</span>
<span class="lineno">206</span> <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="s2">&quot;RGB&quot;</span><span class="p">,</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span><span class="p">)</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
<span class="lineno">207</span>
<span class="lineno">208</span> <span class="k">return</span> <span class="n">image</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<h3>Dataset to load images</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span><span class="k">class</span> <span class="nc">ImageDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<h4>Download dataset and extract data</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">216</span> <span class="nd">@staticmethod</span>
<span class="lineno">217</span> <span class="k">def</span> <span class="nf">download</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>URL</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="n">url</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/</span><span class="si">{</span><span class="n">dataset_name</span><span class="si">}</span><span class="s1">.zip&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Download folder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="n">root</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;cycle_gan&#39;</span>
<span class="lineno">225</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">root</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
<span class="lineno">226</span> <span class="n">root</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Download destination</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">archive</span> <span class="o">=</span> <span class="n">root</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">dataset_name</span><span class="si">}</span><span class="s1">.zip&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Download file (generally ~100MB)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">download_file</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">archive</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Extract the archive</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span> <span class="k">with</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="n">archive</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="lineno">233</span> <span class="n">f</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="n">root</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<h4>Initialize the dataset</h4>
<ul>
<li><code>dataset_name</code> is the name of the dataset</li>
<li><code>transforms_</code> is the set of image transforms</li>
<li><code>mode</code> is either <code>train</code> or <code>test</code></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">transforms_</span><span class="p">,</span> <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Dataset path</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">244</span> <span class="n">root</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;cycle_gan&#39;</span> <span class="o">/</span> <span class="n">dataset_name</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Download if missing</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">root</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
<span class="lineno">247</span> <span class="bp">self</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Image transforms</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">(</span><span class="n">transforms_</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Get image paths</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">253</span> <span class="n">path_a</span> <span class="o">=</span> <span class="n">root</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">mode</span><span class="si">}</span><span class="s1">A&#39;</span>
<span class="lineno">254</span> <span class="n">path_b</span> <span class="o">=</span> <span class="n">root</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">mode</span><span class="si">}</span><span class="s1">B&#39;</span>
<span class="lineno">255</span> <span class="bp">self</span><span class="o">.</span><span class="n">files_a</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="n">path_a</span><span class="o">.</span><span class="n">iterdir</span><span class="p">())</span>
<span class="lineno">256</span> <span class="bp">self</span><span class="o">.</span><span class="n">files_b</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="n">path_b</span><span class="o">.</span><span class="n">iterdir</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">258</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Return a pair of images.
These pairs get batched together, and they do not act like pairs in training.
So it is kind of ok that we always keep giving the same pair.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">262</span> <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;x&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">load_image</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_a</span><span class="p">[</span><span class="n">index</span> <span class="o">%</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_a</span><span class="p">)])),</span>
<span class="lineno">263</span> <span class="s2">&quot;y&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">load_image</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_b</span><span class="p">[</span><span class="n">index</span> <span class="o">%</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_b</span><span class="p">)]))}</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Number of images in the dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">267</span> <span class="k">return</span> <span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_a</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_b</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<h3>Replay Buffer</h3>
<p>Replay buffer is used to train the discriminator.
Generated images are added to the replay buffer and sampled from it.</p>
<p>The replay buffer returns the newly added image with a probability of $0.5$.
Otherwise, it sends an older generated image and replaces the older image
with the newly generated image.</p>
<p>This is done to reduce model oscillation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">270</span><span class="k">class</span> <span class="nc">ReplayBuffer</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">284</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">50</span><span class="p">):</span>
<span class="lineno">285</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span> <span class="o">=</span> <span class="n">max_size</span>
<span class="lineno">286</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Add/retrieve an image</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="k">def</span> <span class="nf">push_and_pop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">290</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="lineno">291</span> <span class="n">res</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">292</span> <span class="k">for</span> <span class="n">element</span> <span class="ow">in</span> <span class="n">data</span><span class="p">:</span>
<span class="lineno">293</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">:</span>
<span class="lineno">294</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">element</span><span class="p">)</span>
<span class="lineno">295</span> <span class="n">res</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">element</span><span class="p">)</span>
<span class="lineno">296</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">297</span> <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mf">0.5</span><span class="p">:</span>
<span class="lineno">298</span> <span class="n">i</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">299</span> <span class="n">res</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="lineno">300</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">element</span>
<span class="lineno">301</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">302</span> <span class="n">res</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">element</span><span class="p">)</span>
<span class="lineno">303</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<h2>Configurations</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">306</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p><code>DeviceConfigs</code> will pick a GPU if available</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">310</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Hyper-parameters</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">313</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">200</span>
<span class="lineno">314</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;monet2photo&#39;</span>
<span class="lineno">315</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
<span class="lineno">316</span>
<span class="lineno">317</span> <span class="n">data_loader_workers</span> <span class="o">=</span> <span class="mi">8</span>
<span class="lineno">318</span>
<span class="lineno">319</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0002</span>
<span class="lineno">320</span> <span class="n">adam_betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span>
<span class="lineno">321</span> <span class="n">decay_start</span> <span class="o">=</span> <span class="mi">100</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>The paper suggests using a least-squares loss instead of
negative log-likelihood, at it is found to be more stable.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">325</span> <span class="n">gan_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">MSELoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>L1 loss is used for cycle loss and identity loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">328</span> <span class="n">cycle_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">L1Loss</span><span class="p">()</span>
<span class="lineno">329</span> <span class="n">identity_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">L1Loss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Image dimensions</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">332</span> <span class="n">img_height</span> <span class="o">=</span> <span class="mi">256</span>
<span class="lineno">333</span> <span class="n">img_width</span> <span class="o">=</span> <span class="mi">256</span>
<span class="lineno">334</span> <span class="n">img_channels</span> <span class="o">=</span> <span class="mi">3</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Number of residual blocks in the generator</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">337</span> <span class="n">n_residual_blocks</span> <span class="o">=</span> <span class="mi">9</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Loss coefficients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">340</span> <span class="n">cyclic_loss_coefficient</span> <span class="o">=</span> <span class="mf">10.0</span>
<span class="lineno">341</span> <span class="n">identity_loss_coefficient</span> <span class="o">=</span> <span class="mf">5.</span>
<span class="lineno">342</span>
<span class="lineno">343</span> <span class="n">sample_interval</span> <span class="o">=</span> <span class="mi">500</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>Models</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">346</span> <span class="n">generator_xy</span><span class="p">:</span> <span class="n">GeneratorResNet</span>
<span class="lineno">347</span> <span class="n">generator_yx</span><span class="p">:</span> <span class="n">GeneratorResNet</span>
<span class="lineno">348</span> <span class="n">discriminator_x</span><span class="p">:</span> <span class="n">Discriminator</span>
<span class="lineno">349</span> <span class="n">discriminator_y</span><span class="p">:</span> <span class="n">Discriminator</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Optimizers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">352</span> <span class="n">generator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
<span class="lineno">353</span> <span class="n">discriminator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Learning rate schedules</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">356</span> <span class="n">generator_lr_scheduler</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">LambdaLR</span>
<span class="lineno">357</span> <span class="n">discriminator_lr_scheduler</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">LambdaLR</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Data loaders</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">360</span> <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span>
<span class="lineno">361</span> <span class="n">valid_dataloader</span><span class="p">:</span> <span class="n">DataLoader</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>Generate samples from test set and save them</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">363</span> <span class="k">def</span> <span class="nf">sample_images</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">365</span> <span class="n">batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_dataloader</span><span class="p">))</span>
<span class="lineno">366</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="lineno">367</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="lineno">368</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="lineno">369</span> <span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="s1">&#39;x&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="s1">&#39;y&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">370</span> <span class="n">gen_y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
<span class="lineno">371</span> <span class="n">gen_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="p">(</span><span class="n">data_y</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>Arrange images along x-axis</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">374</span> <span class="n">data_x</span> <span class="o">=</span> <span class="n">make_grid</span><span class="p">(</span><span class="n">data_x</span><span class="p">,</span> <span class="n">nrow</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">375</span> <span class="n">data_y</span> <span class="o">=</span> <span class="n">make_grid</span><span class="p">(</span><span class="n">data_y</span><span class="p">,</span> <span class="n">nrow</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">376</span> <span class="n">gen_x</span> <span class="o">=</span> <span class="n">make_grid</span><span class="p">(</span><span class="n">gen_x</span><span class="p">,</span> <span class="n">nrow</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="lineno">377</span> <span class="n">gen_y</span> <span class="o">=</span> <span class="n">make_grid</span><span class="p">(</span><span class="n">gen_y</span><span class="p">,</span> <span class="n">nrow</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Arrange images along y-axis</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">380</span> <span class="n">image_grid</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">data_x</span><span class="p">,</span> <span class="n">gen_y</span><span class="p">,</span> <span class="n">data_y</span><span class="p">,</span> <span class="n">gen_x</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>Show samples</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">383</span> <span class="n">plot_image</span><span class="p">(</span><span class="n">image_grid</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<h2>Initialize models and data loaders</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">385</span> <span class="k">def</span> <span class="nf">initialize</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">389</span> <span class="n">input_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_width</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>Create the models</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">392</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span> <span class="o">=</span> <span class="n">GeneratorResNet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_residual_blocks</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">393</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span> <span class="o">=</span> <span class="n">GeneratorResNet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_residual_blocks</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">394</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span> <span class="o">=</span> <span class="n">Discriminator</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">395</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span> <span class="o">=</span> <span class="n">Discriminator</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>Create the optmizers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">398</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
<span class="lineno">399</span> <span class="n">itertools</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">parameters</span><span class="p">()),</span>
<span class="lineno">400</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span><span class="p">)</span>
<span class="lineno">401</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
<span class="lineno">402</span> <span class="n">itertools</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span><span class="o">.</span><span class="n">parameters</span><span class="p">()),</span>
<span class="lineno">403</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>Create the learning rate schedules.
The learning rate stars flat until <code>decay_start</code> epochs,
and then linearly reduce to $0$ at end of training.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">408</span> <span class="n">decay_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">decay_start</span>
<span class="lineno">409</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_lr_scheduler</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">LambdaLR</span><span class="p">(</span>
<span class="lineno">410</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="p">,</span> <span class="n">lr_lambda</span><span class="o">=</span><span class="k">lambda</span> <span class="n">e</span><span class="p">:</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">e</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">decay_start</span><span class="p">)</span> <span class="o">/</span> <span class="n">decay_epochs</span><span class="p">)</span>
<span class="lineno">411</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_lr_scheduler</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">LambdaLR</span><span class="p">(</span>
<span class="lineno">412</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="p">,</span> <span class="n">lr_lambda</span><span class="o">=</span><span class="k">lambda</span> <span class="n">e</span><span class="p">:</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">e</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">decay_start</span><span class="p">)</span> <span class="o">/</span> <span class="n">decay_epochs</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>Image transformations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">415</span> <span class="n">transforms_</span> <span class="o">=</span> <span class="p">[</span>
<span class="lineno">416</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_height</span> <span class="o">*</span> <span class="mf">1.12</span><span class="p">),</span> <span class="n">Image</span><span class="o">.</span><span class="n">BICUBIC</span><span class="p">),</span>
<span class="lineno">417</span> <span class="n">transforms</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">img_height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_width</span><span class="p">)),</span>
<span class="lineno">418</span> <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
<span class="lineno">419</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="lineno">420</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)),</span>
<span class="lineno">421</span> <span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>Training data loader</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">424</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span>
<span class="lineno">425</span> <span class="n">ImageDataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span><span class="p">,</span> <span class="n">transforms_</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">),</span>
<span class="lineno">426</span> <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
<span class="lineno">427</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">428</span> <span class="n">num_workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_loader_workers</span><span class="p">,</span>
<span class="lineno">429</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>Validation data loader</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">432</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span>
<span class="lineno">433</span> <span class="n">ImageDataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span><span class="p">,</span> <span class="n">transforms_</span><span class="p">,</span> <span class="s2">&quot;test&quot;</span><span class="p">),</span>
<span class="lineno">434</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
<span class="lineno">435</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">436</span> <span class="n">num_workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_loader_workers</span><span class="p">,</span>
<span class="lineno">437</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<h2>Training</h2>
<p>We aim to solve:
<script type="math/tex; mode=display">G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)</script>
</p>
<p>where,
$G$ translates images from $X \rightarrow Y$,
$F$ translates images from $Y \rightarrow X$,
$D_X$ tests if images are from $X$ space,
$D_Y$ tests if images are from $Y$ space, and
<script type="math/tex; mode=display">\begin{align}
\mathcal{L}(G, F, D_X, D_Y)
&= \mathcal{L}_{GAN}(G, D_Y, X, Y) \\
&+ \mathcal{L}_{GAN}(F, D_X, Y, X) \\
&+ \lambda_1 \mathcal{L}_{cyc}(G, F) \\
&+ \lambda_2 \mathcal{L}_{identity}(G, F) \\
\\
\mathcal{L}_{GAN}(G, F, D_Y, X, Y)
&= \mathbb{E}_{y \sim p_{data}(y)} \Big[log D_Y(y)\Big] \\
&+ \mathbb{E}_{x \sim p_{data}(x)} \bigg[log\Big(1 - D_Y(G(x))\Big)\bigg] \\
&+ \mathbb{E}_{x \sim p_{data}(x)} \Big[log D_X(x)\Big] \\
&+ \mathbb{E}_{y \sim p_{data}(y)} \bigg[log\Big(1 - D_X(F(y))\Big)\bigg] \\
\\
\mathcal{L}_{cyc}(G, F)
&= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(G(x)) - x \lVert_1\Big] \\
&+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(F(y)) - y \rVert_1\Big] \\
\\
\mathcal{L}_{identity}(G, F)
&= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(x) - x \lVert_1\Big] \\
&+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(y) - y \rVert_1\Big] \\
\end{align}</script>
</p>
<p>$\mathcal{L}_{GAN}$ is the generative adversarial loss from the original
GAN paper.</p>
<p>$\mathcal{L}_{cyc}$ is the cyclic loss, where we try to get $F(G(x))$ to be similar to $x$,
and $G(F(y))$ to be similar to $y$.
Basically if the two generators (transformations) are applied in series it should give back the
original image.
This is the main contribution of this paper.
It trains the generators to generate an image of the other distribution that is similar to
the original image.
Without this loss $G(x)$ could generate anything that&rsquo;s from the distribution of $Y$.
Now it needs to generate something from the distribution of $Y$ but still has properties of $x$,
so that $F(G(x)$ can re-generate something like $x$.</p>
<p>$\mathcal{L}_{cyc}$ is the identity loss.
This was used to encourage the mapping to preserve color composition between
the input and the output.</p>
<p>To solve $G^{*}, F^{*}$,
discriminators $D_X$ and $D_Y$ should <strong>ascend</strong> on the gradient,
<script type="math/tex; mode=display">\begin{align}
\nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m
&\Bigg[
\log D_Y\Big(y^{(i)}\Big) \\
&+ \log \Big(1 - D_Y\Big(G\Big(x^{(i)}\Big)\Big)\Big) \\
&+ \log D_X\Big(x^{(i)}\Big) \\
& +\log\Big(1 - D_X\Big(F\Big(y^{(i)}\Big)\Big)\Big)
\Bigg]
\end{align}</script>
That is descend on <em>negative</em> log-likelihood loss.</p>
<p>In order to stabilize the training the negative log- likelihood objective
was replaced by a least-squared loss -
the least-squared error of discriminator, labelling real images with 1,
and generated images with 0.
So we want to descend on the gradient,
<script type="math/tex; mode=display">\begin{align}
\nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m
&\Bigg[
\bigg(D_Y\Big(y^{(i)}\Big) - 1\bigg)^2 \\
&+ D_Y\Big(G\Big(x^{(i)}\Big)\Big)^2 \\
&+ \bigg(D_X\Big(x^{(i)}\Big) - 1\bigg)^2 \\
&+ D_X\Big(F\Big(y^{(i)}\Big)\Big)^2
\Bigg]
\end{align}</script>
</p>
<p>We use least-squares for generators also.
The generators should <em>descend</em> on the gradient,
<script type="math/tex; mode=display">\begin{align}
\nabla_{\theta_{F, G}} \frac{1}{m} \sum_{i=1}^m
&\Bigg[
\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2 \\
&+ \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2 \\
&+ \mathcal{L}_{cyc}(G, F)
+ \mathcal{L}_{identity}(G, F)
\Bigg]
\end{align}</script>
</p>
<p>We use <code>generator_xy</code> for $G$ and <code>generator_yx$ for $F$.
We use</code>discriminator_x$ for $D_X$ and <code>discriminator_y</code> for $D_Y$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">439</span> <span class="k">def</span> <span class="nf">run</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>Replay buffers to keep generated samples</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">536</span> <span class="n">gen_x_buffer</span> <span class="o">=</span> <span class="n">ReplayBuffer</span><span class="p">()</span>
<span class="lineno">537</span> <span class="n">gen_y_buffer</span> <span class="o">=</span> <span class="n">ReplayBuffer</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>Loop through epochs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">540</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epochs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>Loop through the dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">542</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">enum</span><span class="p">(</span><span class="s1">&#39;Train&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>Move images to the device</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">544</span> <span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="s1">&#39;x&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="s1">&#39;y&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<p>true labels equal to $1$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">547</span> <span class="n">true_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">data_x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="o">.</span><span class="n">output_shape</span><span class="p">,</span>
<span class="lineno">548</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<p>false labels equal to $0$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">550</span> <span class="n">false_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">data_x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="o">.</span><span class="n">output_shape</span><span class="p">,</span>
<span class="lineno">551</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<p>Train the generators.
This returns the generated images.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">555</span> <span class="n">gen_x</span><span class="p">,</span> <span class="n">gen_y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimize_generators</span><span class="p">(</span><span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span><span class="p">,</span> <span class="n">true_labels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>Train discriminators</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">558</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimize_discriminator</span><span class="p">(</span><span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span><span class="p">,</span>
<span class="lineno">559</span> <span class="n">gen_x_buffer</span><span class="o">.</span><span class="n">push_and_pop</span><span class="p">(</span><span class="n">gen_x</span><span class="p">),</span> <span class="n">gen_y_buffer</span><span class="o">.</span><span class="n">push_and_pop</span><span class="p">(</span><span class="n">gen_y</span><span class="p">),</span>
<span class="lineno">560</span> <span class="n">true_labels</span><span class="p">,</span> <span class="n">false_labels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>Save training statistics and increment the global step counter</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">563</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span>
<span class="lineno">564</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data_x</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">data_y</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>Save images at intervals</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">567</span> <span class="n">batches_done</span> <span class="o">=</span> <span class="n">epoch</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span><span class="p">)</span> <span class="o">+</span> <span class="n">i</span>
<span class="lineno">568</span> <span class="k">if</span> <span class="n">batches_done</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>Save models when sampling images</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">570</span> <span class="n">experiment</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<p>Sample images</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">572</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_images</span><span class="p">(</span><span class="n">batches_done</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>Update learning rates</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">575</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_lr_scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="lineno">576</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_lr_scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>New line</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">578</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<h3>Optimize the generators with identity, gan and cycle losses.</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">580</span> <span class="k">def</span> <span class="nf">optimize_generators</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">data_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">true_labels</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>Change to training mode</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">586</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
<span class="lineno">587</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">train</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>Identity loss
<script type="math/tex; mode=display">\lVert F(G(x^{(i)})) - x^{(i)} \lVert_1\
\lVert G(F(y^{(i)})) - y^{(i)} \rVert_1</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">592</span> <span class="n">loss_identity</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">identity_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="p">(</span><span class="n">data_x</span><span class="p">),</span> <span class="n">data_x</span><span class="p">)</span> <span class="o">+</span>
<span class="lineno">593</span> <span class="bp">self</span><span class="o">.</span><span class="n">identity_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">data_y</span><span class="p">),</span> <span class="n">data_y</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>Generate images $G(x)$ and $F(y)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">596</span> <span class="n">gen_y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
<span class="lineno">597</span> <span class="n">gen_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="p">(</span><span class="n">data_y</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>GAN loss
<script type="math/tex; mode=display">\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2
+ \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">602</span> <span class="n">loss_gan</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span><span class="p">(</span><span class="n">gen_y</span><span class="p">),</span> <span class="n">true_labels</span><span class="p">)</span> <span class="o">+</span>
<span class="lineno">603</span> <span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="p">(</span><span class="n">gen_x</span><span class="p">),</span> <span class="n">true_labels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>Cycle loss
<script type="math/tex; mode=display">
\lVert F(G(x^{(i)})) - x^{(i)} \lVert_1 +
\lVert G(F(y^{(i)})) - y^{(i)} \rVert_1
</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">610</span> <span class="n">loss_cycle</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cycle_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="p">(</span><span class="n">gen_y</span><span class="p">),</span> <span class="n">data_x</span><span class="p">)</span> <span class="o">+</span>
<span class="lineno">611</span> <span class="bp">self</span><span class="o">.</span><span class="n">cycle_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">gen_x</span><span class="p">),</span> <span class="n">data_y</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<p>Total loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">614</span> <span class="n">loss_generator</span> <span class="o">=</span> <span class="p">(</span><span class="n">loss_gan</span> <span class="o">+</span>
<span class="lineno">615</span> <span class="bp">self</span><span class="o">.</span><span class="n">cyclic_loss_coefficient</span> <span class="o">*</span> <span class="n">loss_cycle</span> <span class="o">+</span>
<span class="lineno">616</span> <span class="bp">self</span><span class="o">.</span><span class="n">identity_loss_coefficient</span> <span class="o">*</span> <span class="n">loss_identity</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>Take a step in the optimizer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">619</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="lineno">620</span> <span class="n">loss_generator</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="lineno">621</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p>Log losses</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">624</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="s1">&#39;loss.generator&#39;</span><span class="p">:</span> <span class="n">loss_generator</span><span class="p">,</span>
<span class="lineno">625</span> <span class="s1">&#39;loss.generator.cycle&#39;</span><span class="p">:</span> <span class="n">loss_cycle</span><span class="p">,</span>
<span class="lineno">626</span> <span class="s1">&#39;loss.generator.gan&#39;</span><span class="p">:</span> <span class="n">loss_gan</span><span class="p">,</span>
<span class="lineno">627</span> <span class="s1">&#39;loss.generator.identity&#39;</span><span class="p">:</span> <span class="n">loss_identity</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>Return generated images</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">630</span> <span class="k">return</span> <span class="n">gen_x</span><span class="p">,</span> <span class="n">gen_y</span></pre></div>
</div>
</div>
<div class='section' id='section-98'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-98'>#</a>
</div>
<h3>Optimize the discriminators with gan loss.</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">632</span> <span class="k">def</span> <span class="nf">optimize_discriminator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">data_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">633</span> <span class="n">gen_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">gen_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">634</span> <span class="n">true_labels</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">false_labels</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-99'>
<div class='docs'>
<div class='section-link'>
<a href='#section-99'>#</a>
</div>
<p>GAN Loss
<script type="math/tex; mode=display">\begin{align}
\bigg(D_Y\Big(y ^ {(i)}\Big) - 1\bigg) ^ 2
+ D_Y\Big(G\Big(x ^ {(i)}\Big)\Big) ^ 2 + \\
\bigg(D_X\Big(x ^ {(i)}\Big) - 1\bigg) ^ 2
+ D_X\Big(F\Big(y ^ {(i)}\Big)\Big) ^ 2
\end{align}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">645</span> <span class="n">loss_discriminator</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="p">(</span><span class="n">data_x</span><span class="p">),</span> <span class="n">true_labels</span><span class="p">)</span> <span class="o">+</span>
<span class="lineno">646</span> <span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="p">(</span><span class="n">gen_x</span><span class="p">),</span> <span class="n">false_labels</span><span class="p">)</span> <span class="o">+</span>
<span class="lineno">647</span> <span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span><span class="p">(</span><span class="n">data_y</span><span class="p">),</span> <span class="n">true_labels</span><span class="p">)</span> <span class="o">+</span>
<span class="lineno">648</span> <span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span><span class="p">(</span><span class="n">gen_y</span><span class="p">),</span> <span class="n">false_labels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-100'>
<div class='docs'>
<div class='section-link'>
<a href='#section-100'>#</a>
</div>
<p>Take a step in the optimizer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">651</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="lineno">652</span> <span class="n">loss_discriminator</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="lineno">653</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-101'>
<div class='docs'>
<div class='section-link'>
<a href='#section-101'>#</a>
</div>
<p>Log losses</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">656</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="s1">&#39;loss.discriminator&#39;</span><span class="p">:</span> <span class="n">loss_discriminator</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-102'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-102'>#</a>
</div>
<h2>Train Cycle GAN</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">659</span><span class="k">def</span> <span class="nf">train</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-103'>
<div class='docs'>
<div class='section-link'>
<a href='#section-103'>#</a>
</div>
<p>Create configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">664</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-104'>
<div class='docs'>
<div class='section-link'>
<a href='#section-104'>#</a>
</div>
<p>Create an experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">666</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;cycle_gan&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-105'>
<div class='docs'>
<div class='section-link'>
<a href='#section-105'>#</a>
</div>
<p>Calculate configurations.
It will calculate <code>conf.run</code> and all other configs required by it.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">669</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">&#39;dataset_name&#39;</span><span class="p">:</span> <span class="s1">&#39;summer2winter_yosemite&#39;</span><span class="p">})</span>
<span class="lineno">670</span> <span class="n">conf</span><span class="o">.</span><span class="n">initialize</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-106'>
<div class='docs'>
<div class='section-link'>
<a href='#section-106'>#</a>
</div>
<p>Register models for saving and loading.
<code>get_modules</code> gives a dictionary of <code>nn.Modules</code> in <code>conf</code>.
You can also specify a custom dictionary of models.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">675</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-107'>
<div class='docs'>
<div class='section-link'>
<a href='#section-107'>#</a>
</div>
<p>Start and watch the experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">677</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-108'>
<div class='docs'>
<div class='section-link'>
<a href='#section-108'>#</a>
</div>
<p>Run the training</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">679</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-109'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-109'>#</a>
</div>
<h3>Plot an image with matplotlib</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">682</span><span class="k">def</span> <span class="nf">plot_image</span><span class="p">(</span><span class="n">img</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-110'>
<div class='docs'>
<div class='section-link'>
<a href='#section-110'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">686</span> <span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span></pre></div>
</div>
</div>
<div class='section' id='section-111'>
<div class='docs'>
<div class='section-link'>
<a href='#section-111'>#</a>
</div>
<p>Move tensor to CPU</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">689</span> <span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-112'>
<div class='docs'>
<div class='section-link'>
<a href='#section-112'>#</a>
</div>
<p>Get min and max values of the image for normalization</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">691</span> <span class="n">img_min</span><span class="p">,</span> <span class="n">img_max</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">min</span><span class="p">(),</span> <span class="n">img</span><span class="o">.</span><span class="n">max</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-113'>
<div class='docs'>
<div class='section-link'>
<a href='#section-113'>#</a>
</div>
<p>Scale image values to be [0&hellip;1]</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">693</span> <span class="n">img</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span> <span class="o">-</span> <span class="n">img_min</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">img_max</span> <span class="o">-</span> <span class="n">img_min</span> <span class="o">+</span> <span class="mf">1e-5</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-114'>
<div class='docs'>
<div class='section-link'>
<a href='#section-114'>#</a>
</div>
<p>We have to change the order of dimensions to HWC.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">695</span> <span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-115'>
<div class='docs'>
<div class='section-link'>
<a href='#section-115'>#</a>
</div>
<p>Show Image</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">697</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-116'>
<div class='docs'>
<div class='section-link'>
<a href='#section-116'>#</a>
</div>
<p>We don&rsquo;t need axes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">699</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-117'>
<div class='docs'>
<div class='section-link'>
<a href='#section-117'>#</a>
</div>
<p>Display</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">701</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-118'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-118'>#</a>
</div>
<h2>Evaluate trained Cycle GAN</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">704</span><span class="k">def</span> <span class="nf">evaluate</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-119'>
<div class='docs'>
<div class='section-link'>
<a href='#section-119'>#</a>
</div>
<p>Set the run UUID from the training run</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">709</span> <span class="n">trained_run_uuid</span> <span class="o">=</span> <span class="s1">&#39;f73c1164184711eb9190b74249275441&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-120'>
<div class='docs'>
<div class='section-link'>
<a href='#section-120'>#</a>
</div>
<p>Create configs object</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">711</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-121'>
<div class='docs'>
<div class='section-link'>
<a href='#section-121'>#</a>
</div>
<p>Create experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">713</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;cycle_gan_inference&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-122'>
<div class='docs'>
<div class='section-link'>
<a href='#section-122'>#</a>
</div>
<p>Load hyper parameters set for training</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">715</span> <span class="n">conf_dict</span> <span class="o">=</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load_configs</span><span class="p">(</span><span class="n">trained_run_uuid</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-123'>
<div class='docs'>
<div class='section-link'>
<a href='#section-123'>#</a>
</div>
<p>Calculate configurations. We specify the generators <code>'generator_xy', 'generator_yx'</code>
so that it only loads those and their dependencies.
Configs like <code>device</code> and <code>img_channels</code> will be calculated, since these are required by
<code>generator_xy</code> and <code>generator_yx</code>.</p>
<p>If you want other parameters like <code>dataset_name</code> you should specify them here.
If you specify nothing, all the configurations will be calculated, including data loaders.
Calculation of configurations and their dependencies will happen when you call <code>experiment.start</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">724</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="n">conf_dict</span><span class="p">)</span>
<span class="lineno">725</span> <span class="n">conf</span><span class="o">.</span><span class="n">initialize</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-124'>
<div class='docs'>
<div class='section-link'>
<a href='#section-124'>#</a>
</div>
<p>Register models for saving and loading.
<code>get_modules</code> gives a dictionary of <code>nn.Modules</code> in <code>conf</code>.
You can also specify a custom dictionary of models.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">730</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-125'>
<div class='docs'>
<div class='section-link'>
<a href='#section-125'>#</a>
</div>
<p>Specify which run to load from.
Loading will actually happen when you call <code>experiment.start</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">733</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">trained_run_uuid</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-126'>
<div class='docs'>
<div class='section-link'>
<a href='#section-126'>#</a>
</div>
<p>Start the experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">736</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-127'>
<div class='docs'>
<div class='section-link'>
<a href='#section-127'>#</a>
</div>
<p>Image transformations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">738</span> <span class="n">transforms_</span> <span class="o">=</span> <span class="p">[</span>
<span class="lineno">739</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="lineno">740</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)),</span>
<span class="lineno">741</span> <span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-128'>
<div class='docs'>
<div class='section-link'>
<a href='#section-128'>#</a>
</div>
<p>Load your own data. Here we try the test set.
I was trying with Yosemite photos, they look awesome.
You can use <code>conf.dataset_name</code>, if you specified <code>dataset_name</code> as something you wanted to be calculated
in the call to <code>experiment.configs</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">747</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ImageDataset</span><span class="p">(</span><span class="n">conf</span><span class="o">.</span><span class="n">dataset_name</span><span class="p">,</span> <span class="n">transforms_</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-129'>
<div class='docs'>
<div class='section-link'>
<a href='#section-129'>#</a>
</div>
<p>Get an image from dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">749</span> <span class="n">x_image</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[</span><span class="mi">10</span><span class="p">][</span><span class="s1">&#39;x&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-130'>
<div class='docs'>
<div class='section-link'>
<a href='#section-130'>#</a>
</div>
<p>Display the image</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">751</span> <span class="n">plot_image</span><span class="p">(</span><span class="n">x_image</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-131'>
<div class='docs'>
<div class='section-link'>
<a href='#section-131'>#</a>
</div>
<p>Evaluation mode</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">754</span> <span class="n">conf</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="lineno">755</span> <span class="n">conf</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-132'>
<div class='docs'>
<div class='section-link'>
<a href='#section-132'>#</a>
</div>
<p>We don&rsquo;t need gradients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">758</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-133'>
<div class='docs'>
<div class='section-link'>
<a href='#section-133'>#</a>
</div>
<p>Add batch dimension and move to the device we use</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">760</span> <span class="n">data</span> <span class="o">=</span> <span class="n">x_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">conf</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">761</span> <span class="n">generated_y</span> <span class="o">=</span> <span class="n">conf</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-134'>
<div class='docs'>
<div class='section-link'>
<a href='#section-134'>#</a>
</div>
<p>Display the generated image.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">764</span> <span class="n">plot_image</span><span class="p">(</span><span class="n">generated_y</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">cpu</span><span class="p">())</span>
<span class="lineno">765</span>
<span class="lineno">766</span>
<span class="lineno">767</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">768</span> <span class="n">train</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-135'>
<div class='docs'>
<div class='section-link'>
<a href='#section-135'>#</a>
</div>
<p>evaluate()</p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
</body>
</html>