mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 20:28:41 +08:00
1957 lines
147 KiB
HTML
1957 lines
147 KiB
HTML
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
|
<meta name="description" content="A simple PyTorch implementation/tutorial of Cycle GAN introduced in paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks."/>
|
|
|
|
<meta name="twitter:card" content="summary"/>
|
|
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
|
<meta name="twitter:title" content="Cycle GAN"/>
|
|
<meta name="twitter:description" content="A simple PyTorch implementation/tutorial of Cycle GAN introduced in paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks."/>
|
|
<meta name="twitter:site" content="@labmlai"/>
|
|
<meta name="twitter:creator" content="@labmlai"/>
|
|
|
|
<meta property="og:url" content="https://nn.labml.ai/gan/cycle_gan.html"/>
|
|
<meta property="og:title" content="Cycle GAN"/>
|
|
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
|
<meta property="og:site_name" content="LabML Neural Networks"/>
|
|
<meta property="og:type" content="object"/>
|
|
<meta property="og:title" content="Cycle GAN"/>
|
|
<meta property="og:description" content="A simple PyTorch implementation/tutorial of Cycle GAN introduced in paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks."/>
|
|
|
|
<title>Cycle GAN</title>
|
|
<link rel="shortcut icon" href="/icon.png"/>
|
|
<link rel="stylesheet" href="../pylit.css">
|
|
<link rel="canonical" href="https://nn.labml.ai/gan/cycle_gan.html"/>
|
|
<!-- Global site tag (gtag.js) - Google Analytics -->
|
|
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
|
<script>
|
|
window.dataLayer = window.dataLayer || [];
|
|
|
|
function gtag() {
|
|
dataLayer.push(arguments);
|
|
}
|
|
|
|
gtag('js', new Date());
|
|
|
|
gtag('config', 'G-4V3HC8HBLH');
|
|
</script>
|
|
</head>
|
|
<body>
|
|
<div id='container'>
|
|
<div id="background"></div>
|
|
<div class='section'>
|
|
<div class='docs'>
|
|
<p>
|
|
<a class="parent" href="/">home</a>
|
|
<a class="parent" href="index.html">gan</a>
|
|
</p>
|
|
<p>
|
|
|
|
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/gan/cycle_gan.py">
|
|
<img alt="Github"
|
|
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
|
|
style="max-width:100%;"/></a>
|
|
<a href="https://join.slack.com/t/labforml/shared_invite/zt-egj9zvq9-Dl3hhZqobexgT7aVKnD14g/"
|
|
rel="nofollow">
|
|
<img alt="Join Slact"
|
|
src="https://img.shields.io/badge/slack-chat-green.svg?logo=slack"
|
|
style="max-width:100%;"/></a>
|
|
<a href="https://twitter.com/labmlai"
|
|
rel="nofollow">
|
|
<img alt="Twitter"
|
|
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
|
style="max-width:100%;"/></a>
|
|
</p>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-0'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-0'>#</a>
|
|
</div>
|
|
<h1>Cycle GAN</h1>
|
|
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of the paper
|
|
<a href="https://arxiv.org/abs/1703.10593">Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks</a>.</p>
|
|
<p>I’ve taken pieces of code from <a href="https://github.com/eriklindernoren/PyTorch-GAN">eriklindernoren/PyTorch-GAN</a>.
|
|
It is a very good resource if you want to checkout other GAN variations too.</p>
|
|
<p>Cycle GAN does image-to-image translation.
|
|
It trains a model to translate an image from given distribution to another, say, images of class A and B.
|
|
Images of a certain distribution could be things like images of a certain style, or nature.
|
|
The models do not need paired images between A and B.
|
|
Just a set of images of each class is enough.
|
|
This works very well on changing between image styles, lighting changes, pattern changes, etc.
|
|
For example, changing summer to winter, painting style to photos, and horses to zebras.</p>
|
|
<p>Cycle GAN trains two generator models and two discriminator models.
|
|
One generator translates images from A to B and the other from B to A.
|
|
The discriminators test whether the generated images look real.</p>
|
|
<p>This file contains the model code as well as the training code.
|
|
We also have a Google Colab notebook.</p>
|
|
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/gan/cycle_gan.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
|
|
<a href="https://app.labml.ai/run/93b11a665d6811ebaac80242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">36</span><span></span><span class="kn">import</span> <span class="nn">itertools</span>
|
|
<span class="lineno">37</span><span class="kn">import</span> <span class="nn">random</span>
|
|
<span class="lineno">38</span><span class="kn">import</span> <span class="nn">zipfile</span>
|
|
<span class="lineno">39</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span>
|
|
<span class="lineno">40</span>
|
|
<span class="lineno">41</span><span class="kn">import</span> <span class="nn">torch</span>
|
|
<span class="lineno">42</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
|
<span class="lineno">43</span><span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transforms</span>
|
|
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
|
|
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">Dataset</span>
|
|
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">torchvision.utils</span> <span class="kn">import</span> <span class="n">make_grid</span>
|
|
<span class="lineno">47</span>
|
|
<span class="lineno">48</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">monit</span>
|
|
<span class="lineno">49</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span>
|
|
<span class="lineno">50</span><span class="kn">from</span> <span class="nn">labml.utils.download</span> <span class="kn">import</span> <span class="n">download_file</span>
|
|
<span class="lineno">51</span><span class="kn">from</span> <span class="nn">labml.utils.pytorch</span> <span class="kn">import</span> <span class="n">get_modules</span>
|
|
<span class="lineno">52</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
|
|
<span class="lineno">53</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-1'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-1'>#</a>
|
|
</div>
|
|
<p>The generator is a residual network.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">56</span><span class="k">class</span> <span class="nc">GeneratorResNet</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-2'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-2'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">61</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_residual_blocks</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="lineno">62</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-3'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-3'>#</a>
|
|
</div>
|
|
<p>This first block runs a $7\times7$ convolution and maps the image to
|
|
a feature map.
|
|
The output feature map has the same height and width because we have
|
|
a padding of $3$.
|
|
Reflection padding is used because it gives better image quality at edges.</p>
|
|
<p><code>inplace=True</code> in <code>ReLU</code> saves a little bit of memory.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">out_features</span> <span class="o">=</span> <span class="mi">64</span>
|
|
<span class="lineno">71</span> <span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="lineno">72</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">input_channels</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding_mode</span><span class="o">=</span><span class="s1">'reflect'</span><span class="p">),</span>
|
|
<span class="lineno">73</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">out_features</span><span class="p">),</span>
|
|
<span class="lineno">74</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
|
|
<span class="lineno">75</span> <span class="p">]</span>
|
|
<span class="lineno">76</span> <span class="n">in_features</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-4'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-4'>#</a>
|
|
</div>
|
|
<p>We down-sample with two $3 \times 3$ convolutions
|
|
with stride of 2</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">):</span>
|
|
<span class="lineno">81</span> <span class="n">out_features</span> <span class="o">*=</span> <span class="mi">2</span>
|
|
<span class="lineno">82</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="lineno">83</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
|
|
<span class="lineno">84</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">out_features</span><span class="p">),</span>
|
|
<span class="lineno">85</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
|
|
<span class="lineno">86</span> <span class="p">]</span>
|
|
<span class="lineno">87</span> <span class="n">in_features</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-5'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-5'>#</a>
|
|
</div>
|
|
<p>We take this through <code>n_residual_blocks</code>.
|
|
This module is defined below.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">91</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_residual_blocks</span><span class="p">):</span>
|
|
<span class="lineno">92</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">ResidualBlock</span><span class="p">(</span><span class="n">out_features</span><span class="p">)]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-6'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-6'>#</a>
|
|
</div>
|
|
<p>Then the resulting feature map is up-sampled
|
|
to match the original image height and width.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">96</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">):</span>
|
|
<span class="lineno">97</span> <span class="n">out_features</span> <span class="o">//=</span> <span class="mi">2</span>
|
|
<span class="lineno">98</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="lineno">99</span> <span class="n">nn</span><span class="o">.</span><span class="n">Upsample</span><span class="p">(</span><span class="n">scale_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
|
<span class="lineno">100</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
|
|
<span class="lineno">101</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">out_features</span><span class="p">),</span>
|
|
<span class="lineno">102</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
|
|
<span class="lineno">103</span> <span class="p">]</span>
|
|
<span class="lineno">104</span> <span class="n">in_features</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-7'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-7'>#</a>
|
|
</div>
|
|
<p>Finally we map the feature map to an RGB image</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">107</span> <span class="n">layers</span> <span class="o">+=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">input_channels</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding_mode</span><span class="o">=</span><span class="s1">'reflect'</span><span class="p">),</span> <span class="n">nn</span><span class="o">.</span><span class="n">Tanh</span><span class="p">()]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-8'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-8'>#</a>
|
|
</div>
|
|
<p>Create a sequential module with the layers</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-9'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-9'>#</a>
|
|
</div>
|
|
<p>Initialize weights to $\mathcal{N}(0, 0.2)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">weights_init_normal</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-10'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-10'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">115</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
|
<span class="lineno">116</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-11'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-11'>#</a>
|
|
</div>
|
|
<p>This is the residual block, with two convolution layers.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">119</span><span class="k">class</span> <span class="nc">ResidualBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-12'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-12'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="lineno">125</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
<span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">block</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
|
|
<span class="lineno">127</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding_mode</span><span class="o">=</span><span class="s1">'reflect'</span><span class="p">),</span>
|
|
<span class="lineno">128</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">),</span>
|
|
<span class="lineno">129</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
|
|
<span class="lineno">130</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding_mode</span><span class="o">=</span><span class="s1">'reflect'</span><span class="p">),</span>
|
|
<span class="lineno">131</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">in_features</span><span class="p">),</span>
|
|
<span class="lineno">132</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span>
|
|
<span class="lineno">133</span> <span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-13'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-13'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">135</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="lineno">136</span> <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">block</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-14'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-14'>#</a>
|
|
</div>
|
|
<p>This is the discriminator.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">139</span><span class="k">class</span> <span class="nc">Discriminator</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-15'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-15'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]):</span>
|
|
<span class="lineno">145</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
<span class="lineno">146</span> <span class="n">channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">input_shape</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-16'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-16'>#</a>
|
|
</div>
|
|
<p>Output of the discriminator is also a map of probabilities*
|
|
whether each region of the image is real or generated</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">height</span> <span class="o">//</span> <span class="mi">2</span> <span class="o">**</span> <span class="mi">4</span><span class="p">,</span> <span class="n">width</span> <span class="o">//</span> <span class="mi">2</span> <span class="o">**</span> <span class="mi">4</span><span class="p">)</span>
|
|
<span class="lineno">151</span>
|
|
<span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-17'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-17'>#</a>
|
|
</div>
|
|
<p>Each of these blocks will shrink the height and width by a factor of 2</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">DiscriminatorBlock</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
|
|
<span class="lineno">155</span> <span class="n">DiscriminatorBlock</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span>
|
|
<span class="lineno">156</span> <span class="n">DiscriminatorBlock</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span>
|
|
<span class="lineno">157</span> <span class="n">DiscriminatorBlock</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-18'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-18'>#</a>
|
|
</div>
|
|
<p>Zero pad on top and left to keep the output height and width same
|
|
with the $4 \times 4$ kernel</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">nn</span><span class="o">.</span><span class="n">ZeroPad2d</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)),</span>
|
|
<span class="lineno">161</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="lineno">162</span> <span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-19'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-19'>#</a>
|
|
</div>
|
|
<p>Initialize weights to $\mathcal{N}(0, 0.2)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">165</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">weights_init_normal</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-20'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-20'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">167</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
|
|
<span class="lineno">168</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-21'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-21'>#</a>
|
|
</div>
|
|
<p>This is the discriminator block module.
|
|
It does a convolution, an optional normalization, and a leaky ReLU.</p>
|
|
<p>It shrinks the height and width of the input feature map by half.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">171</span><span class="k">class</span> <span class="nc">DiscriminatorBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-22'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-22'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">179</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_filters</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_filters</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">normalize</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
|
|
<span class="lineno">180</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
<span class="lineno">181</span> <span class="n">layers</span> <span class="o">=</span> <span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_filters</span><span class="p">,</span> <span class="n">out_filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
|
|
<span class="lineno">182</span> <span class="k">if</span> <span class="n">normalize</span><span class="p">:</span>
|
|
<span class="lineno">183</span> <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm2d</span><span class="p">(</span><span class="n">out_filters</span><span class="p">))</span>
|
|
<span class="lineno">184</span> <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
|
|
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-23'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-23'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="lineno">188</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-24'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-24'>#</a>
|
|
</div>
|
|
<p>Initialize convolution layer weights to $\mathcal{N}(0, 0.2)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">191</span><span class="k">def</span> <span class="nf">weights_init_normal</span><span class="p">(</span><span class="n">m</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-25'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-25'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">classname</span> <span class="o">=</span> <span class="n">m</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
|
|
<span class="lineno">196</span> <span class="k">if</span> <span class="n">classname</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"Conv"</span><span class="p">)</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
|
|
<span class="lineno">197</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.02</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-26'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-26'>#</a>
|
|
</div>
|
|
<p>Load an image and change to RGB if in grey-scale.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">200</span><span class="k">def</span> <span class="nf">load_image</span><span class="p">(</span><span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-27'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-27'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
|
|
<span class="lineno">205</span> <span class="k">if</span> <span class="n">image</span><span class="o">.</span><span class="n">mode</span> <span class="o">!=</span> <span class="s1">'RGB'</span><span class="p">:</span>
|
|
<span class="lineno">206</span> <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="s2">"RGB"</span><span class="p">,</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span><span class="p">)</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
|
|
<span class="lineno">207</span>
|
|
<span class="lineno">208</span> <span class="k">return</span> <span class="n">image</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-28'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-28'>#</a>
|
|
</div>
|
|
<h3>Dataset to load images</h3>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">211</span><span class="k">class</span> <span class="nc">ImageDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-29'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-29'>#</a>
|
|
</div>
|
|
<h4>Download dataset and extract data</h4>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">216</span> <span class="nd">@staticmethod</span>
|
|
<span class="lineno">217</span> <span class="k">def</span> <span class="nf">download</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-30'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-30'>#</a>
|
|
</div>
|
|
<p>URL</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">222</span> <span class="n">url</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/</span><span class="si">{</span><span class="n">dataset_name</span><span class="si">}</span><span class="s1">.zip'</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-31'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-31'>#</a>
|
|
</div>
|
|
<p>Download folder</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">224</span> <span class="n">root</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">'cycle_gan'</span>
|
|
<span class="lineno">225</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">root</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
|
|
<span class="lineno">226</span> <span class="n">root</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-32'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-32'>#</a>
|
|
</div>
|
|
<p>Download destination</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">archive</span> <span class="o">=</span> <span class="n">root</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">dataset_name</span><span class="si">}</span><span class="s1">.zip'</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-33'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-33'>#</a>
|
|
</div>
|
|
<p>Download file (generally ~100MB)</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">download_file</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">archive</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-34'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-34'>#</a>
|
|
</div>
|
|
<p>Extract the archive</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">232</span> <span class="k">with</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="n">archive</span><span class="p">,</span> <span class="s1">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
|
<span class="lineno">233</span> <span class="n">f</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="n">root</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-35'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-35'>#</a>
|
|
</div>
|
|
<h4>Initialize the dataset</h4>
|
|
<ul>
|
|
<li><code>dataset_name</code> is the name of the dataset</li>
|
|
<li><code>transforms_</code> is the set of image transforms</li>
|
|
<li><code>mode</code> is either <code>train</code> or <code>test</code></li>
|
|
</ul>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">235</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">transforms_</span><span class="p">,</span> <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-36'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-36'>#</a>
|
|
</div>
|
|
<p>Dataset path</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">244</span> <span class="n">root</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">'cycle_gan'</span> <span class="o">/</span> <span class="n">dataset_name</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-37'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-37'>#</a>
|
|
</div>
|
|
<p>Download if missing</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">246</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">root</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
|
|
<span class="lineno">247</span> <span class="bp">self</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-38'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-38'>#</a>
|
|
</div>
|
|
<p>Image transforms</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">250</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">(</span><span class="n">transforms_</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-39'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-39'>#</a>
|
|
</div>
|
|
<p>Get image paths</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">253</span> <span class="n">path_a</span> <span class="o">=</span> <span class="n">root</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">mode</span><span class="si">}</span><span class="s1">A'</span>
|
|
<span class="lineno">254</span> <span class="n">path_b</span> <span class="o">=</span> <span class="n">root</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">mode</span><span class="si">}</span><span class="s1">B'</span>
|
|
<span class="lineno">255</span> <span class="bp">self</span><span class="o">.</span><span class="n">files_a</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="n">path_a</span><span class="o">.</span><span class="n">iterdir</span><span class="p">())</span>
|
|
<span class="lineno">256</span> <span class="bp">self</span><span class="o">.</span><span class="n">files_b</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="n">path_b</span><span class="o">.</span><span class="n">iterdir</span><span class="p">())</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-40'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-40'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">258</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-41'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-41'>#</a>
|
|
</div>
|
|
<p>Return a pair of images.
|
|
These pairs get batched together, and they do not act like pairs in training.
|
|
So it is kind of ok that we always keep giving the same pair.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">262</span> <span class="k">return</span> <span class="p">{</span><span class="s2">"x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">load_image</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_a</span><span class="p">[</span><span class="n">index</span> <span class="o">%</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_a</span><span class="p">)])),</span>
|
|
<span class="lineno">263</span> <span class="s2">"y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">load_image</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_b</span><span class="p">[</span><span class="n">index</span> <span class="o">%</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_b</span><span class="p">)]))}</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-42'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-42'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">265</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-43'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-43'>#</a>
|
|
</div>
|
|
<p>Number of images in the dataset</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">267</span> <span class="k">return</span> <span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_a</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">files_b</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-44'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-44'>#</a>
|
|
</div>
|
|
<h3>Replay Buffer</h3>
|
|
<p>Replay buffer is used to train the discriminator.
|
|
Generated images are added to the replay buffer and sampled from it.</p>
|
|
<p>The replay buffer returns the newly added image with a probability of $0.5$.
|
|
Otherwise, it sends an older generated image and replaces the older image
|
|
with the newly generated image.</p>
|
|
<p>This is done to reduce model oscillation.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">270</span><span class="k">class</span> <span class="nc">ReplayBuffer</span><span class="p">:</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-45'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-45'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">284</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">50</span><span class="p">):</span>
|
|
<span class="lineno">285</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span> <span class="o">=</span> <span class="n">max_size</span>
|
|
<span class="lineno">286</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-46'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-46'>#</a>
|
|
</div>
|
|
<p>Add/retrieve an image</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">288</span> <span class="k">def</span> <span class="nf">push_and_pop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-47'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-47'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">290</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
|
|
<span class="lineno">291</span> <span class="n">res</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="lineno">292</span> <span class="k">for</span> <span class="n">element</span> <span class="ow">in</span> <span class="n">data</span><span class="p">:</span>
|
|
<span class="lineno">293</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="o"><</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span><span class="p">:</span>
|
|
<span class="lineno">294</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">element</span><span class="p">)</span>
|
|
<span class="lineno">295</span> <span class="n">res</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">element</span><span class="p">)</span>
|
|
<span class="lineno">296</span> <span class="k">else</span><span class="p">:</span>
|
|
<span class="lineno">297</span> <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">></span> <span class="mf">0.5</span><span class="p">:</span>
|
|
<span class="lineno">298</span> <span class="n">i</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="lineno">299</span> <span class="n">res</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
|
|
<span class="lineno">300</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">element</span>
|
|
<span class="lineno">301</span> <span class="k">else</span><span class="p">:</span>
|
|
<span class="lineno">302</span> <span class="n">res</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">element</span><span class="p">)</span>
|
|
<span class="lineno">303</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-48'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-48'>#</a>
|
|
</div>
|
|
<h2>Configurations</h2>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">306</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-49'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-49'>#</a>
|
|
</div>
|
|
<p><code>DeviceConfigs</code> will pick a GPU if available</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">310</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-50'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-50'>#</a>
|
|
</div>
|
|
<p>Hyper-parameters</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">313</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">200</span>
|
|
<span class="lineno">314</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'monet2photo'</span>
|
|
<span class="lineno">315</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
|
<span class="lineno">316</span>
|
|
<span class="lineno">317</span> <span class="n">data_loader_workers</span> <span class="o">=</span> <span class="mi">8</span>
|
|
<span class="lineno">318</span>
|
|
<span class="lineno">319</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.0002</span>
|
|
<span class="lineno">320</span> <span class="n">adam_betas</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">)</span>
|
|
<span class="lineno">321</span> <span class="n">decay_start</span> <span class="o">=</span> <span class="mi">100</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-51'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-51'>#</a>
|
|
</div>
|
|
<p>The paper suggests using a least-squares loss instead of
|
|
negative log-likelihood, at it is found to be more stable.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">325</span> <span class="n">gan_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">MSELoss</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-52'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-52'>#</a>
|
|
</div>
|
|
<p>L1 loss is used for cycle loss and identity loss</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">328</span> <span class="n">cycle_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">L1Loss</span><span class="p">()</span>
|
|
<span class="lineno">329</span> <span class="n">identity_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">L1Loss</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-53'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-53'>#</a>
|
|
</div>
|
|
<p>Image dimensions</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">332</span> <span class="n">img_height</span> <span class="o">=</span> <span class="mi">256</span>
|
|
<span class="lineno">333</span> <span class="n">img_width</span> <span class="o">=</span> <span class="mi">256</span>
|
|
<span class="lineno">334</span> <span class="n">img_channels</span> <span class="o">=</span> <span class="mi">3</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-54'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-54'>#</a>
|
|
</div>
|
|
<p>Number of residual blocks in the generator</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">337</span> <span class="n">n_residual_blocks</span> <span class="o">=</span> <span class="mi">9</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-55'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-55'>#</a>
|
|
</div>
|
|
<p>Loss coefficients</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">340</span> <span class="n">cyclic_loss_coefficient</span> <span class="o">=</span> <span class="mf">10.0</span>
|
|
<span class="lineno">341</span> <span class="n">identity_loss_coefficient</span> <span class="o">=</span> <span class="mf">5.</span>
|
|
<span class="lineno">342</span>
|
|
<span class="lineno">343</span> <span class="n">sample_interval</span> <span class="o">=</span> <span class="mi">500</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-56'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-56'>#</a>
|
|
</div>
|
|
<p>Models</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">346</span> <span class="n">generator_xy</span><span class="p">:</span> <span class="n">GeneratorResNet</span>
|
|
<span class="lineno">347</span> <span class="n">generator_yx</span><span class="p">:</span> <span class="n">GeneratorResNet</span>
|
|
<span class="lineno">348</span> <span class="n">discriminator_x</span><span class="p">:</span> <span class="n">Discriminator</span>
|
|
<span class="lineno">349</span> <span class="n">discriminator_y</span><span class="p">:</span> <span class="n">Discriminator</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-57'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-57'>#</a>
|
|
</div>
|
|
<p>Optimizers</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">352</span> <span class="n">generator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
|
|
<span class="lineno">353</span> <span class="n">discriminator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-58'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-58'>#</a>
|
|
</div>
|
|
<p>Learning rate schedules</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">356</span> <span class="n">generator_lr_scheduler</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">LambdaLR</span>
|
|
<span class="lineno">357</span> <span class="n">discriminator_lr_scheduler</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">LambdaLR</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-59'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-59'>#</a>
|
|
</div>
|
|
<p>Data loaders</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">360</span> <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span>
|
|
<span class="lineno">361</span> <span class="n">valid_dataloader</span><span class="p">:</span> <span class="n">DataLoader</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-60'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-60'>#</a>
|
|
</div>
|
|
<p>Generate samples from test set and save them</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">363</span> <span class="k">def</span> <span class="nf">sample_images</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-61'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-61'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">365</span> <span class="n">batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_dataloader</span><span class="p">))</span>
|
|
<span class="lineno">366</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
|
|
<span class="lineno">367</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
|
|
<span class="lineno">368</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
|
|
<span class="lineno">369</span> <span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="s1">'x'</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="s1">'y'</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="lineno">370</span> <span class="n">gen_y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
|
|
<span class="lineno">371</span> <span class="n">gen_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="p">(</span><span class="n">data_y</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-62'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-62'>#</a>
|
|
</div>
|
|
<p>Arrange images along x-axis</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">374</span> <span class="n">data_x</span> <span class="o">=</span> <span class="n">make_grid</span><span class="p">(</span><span class="n">data_x</span><span class="p">,</span> <span class="n">nrow</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
<span class="lineno">375</span> <span class="n">data_y</span> <span class="o">=</span> <span class="n">make_grid</span><span class="p">(</span><span class="n">data_y</span><span class="p">,</span> <span class="n">nrow</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
<span class="lineno">376</span> <span class="n">gen_x</span> <span class="o">=</span> <span class="n">make_grid</span><span class="p">(</span><span class="n">gen_x</span><span class="p">,</span> <span class="n">nrow</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
<span class="lineno">377</span> <span class="n">gen_y</span> <span class="o">=</span> <span class="n">make_grid</span><span class="p">(</span><span class="n">gen_y</span><span class="p">,</span> <span class="n">nrow</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">normalize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-63'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-63'>#</a>
|
|
</div>
|
|
<p>Arrange images along y-axis</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">380</span> <span class="n">image_grid</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">data_x</span><span class="p">,</span> <span class="n">gen_y</span><span class="p">,</span> <span class="n">data_y</span><span class="p">,</span> <span class="n">gen_x</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-64'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-64'>#</a>
|
|
</div>
|
|
<p>Show samples</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">383</span> <span class="n">plot_image</span><span class="p">(</span><span class="n">image_grid</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-65'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-65'>#</a>
|
|
</div>
|
|
<h2>Initialize models and data loaders</h2>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">385</span> <span class="k">def</span> <span class="nf">initialize</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-66'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-66'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">389</span> <span class="n">input_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_width</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-67'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-67'>#</a>
|
|
</div>
|
|
<p>Create the models</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">392</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span> <span class="o">=</span> <span class="n">GeneratorResNet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_residual_blocks</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="lineno">393</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span> <span class="o">=</span> <span class="n">GeneratorResNet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_residual_blocks</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="lineno">394</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span> <span class="o">=</span> <span class="n">Discriminator</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="lineno">395</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span> <span class="o">=</span> <span class="n">Discriminator</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-68'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-68'>#</a>
|
|
</div>
|
|
<p>Create the optmizers</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">398</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
|
|
<span class="lineno">399</span> <span class="n">itertools</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">parameters</span><span class="p">()),</span>
|
|
<span class="lineno">400</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span><span class="p">)</span>
|
|
<span class="lineno">401</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
|
|
<span class="lineno">402</span> <span class="n">itertools</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span><span class="o">.</span><span class="n">parameters</span><span class="p">()),</span>
|
|
<span class="lineno">403</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-69'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-69'>#</a>
|
|
</div>
|
|
<p>Create the learning rate schedules.
|
|
The learning rate stars flat until <code>decay_start</code> epochs,
|
|
and then linearly reduce to $0$ at end of training.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">408</span> <span class="n">decay_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">decay_start</span>
|
|
<span class="lineno">409</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_lr_scheduler</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">LambdaLR</span><span class="p">(</span>
|
|
<span class="lineno">410</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="p">,</span> <span class="n">lr_lambda</span><span class="o">=</span><span class="k">lambda</span> <span class="n">e</span><span class="p">:</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">e</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">decay_start</span><span class="p">)</span> <span class="o">/</span> <span class="n">decay_epochs</span><span class="p">)</span>
|
|
<span class="lineno">411</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_lr_scheduler</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">lr_scheduler</span><span class="o">.</span><span class="n">LambdaLR</span><span class="p">(</span>
|
|
<span class="lineno">412</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="p">,</span> <span class="n">lr_lambda</span><span class="o">=</span><span class="k">lambda</span> <span class="n">e</span><span class="p">:</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">e</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">decay_start</span><span class="p">)</span> <span class="o">/</span> <span class="n">decay_epochs</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-70'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-70'>#</a>
|
|
</div>
|
|
<p>Image transformations</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">415</span> <span class="n">transforms_</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="lineno">416</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_height</span> <span class="o">*</span> <span class="mf">1.12</span><span class="p">),</span> <span class="n">Image</span><span class="o">.</span><span class="n">BICUBIC</span><span class="p">),</span>
|
|
<span class="lineno">417</span> <span class="n">transforms</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">img_height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_width</span><span class="p">)),</span>
|
|
<span class="lineno">418</span> <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
|
|
<span class="lineno">419</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
|
|
<span class="lineno">420</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)),</span>
|
|
<span class="lineno">421</span> <span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-71'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-71'>#</a>
|
|
</div>
|
|
<p>Training data loader</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">424</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span>
|
|
<span class="lineno">425</span> <span class="n">ImageDataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span><span class="p">,</span> <span class="n">transforms_</span><span class="p">,</span> <span class="s1">'train'</span><span class="p">),</span>
|
|
<span class="lineno">426</span> <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="lineno">427</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
|
<span class="lineno">428</span> <span class="n">num_workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_loader_workers</span><span class="p">,</span>
|
|
<span class="lineno">429</span> <span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-72'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-72'>#</a>
|
|
</div>
|
|
<p>Validation data loader</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">432</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span>
|
|
<span class="lineno">433</span> <span class="n">ImageDataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span><span class="p">,</span> <span class="n">transforms_</span><span class="p">,</span> <span class="s2">"test"</span><span class="p">),</span>
|
|
<span class="lineno">434</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
|
|
<span class="lineno">435</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
|
<span class="lineno">436</span> <span class="n">num_workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_loader_workers</span><span class="p">,</span>
|
|
<span class="lineno">437</span> <span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-73'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-73'>#</a>
|
|
</div>
|
|
<h2>Training</h2>
|
|
<p>We aim to solve:
|
|
<script type="math/tex; mode=display">G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)</script>
|
|
</p>
|
|
<p>where,
|
|
$G$ translates images from $X \rightarrow Y$,
|
|
$F$ translates images from $Y \rightarrow X$,
|
|
$D_X$ tests if images are from $X$ space,
|
|
$D_Y$ tests if images are from $Y$ space, and
|
|
<script type="math/tex; mode=display">\begin{align}
|
|
\mathcal{L}(G, F, D_X, D_Y)
|
|
&= \mathcal{L}_{GAN}(G, D_Y, X, Y) \\
|
|
&+ \mathcal{L}_{GAN}(F, D_X, Y, X) \\
|
|
&+ \lambda_1 \mathcal{L}_{cyc}(G, F) \\
|
|
&+ \lambda_2 \mathcal{L}_{identity}(G, F) \\
|
|
\\
|
|
\mathcal{L}_{GAN}(G, F, D_Y, X, Y)
|
|
&= \mathbb{E}_{y \sim p_{data}(y)} \Big[log D_Y(y)\Big] \\
|
|
&+ \mathbb{E}_{x \sim p_{data}(x)} \bigg[log\Big(1 - D_Y(G(x))\Big)\bigg] \\
|
|
&+ \mathbb{E}_{x \sim p_{data}(x)} \Big[log D_X(x)\Big] \\
|
|
&+ \mathbb{E}_{y \sim p_{data}(y)} \bigg[log\Big(1 - D_X(F(y))\Big)\bigg] \\
|
|
\\
|
|
\mathcal{L}_{cyc}(G, F)
|
|
&= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(G(x)) - x \lVert_1\Big] \\
|
|
&+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(F(y)) - y \rVert_1\Big] \\
|
|
\\
|
|
\mathcal{L}_{identity}(G, F)
|
|
&= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(x) - x \lVert_1\Big] \\
|
|
&+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(y) - y \rVert_1\Big] \\
|
|
\end{align}</script>
|
|
</p>
|
|
<p>$\mathcal{L}_{GAN}$ is the generative adversarial loss from the original
|
|
GAN paper.</p>
|
|
<p>$\mathcal{L}_{cyc}$ is the cyclic loss, where we try to get $F(G(x))$ to be similar to $x$,
|
|
and $G(F(y))$ to be similar to $y$.
|
|
Basically if the two generators (transformations) are applied in series it should give back the
|
|
original image.
|
|
This is the main contribution of this paper.
|
|
It trains the generators to generate an image of the other distribution that is similar to
|
|
the original image.
|
|
Without this loss $G(x)$ could generate anything that’s from the distribution of $Y$.
|
|
Now it needs to generate something from the distribution of $Y$ but still has properties of $x$,
|
|
so that $F(G(x)$ can re-generate something like $x$.</p>
|
|
<p>$\mathcal{L}_{cyc}$ is the identity loss.
|
|
This was used to encourage the mapping to preserve color composition between
|
|
the input and the output.</p>
|
|
<p>To solve $G^{*}, F^{*}$,
|
|
discriminators $D_X$ and $D_Y$ should <strong>ascend</strong> on the gradient,
|
|
<script type="math/tex; mode=display">\begin{align}
|
|
\nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m
|
|
&\Bigg[
|
|
\log D_Y\Big(y^{(i)}\Big) \\
|
|
&+ \log \Big(1 - D_Y\Big(G\Big(x^{(i)}\Big)\Big)\Big) \\
|
|
&+ \log D_X\Big(x^{(i)}\Big) \\
|
|
& +\log\Big(1 - D_X\Big(F\Big(y^{(i)}\Big)\Big)\Big)
|
|
\Bigg]
|
|
\end{align}</script>
|
|
That is descend on <em>negative</em> log-likelihood loss.</p>
|
|
<p>In order to stabilize the training the negative log- likelihood objective
|
|
was replaced by a least-squared loss -
|
|
the least-squared error of discriminator, labelling real images with 1,
|
|
and generated images with 0.
|
|
So we want to descend on the gradient,
|
|
<script type="math/tex; mode=display">\begin{align}
|
|
\nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m
|
|
&\Bigg[
|
|
\bigg(D_Y\Big(y^{(i)}\Big) - 1\bigg)^2 \\
|
|
&+ D_Y\Big(G\Big(x^{(i)}\Big)\Big)^2 \\
|
|
&+ \bigg(D_X\Big(x^{(i)}\Big) - 1\bigg)^2 \\
|
|
&+ D_X\Big(F\Big(y^{(i)}\Big)\Big)^2
|
|
\Bigg]
|
|
\end{align}</script>
|
|
</p>
|
|
<p>We use least-squares for generators also.
|
|
The generators should <em>descend</em> on the gradient,
|
|
<script type="math/tex; mode=display">\begin{align}
|
|
\nabla_{\theta_{F, G}} \frac{1}{m} \sum_{i=1}^m
|
|
&\Bigg[
|
|
\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2 \\
|
|
&+ \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2 \\
|
|
&+ \mathcal{L}_{cyc}(G, F)
|
|
+ \mathcal{L}_{identity}(G, F)
|
|
\Bigg]
|
|
\end{align}</script>
|
|
</p>
|
|
<p>We use <code>generator_xy</code> for $G$ and <code>generator_yx$ for $F$.
|
|
We use</code>discriminator_x$ for $D_X$ and <code>discriminator_y</code> for $D_Y$.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">439</span> <span class="k">def</span> <span class="nf">run</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-74'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-74'>#</a>
|
|
</div>
|
|
<p>Replay buffers to keep generated samples</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">536</span> <span class="n">gen_x_buffer</span> <span class="o">=</span> <span class="n">ReplayBuffer</span><span class="p">()</span>
|
|
<span class="lineno">537</span> <span class="n">gen_y_buffer</span> <span class="o">=</span> <span class="n">ReplayBuffer</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-75'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-75'>#</a>
|
|
</div>
|
|
<p>Loop through epochs</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">540</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epochs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-76'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-76'>#</a>
|
|
</div>
|
|
<p>Loop through the dataset</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">542</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">enum</span><span class="p">(</span><span class="s1">'Train'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-77'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-77'>#</a>
|
|
</div>
|
|
<p>Move images to the device</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">544</span> <span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="s1">'x'</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="s1">'y'</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-78'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-78'>#</a>
|
|
</div>
|
|
<p>true labels equal to $1$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">547</span> <span class="n">true_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">data_x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="o">.</span><span class="n">output_shape</span><span class="p">,</span>
|
|
<span class="lineno">548</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-79'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-79'>#</a>
|
|
</div>
|
|
<p>false labels equal to $0$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">550</span> <span class="n">false_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">data_x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="o">.</span><span class="n">output_shape</span><span class="p">,</span>
|
|
<span class="lineno">551</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-80'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-80'>#</a>
|
|
</div>
|
|
<p>Train the generators.
|
|
This returns the generated images.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">555</span> <span class="n">gen_x</span><span class="p">,</span> <span class="n">gen_y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimize_generators</span><span class="p">(</span><span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span><span class="p">,</span> <span class="n">true_labels</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-81'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-81'>#</a>
|
|
</div>
|
|
<p>Train discriminators</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">558</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimize_discriminator</span><span class="p">(</span><span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span><span class="p">,</span>
|
|
<span class="lineno">559</span> <span class="n">gen_x_buffer</span><span class="o">.</span><span class="n">push_and_pop</span><span class="p">(</span><span class="n">gen_x</span><span class="p">),</span> <span class="n">gen_y_buffer</span><span class="o">.</span><span class="n">push_and_pop</span><span class="p">(</span><span class="n">gen_y</span><span class="p">),</span>
|
|
<span class="lineno">560</span> <span class="n">true_labels</span><span class="p">,</span> <span class="n">false_labels</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-82'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-82'>#</a>
|
|
</div>
|
|
<p>Save training statistics and increment the global step counter</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">563</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span>
|
|
<span class="lineno">564</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data_x</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">data_y</span><span class="p">)))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-83'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-83'>#</a>
|
|
</div>
|
|
<p>Save images at intervals</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">567</span> <span class="n">batches_done</span> <span class="o">=</span> <span class="n">epoch</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span><span class="p">)</span> <span class="o">+</span> <span class="n">i</span>
|
|
<span class="lineno">568</span> <span class="k">if</span> <span class="n">batches_done</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-84'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-84'>#</a>
|
|
</div>
|
|
<p>Save models when sampling images</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">570</span> <span class="n">experiment</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-85'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-85'>#</a>
|
|
</div>
|
|
<p>Sample images</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">572</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_images</span><span class="p">(</span><span class="n">batches_done</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-86'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-86'>#</a>
|
|
</div>
|
|
<p>Update learning rates</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">575</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_lr_scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
|
<span class="lineno">576</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_lr_scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-87'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-87'>#</a>
|
|
</div>
|
|
<p>New line</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">578</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-88'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-88'>#</a>
|
|
</div>
|
|
<h3>Optimize the generators with identity, gan and cycle losses.</h3>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">580</span> <span class="k">def</span> <span class="nf">optimize_generators</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">data_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">true_labels</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-89'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-89'>#</a>
|
|
</div>
|
|
<p>Change to training mode</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">586</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
|
|
<span class="lineno">587</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">train</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-90'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-90'>#</a>
|
|
</div>
|
|
<p>Identity loss
|
|
<script type="math/tex; mode=display">\lVert F(G(x^{(i)})) - x^{(i)} \lVert_1\
|
|
\lVert G(F(y^{(i)})) - y^{(i)} \rVert_1</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">592</span> <span class="n">loss_identity</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">identity_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="p">(</span><span class="n">data_x</span><span class="p">),</span> <span class="n">data_x</span><span class="p">)</span> <span class="o">+</span>
|
|
<span class="lineno">593</span> <span class="bp">self</span><span class="o">.</span><span class="n">identity_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">data_y</span><span class="p">),</span> <span class="n">data_y</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-91'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-91'>#</a>
|
|
</div>
|
|
<p>Generate images $G(x)$ and $F(y)$</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">596</span> <span class="n">gen_y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
|
|
<span class="lineno">597</span> <span class="n">gen_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="p">(</span><span class="n">data_y</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-92'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-92'>#</a>
|
|
</div>
|
|
<p>GAN loss
|
|
<script type="math/tex; mode=display">\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2
|
|
+ \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">602</span> <span class="n">loss_gan</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span><span class="p">(</span><span class="n">gen_y</span><span class="p">),</span> <span class="n">true_labels</span><span class="p">)</span> <span class="o">+</span>
|
|
<span class="lineno">603</span> <span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="p">(</span><span class="n">gen_x</span><span class="p">),</span> <span class="n">true_labels</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-93'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-93'>#</a>
|
|
</div>
|
|
<p>Cycle loss
|
|
<script type="math/tex; mode=display">
|
|
\lVert F(G(x^{(i)})) - x^{(i)} \lVert_1 +
|
|
\lVert G(F(y^{(i)})) - y^{(i)} \rVert_1
|
|
</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">610</span> <span class="n">loss_cycle</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cycle_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_yx</span><span class="p">(</span><span class="n">gen_y</span><span class="p">),</span> <span class="n">data_x</span><span class="p">)</span> <span class="o">+</span>
|
|
<span class="lineno">611</span> <span class="bp">self</span><span class="o">.</span><span class="n">cycle_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">gen_x</span><span class="p">),</span> <span class="n">data_y</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-94'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-94'>#</a>
|
|
</div>
|
|
<p>Total loss</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">614</span> <span class="n">loss_generator</span> <span class="o">=</span> <span class="p">(</span><span class="n">loss_gan</span> <span class="o">+</span>
|
|
<span class="lineno">615</span> <span class="bp">self</span><span class="o">.</span><span class="n">cyclic_loss_coefficient</span> <span class="o">*</span> <span class="n">loss_cycle</span> <span class="o">+</span>
|
|
<span class="lineno">616</span> <span class="bp">self</span><span class="o">.</span><span class="n">identity_loss_coefficient</span> <span class="o">*</span> <span class="n">loss_identity</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-95'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-95'>#</a>
|
|
</div>
|
|
<p>Take a step in the optimizer</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">619</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
|
<span class="lineno">620</span> <span class="n">loss_generator</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
|
<span class="lineno">621</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-96'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-96'>#</a>
|
|
</div>
|
|
<p>Log losses</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">624</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="s1">'loss.generator'</span><span class="p">:</span> <span class="n">loss_generator</span><span class="p">,</span>
|
|
<span class="lineno">625</span> <span class="s1">'loss.generator.cycle'</span><span class="p">:</span> <span class="n">loss_cycle</span><span class="p">,</span>
|
|
<span class="lineno">626</span> <span class="s1">'loss.generator.gan'</span><span class="p">:</span> <span class="n">loss_gan</span><span class="p">,</span>
|
|
<span class="lineno">627</span> <span class="s1">'loss.generator.identity'</span><span class="p">:</span> <span class="n">loss_identity</span><span class="p">})</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-97'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-97'>#</a>
|
|
</div>
|
|
<p>Return generated images</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">630</span> <span class="k">return</span> <span class="n">gen_x</span><span class="p">,</span> <span class="n">gen_y</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-98'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-98'>#</a>
|
|
</div>
|
|
<h3>Optimize the discriminators with gan loss.</h3>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">632</span> <span class="k">def</span> <span class="nf">optimize_discriminator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">data_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="lineno">633</span> <span class="n">gen_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">gen_y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="lineno">634</span> <span class="n">true_labels</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">false_labels</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-99'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-99'>#</a>
|
|
</div>
|
|
<p>GAN Loss
|
|
<script type="math/tex; mode=display">\begin{align}
|
|
\bigg(D_Y\Big(y ^ {(i)}\Big) - 1\bigg) ^ 2
|
|
+ D_Y\Big(G\Big(x ^ {(i)}\Big)\Big) ^ 2 + \\
|
|
\bigg(D_X\Big(x ^ {(i)}\Big) - 1\bigg) ^ 2
|
|
+ D_X\Big(F\Big(y ^ {(i)}\Big)\Big) ^ 2
|
|
\end{align}</script>
|
|
</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">645</span> <span class="n">loss_discriminator</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="p">(</span><span class="n">data_x</span><span class="p">),</span> <span class="n">true_labels</span><span class="p">)</span> <span class="o">+</span>
|
|
<span class="lineno">646</span> <span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_x</span><span class="p">(</span><span class="n">gen_x</span><span class="p">),</span> <span class="n">false_labels</span><span class="p">)</span> <span class="o">+</span>
|
|
<span class="lineno">647</span> <span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span><span class="p">(</span><span class="n">data_y</span><span class="p">),</span> <span class="n">true_labels</span><span class="p">)</span> <span class="o">+</span>
|
|
<span class="lineno">648</span> <span class="bp">self</span><span class="o">.</span><span class="n">gan_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator_y</span><span class="p">(</span><span class="n">gen_y</span><span class="p">),</span> <span class="n">false_labels</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-100'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-100'>#</a>
|
|
</div>
|
|
<p>Take a step in the optimizer</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">651</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
|
<span class="lineno">652</span> <span class="n">loss_discriminator</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
|
<span class="lineno">653</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-101'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-101'>#</a>
|
|
</div>
|
|
<p>Log losses</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">656</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="s1">'loss.discriminator'</span><span class="p">:</span> <span class="n">loss_discriminator</span><span class="p">})</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-102'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-102'>#</a>
|
|
</div>
|
|
<h2>Train Cycle GAN</h2>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">659</span><span class="k">def</span> <span class="nf">train</span><span class="p">():</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-103'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-103'>#</a>
|
|
</div>
|
|
<p>Create configurations</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">664</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-104'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-104'>#</a>
|
|
</div>
|
|
<p>Create an experiment</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">666</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'cycle_gan'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-105'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-105'>#</a>
|
|
</div>
|
|
<p>Calculate configurations.
|
|
It will calculate <code>conf.run</code> and all other configs required by it.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">669</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span><span class="s1">'dataset_name'</span><span class="p">:</span> <span class="s1">'summer2winter_yosemite'</span><span class="p">})</span>
|
|
<span class="lineno">670</span> <span class="n">conf</span><span class="o">.</span><span class="n">initialize</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-106'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-106'>#</a>
|
|
</div>
|
|
<p>Register models for saving and loading.
|
|
<code>get_modules</code> gives a dictionary of <code>nn.Modules</code> in <code>conf</code>.
|
|
You can also specify a custom dictionary of models.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">675</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-107'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-107'>#</a>
|
|
</div>
|
|
<p>Start and watch the experiment</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">677</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-108'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-108'>#</a>
|
|
</div>
|
|
<p>Run the training</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">679</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-109'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-109'>#</a>
|
|
</div>
|
|
<h3>Plot an image with matplotlib</h3>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">682</span><span class="k">def</span> <span class="nf">plot_image</span><span class="p">(</span><span class="n">img</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-110'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-110'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">686</span> <span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-111'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-111'>#</a>
|
|
</div>
|
|
<p>Move tensor to CPU</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">689</span> <span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-112'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-112'>#</a>
|
|
</div>
|
|
<p>Get min and max values of the image for normalization</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">691</span> <span class="n">img_min</span><span class="p">,</span> <span class="n">img_max</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">min</span><span class="p">(),</span> <span class="n">img</span><span class="o">.</span><span class="n">max</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-113'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-113'>#</a>
|
|
</div>
|
|
<p>Scale image values to be [0…1]</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">693</span> <span class="n">img</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span> <span class="o">-</span> <span class="n">img_min</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">img_max</span> <span class="o">-</span> <span class="n">img_min</span> <span class="o">+</span> <span class="mf">1e-5</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-114'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-114'>#</a>
|
|
</div>
|
|
<p>We have to change the order of dimensions to HWC.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">695</span> <span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-115'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-115'>#</a>
|
|
</div>
|
|
<p>Show Image</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">697</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-116'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-116'>#</a>
|
|
</div>
|
|
<p>We don’t need axes</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">699</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">'off'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-117'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-117'>#</a>
|
|
</div>
|
|
<p>Display</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">701</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-118'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-118'>#</a>
|
|
</div>
|
|
<h2>Evaluate trained Cycle GAN</h2>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">704</span><span class="k">def</span> <span class="nf">evaluate</span><span class="p">():</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-119'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-119'>#</a>
|
|
</div>
|
|
<p>Set the run UUID from the training run</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">709</span> <span class="n">trained_run_uuid</span> <span class="o">=</span> <span class="s1">'f73c1164184711eb9190b74249275441'</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-120'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-120'>#</a>
|
|
</div>
|
|
<p>Create configs object</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">711</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-121'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-121'>#</a>
|
|
</div>
|
|
<p>Create experiment</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">713</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'cycle_gan_inference'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-122'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-122'>#</a>
|
|
</div>
|
|
<p>Load hyper parameters set for training</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">715</span> <span class="n">conf_dict</span> <span class="o">=</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load_configs</span><span class="p">(</span><span class="n">trained_run_uuid</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-123'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-123'>#</a>
|
|
</div>
|
|
<p>Calculate configurations. We specify the generators <code>'generator_xy', 'generator_yx'</code>
|
|
so that it only loads those and their dependencies.
|
|
Configs like <code>device</code> and <code>img_channels</code> will be calculated, since these are required by
|
|
<code>generator_xy</code> and <code>generator_yx</code>.</p>
|
|
<p>If you want other parameters like <code>dataset_name</code> you should specify them here.
|
|
If you specify nothing, all the configurations will be calculated, including data loaders.
|
|
Calculation of configurations and their dependencies will happen when you call <code>experiment.start</code></p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">724</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="n">conf_dict</span><span class="p">)</span>
|
|
<span class="lineno">725</span> <span class="n">conf</span><span class="o">.</span><span class="n">initialize</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-124'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-124'>#</a>
|
|
</div>
|
|
<p>Register models for saving and loading.
|
|
<code>get_modules</code> gives a dictionary of <code>nn.Modules</code> in <code>conf</code>.
|
|
You can also specify a custom dictionary of models.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">730</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-125'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-125'>#</a>
|
|
</div>
|
|
<p>Specify which run to load from.
|
|
Loading will actually happen when you call <code>experiment.start</code></p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">733</span> <span class="n">experiment</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">trained_run_uuid</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-126'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-126'>#</a>
|
|
</div>
|
|
<p>Start the experiment</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">736</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-127'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-127'>#</a>
|
|
</div>
|
|
<p>Image transformations</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">738</span> <span class="n">transforms_</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="lineno">739</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
|
|
<span class="lineno">740</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">),</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)),</span>
|
|
<span class="lineno">741</span> <span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-128'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-128'>#</a>
|
|
</div>
|
|
<p>Load your own data. Here we try the test set.
|
|
I was trying with Yosemite photos, they look awesome.
|
|
You can use <code>conf.dataset_name</code>, if you specified <code>dataset_name</code> as something you wanted to be calculated
|
|
in the call to <code>experiment.configs</code></p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">747</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ImageDataset</span><span class="p">(</span><span class="n">conf</span><span class="o">.</span><span class="n">dataset_name</span><span class="p">,</span> <span class="n">transforms_</span><span class="p">,</span> <span class="s1">'train'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-129'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-129'>#</a>
|
|
</div>
|
|
<p>Get an image from dataset</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">749</span> <span class="n">x_image</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[</span><span class="mi">10</span><span class="p">][</span><span class="s1">'x'</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-130'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-130'>#</a>
|
|
</div>
|
|
<p>Display the image</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">751</span> <span class="n">plot_image</span><span class="p">(</span><span class="n">x_image</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-131'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-131'>#</a>
|
|
</div>
|
|
<p>Evaluation mode</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">754</span> <span class="n">conf</span><span class="o">.</span><span class="n">generator_xy</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
|
|
<span class="lineno">755</span> <span class="n">conf</span><span class="o">.</span><span class="n">generator_yx</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-132'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-132'>#</a>
|
|
</div>
|
|
<p>We don’t need gradients</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">758</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-133'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-133'>#</a>
|
|
</div>
|
|
<p>Add batch dimension and move to the device we use</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">760</span> <span class="n">data</span> <span class="o">=</span> <span class="n">x_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">conf</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="lineno">761</span> <span class="n">generated_y</span> <span class="o">=</span> <span class="n">conf</span><span class="o">.</span><span class="n">generator_xy</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-134'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-134'>#</a>
|
|
</div>
|
|
<p>Display the generated image.</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">764</span> <span class="n">plot_image</span><span class="p">(</span><span class="n">generated_y</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">cpu</span><span class="p">())</span>
|
|
<span class="lineno">765</span>
|
|
<span class="lineno">766</span>
|
|
<span class="lineno">767</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
|
<span class="lineno">768</span> <span class="n">train</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-135'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-135'>#</a>
|
|
</div>
|
|
<p>evaluate()</p>
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre></pre></div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
|
</script>
|
|
<!-- MathJax configuration -->
|
|
<script type="text/x-mathjax-config">
|
|
MathJax.Hub.Config({
|
|
tex2jax: {
|
|
inlineMath: [ ['$','$'] ],
|
|
displayMath: [ ['$$','$$'] ],
|
|
processEscapes: true,
|
|
processEnvironments: true
|
|
},
|
|
// Center justify equations in code and markdown cells. Elsewhere
|
|
// we use CSS to left justify single line equations in code cells.
|
|
displayAlign: 'center',
|
|
"HTML-CSS": { fonts: ["TeX"] }
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
</script>
|
|
</body>
|
|
</html> |