Files
Varuna Jayasiri c4d2e8cd22 docs
2025-07-31 08:48:07 +05:30

1734 lines
118 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="An annotated PyTorch implementation of StyleGAN2 model training code."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="StyleGAN 2 Model Training"/>
<meta name="twitter:description" content="An annotated PyTorch implementation of StyleGAN2 model training code."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/gan/stylegan/experiment.html"/>
<meta property="og:title" content="StyleGAN 2 Model Training"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="StyleGAN 2 Model Training"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="StyleGAN 2 Model Training"/>
<meta property="og:description" content="An annotated PyTorch implementation of StyleGAN2 model training code."/>
<title>StyleGAN 2 Model Training</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/gan/stylegan/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">gan</a>
<a class="parent" href="index.html">stylegan</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/stylegan/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="index.html">StyleGAN 2</a> Model Training</h1>
<p>This is the training code for <a href="index.html">StyleGAN 2</a> model.</p>
<p><img alt="Generated Images" src="generated_64.png"></p>
<p><small><em>These are <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">6</span><span class="mord coloredeq eqm" style=""><span class="mord" style="">4</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">6</span><span class="mord coloredeq eqm" style=""><span class="mord" style="">4</span></span></span></span></span></span> images generated after training for about 80K steps.</em></small></p>
<p><em>Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.</em></p>
<p><em>Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model for large resolutions (128+). If you want training code with fp16 and DDP take a look at <a href="https://github.com/lucidrains/stylegan2-pytorch">lucidrains/stylegan2-pytorch</a>.</em></p>
<p>We trained this on <a href="https://github.com/tkarras/progressive_growing_of_gans">CelebA-HQ dataset</a>. You can find the download instruction in this <a href="https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3">discussion on fast.ai</a>. Save the images inside <a href="#dataset_path"><code class="highlight"><span></span><span class="n">data</span><span class="o">/</span><span class="n">stylegan</span></code>
folder</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span><span></span><span class="kn">import</span> <span class="nn">math</span>
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Iterator</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="lineno">34</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">torchvision</span>
<span class="lineno">36</span><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
<span class="lineno">37</span>
<span class="lineno">38</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">39</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">lab</span><span class="p">,</span> <span class="n">monit</span><span class="p">,</span> <span class="n">experiment</span>
<span class="lineno">41</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span>
<span class="lineno">42</span><span class="kn">from</span> <span class="nn">labml_nn.gan.stylegan</span> <span class="kn">import</span> <span class="n">Discriminator</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">MappingNetwork</span><span class="p">,</span> <span class="n">GradientPenalty</span><span class="p">,</span> <span class="n">PathLengthPenalty</span>
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml_nn.gan.wasserstein</span> <span class="kn">import</span> <span class="n">DiscriminatorLoss</span><span class="p">,</span> <span class="n">GeneratorLoss</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.trainer</span> <span class="kn">import</span> <span class="n">ModeState</span>
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">cycle_dataloader</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Dataset</h2>
<p>This loads the training dataset and resize it to the give image size.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span><span class="k">class</span> <span class="nc">Dataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">path</span></code>
path to the folder containing the images </li>
<li><code class="highlight"><span></span><span class="n">image_size</span></code>
size of the image</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Get the paths of all <code class="highlight"><span></span><span class="n">jpg</span></code>
files </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="bp">self</span><span class="o">.</span><span class="n">paths</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;**/*.jpg&#39;</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Transformation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Resize the image </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">image_size</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Convert to PyTorch tensor </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="lineno">72</span> <span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Number of images </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">paths</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Get the the <code class="highlight"><span></span><span class="n">index</span></code>
-th image </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">paths</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
<span class="lineno">81</span> <span class="n">img</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
<span class="lineno">82</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<h2>Configurations</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Device to train the model on. <a href="../../helpers/device.html"><code class="highlight"><span></span><span class="n">DeviceConfigs</span></code>
</a> picks up an available CUDA device or defaults to CPU. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p><a href="index.html#discriminator">StyleGAN2 Discriminator</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">discriminator</span><span class="p">:</span> <span class="n">Discriminator</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p><a href="index.html#generator">StyleGAN2 Generator</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p><a href="index.html#mapping_network">Mapping network</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">mapping_network</span><span class="p">:</span> <span class="n">MappingNetwork</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Discriminator and generator loss functions. We use <a href="../wasserstein/index.html">Wasserstein loss</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">discriminator_loss</span><span class="p">:</span> <span class="n">DiscriminatorLoss</span>
<span class="lineno">105</span> <span class="n">generator_loss</span><span class="p">:</span> <span class="n">GeneratorLoss</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Optimizers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">generator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
<span class="lineno">109</span> <span class="n">discriminator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
<span class="lineno">110</span> <span class="n">mapping_network_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p><a href="index.html#gradient_penalty">Gradient Penalty Regularization Loss</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">gradient_penalty</span> <span class="o">=</span> <span class="n">GradientPenalty</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Gradient penalty coefficient <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.05556em;">γ</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span> <span class="n">gradient_penalty_coefficient</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10.</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p><a href="index.html#path_length_penalty">Path length penalty</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="n">path_length_penalty</span><span class="p">:</span> <span class="n">PathLengthPenalty</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Data loader </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">loader</span><span class="p">:</span> <span class="n">Iterator</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Batch size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Dimensionality of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">d_latent</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Height/width of the image </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Number of layers in the mapping network </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="n">mapping_network_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Generator &amp; Discriminator learning rate </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">learning_rate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-3</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Mapping network learning rate (<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">100</span><span class="mord">×</span></span></span></span></span> lower than the others) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">mapping_network_learning_rate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Number of steps to accumulate gradients on. Use this to increase the effective batch size. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">gradient_accumulate_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> for Adam optimizer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">adam_betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.99</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Probability of mixing styles </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">style_mixing_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Total number of training steps </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">training_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">150_000</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Number of blocks in the generator (calculated based on image resolution) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">n_gen_blocks</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<h3>Lazy regularization</h3>
<p>Instead of calculating the regularization losses, the paper proposes lazy regularization where the regularization terms are calculated once in a while. This improves the training efficiency a lot. </p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>The interval at which to compute gradient penalty </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">lazy_gradient_penalty_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Path length penalty calculation interval </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">lazy_path_penalty_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Skip calculating path length penalty during the initial phase of training </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">lazy_path_penalty_after</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5_000</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>How often to log generated images </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">log_generated_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">500</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>How often to save model checkpoints </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="n">save_checkpoint_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2_000</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Training mode state for logging activations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">mode</span><span class="p">:</span> <span class="n">ModeState</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p><a id="dataset_path"></a> We trained this on <a href="https://github.com/tkarras/progressive_growing_of_gans">CelebA-HQ dataset</a>. You can find the download instruction in this <a href="https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3">discussion on fast.ai</a>. Save the images inside <code class="highlight"><span></span><span class="n">data</span><span class="o">/</span><span class="n">stylegan</span></code>
folder. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">dataset_path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;stylegan2&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<h3>Initialize</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">175</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Create dataset </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">180</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Create data loader </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
<span class="lineno">183</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">drop_last</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Continuous <a href="../../utils.html#cycle_dataloader">cyclic loader</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">loader</span> <span class="o">=</span> <span class="n">cycle_dataloader</span><span class="p">(</span><span class="n">dataloader</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mop"><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span></span></span></span></span> of image resolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">log_resolution</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>Create discriminator and generator </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span> <span class="o">=</span> <span class="n">Discriminator</span><span class="p">(</span><span class="n">log_resolution</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">192</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">Generator</span><span class="p">(</span><span class="n">log_resolution</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>Get number of generator blocks for creating style and noise inputs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">n_blocks</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Create mapping network </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span> <span class="o">=</span> <span class="n">MappingNetwork</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_layers</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Create path length penalty loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="bp">self</span><span class="o">.</span><span class="n">path_length_penalty</span> <span class="o">=</span> <span class="n">PathLengthPenalty</span><span class="p">(</span><span class="mf">0.99</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Discriminator and generator losses </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span> <span class="o">=</span> <span class="n">DiscriminatorLoss</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">202</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_loss</span> <span class="o">=</span> <span class="n">GeneratorLoss</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>Create optimizers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
<span class="lineno">206</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span>
<span class="lineno">207</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span>
<span class="lineno">208</span> <span class="p">)</span>
<span class="lineno">209</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
<span class="lineno">210</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span>
<span class="lineno">211</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span>
<span class="lineno">212</span> <span class="p">)</span>
<span class="lineno">213</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
<span class="lineno">214</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span>
<span class="lineno">215</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span>
<span class="lineno">216</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Set tracker configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_image</span><span class="p">(</span><span class="s2">&quot;generated&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<h3>Sample <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span></h3>
<p>This samples <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span> randomly and get <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span> from the mapping network.</p>
<p>We also apply style mixing sometimes where we generate two latent variables <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqk" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqo" style="margin-right:0.04398em">z</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqo" style="margin-right:0.04398em">z</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and get corresponding <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>. Then we randomly sample a cross-over point and apply <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> to the generator blocks before the cross-over point and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> to the blocks after.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span> <span class="k">def</span> <span class="nf">get_w</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Mix styles </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(())</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_mixing_prob</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>Random cross-over point </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="n">cross_over_point</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(())</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Sample <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqk" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqo" style="margin-right:0.04398em">z</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqo" style="margin-right:0.04398em">z</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="n">z2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">240</span> <span class="n">z1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Get <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">242</span> <span class="n">w1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">(</span><span class="n">z1</span><span class="p">)</span>
<span class="lineno">243</span> <span class="n">w2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">(</span><span class="n">z2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Expand <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> for the generator blocks and concatenate </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">245</span> <span class="n">w1</span> <span class="o">=</span> <span class="n">w1</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">cross_over_point</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">246</span> <span class="n">w2</span> <span class="o">=</span> <span class="n">w2</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span> <span class="o">-</span> <span class="n">cross_over_point</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">247</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>Without mixing </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">249</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>Sample <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">251</span> <span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>Get <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">253</span> <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Expand <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span> for the generator blocks </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">255</span> <span class="k">return</span> <span class="n">w</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<h3>Generate noise</h3>
<p>This generates noise for each <a href="index.html#generator_block">generator block</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="k">def</span> <span class="nf">get_noise</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p>List to store noise </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">264</span> <span class="n">noise</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>Noise resolution starts from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqm" style=""><span class="mord" style="">4</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">266</span> <span class="n">resolution</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>Generate noise for each generator block </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">269</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>The first block has only one <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span></span> convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">271</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">272</span> <span class="n">n1</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>Generate noise to add after the first convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">274</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">275</span> <span class="n">n1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">resolution</span><span class="p">,</span> <span class="n">resolution</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>Generate noise to add after the second convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">277</span> <span class="n">n2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">resolution</span><span class="p">,</span> <span class="n">resolution</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>Add noise tensors to the list </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">280</span> <span class="n">noise</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">n1</span><span class="p">,</span> <span class="n">n2</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>Next block has <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">2</span><span class="mord">×</span></span></span></span></span> resolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">283</span> <span class="n">resolution</span> <span class="o">*=</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>Return noise tensors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">286</span> <span class="k">return</span> <span class="n">noise</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<h3>Generate images</h3>
<p>This generate images using the generator</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="k">def</span> <span class="nf">generate_images</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>Get <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">296</span> <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_w</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>Get noise </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">298</span> <span class="n">noise</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_noise</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>Generate images </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">301</span> <span class="n">images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">noise</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<p>Return images and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">304</span> <span class="k">return</span> <span class="n">images</span><span class="p">,</span> <span class="n">w</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<h3>Training Step</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">306</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<p>Train the discriminator </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Discriminator&#39;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>Reset gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">314</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>Accumulate gradients for <code class="highlight"><span></span><span class="n">gradient_accumulate_steps</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">317</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gradient_accumulate_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>Sample images from generator </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">319</span> <span class="n">generated_images</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generate_images</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>Discriminator classification for generated images </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">321</span> <span class="n">fake_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">generated_images</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<p>Get real images from the data loader </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">324</span> <span class="n">real_images</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loader</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>We need to calculate gradients w.r.t. real images for gradient penalty </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">326</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_gradient_penalty_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">327</span> <span class="n">real_images</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>Discriminator classification for real images </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">329</span> <span class="n">real_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">real_images</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<p>Get discriminator loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">332</span> <span class="n">real_loss</span><span class="p">,</span> <span class="n">fake_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span><span class="p">(</span><span class="n">real_output</span><span class="p">,</span> <span class="n">fake_output</span><span class="p">)</span>
<span class="lineno">333</span> <span class="n">disc_loss</span> <span class="o">=</span> <span class="n">real_loss</span> <span class="o">+</span> <span class="n">fake_loss</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>Add gradient penalty </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">336</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_gradient_penalty_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>Calculate and log gradient penalty </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">338</span> <span class="n">gp</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gradient_penalty</span><span class="p">(</span><span class="n">real_images</span><span class="p">,</span> <span class="n">real_output</span><span class="p">)</span>
<span class="lineno">339</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;loss.gp&#39;</span><span class="p">,</span> <span class="n">gp</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>Multiply by coefficient and add gradient penalty </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">341</span> <span class="n">disc_loss</span> <span class="o">=</span> <span class="n">disc_loss</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">gradient_penalty_coefficient</span> <span class="o">*</span> <span class="n">gp</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_gradient_penalty_interval</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>Compute gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">344</span> <span class="n">disc_loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>Log discriminator loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">347</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;loss.discriminator&#39;</span><span class="p">,</span> <span class="n">disc_loss</span><span class="p">)</span>
<span class="lineno">348</span>
<span class="lineno">349</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_generated_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<p>Log discriminator model parameters occasionally </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">351</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;discriminator&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>Clip gradients for stabilization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">354</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p>Take optimizer step </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">356</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>Train the generator </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">359</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Generator&#39;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-98'>
<div class='docs'>
<div class='section-link'>
<a href='#section-98'>#</a>
</div>
<p>Reset gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">361</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="lineno">362</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-99'>
<div class='docs'>
<div class='section-link'>
<a href='#section-99'>#</a>
</div>
<p>Accumulate gradients for <code class="highlight"><span></span><span class="n">gradient_accumulate_steps</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">365</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gradient_accumulate_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-100'>
<div class='docs'>
<div class='section-link'>
<a href='#section-100'>#</a>
</div>
<p>Sample images from generator </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">367</span> <span class="n">generated_images</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generate_images</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-101'>
<div class='docs'>
<div class='section-link'>
<a href='#section-101'>#</a>
</div>
<p>Discriminator classification for generated images </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">369</span> <span class="n">fake_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">generated_images</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-102'>
<div class='docs'>
<div class='section-link'>
<a href='#section-102'>#</a>
</div>
<p>Get generator loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">372</span> <span class="n">gen_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_loss</span><span class="p">(</span><span class="n">fake_output</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-103'>
<div class='docs'>
<div class='section-link'>
<a href='#section-103'>#</a>
</div>
<p>Add path length penalty </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">375</span> <span class="k">if</span> <span class="n">idx</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_path_penalty_after</span> <span class="ow">and</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_path_penalty_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-104'>
<div class='docs'>
<div class='section-link'>
<a href='#section-104'>#</a>
</div>
<p>Calculate path length penalty </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">377</span> <span class="n">plp</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">path_length_penalty</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">generated_images</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-105'>
<div class='docs'>
<div class='section-link'>
<a href='#section-105'>#</a>
</div>
<p>Ignore if <code class="highlight"><span></span><span class="n">nan</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">379</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">plp</span><span class="p">):</span>
<span class="lineno">380</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;loss.plp&#39;</span><span class="p">,</span> <span class="n">plp</span><span class="p">)</span>
<span class="lineno">381</span> <span class="n">gen_loss</span> <span class="o">=</span> <span class="n">gen_loss</span> <span class="o">+</span> <span class="n">plp</span></pre></div>
</div>
</div>
<div class='section' id='section-106'>
<div class='docs'>
<div class='section-link'>
<a href='#section-106'>#</a>
</div>
<p>Calculate gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">384</span> <span class="n">gen_loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-107'>
<div class='docs'>
<div class='section-link'>
<a href='#section-107'>#</a>
</div>
<p>Log generator loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">387</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;loss.generator&#39;</span><span class="p">,</span> <span class="n">gen_loss</span><span class="p">)</span>
<span class="lineno">388</span>
<span class="lineno">389</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_generated_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-108'>
<div class='docs'>
<div class='section-link'>
<a href='#section-108'>#</a>
</div>
<p>Log discriminator model parameters occasionally </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">391</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;generator&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span>
<span class="lineno">392</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;mapping_network&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-109'>
<div class='docs'>
<div class='section-link'>
<a href='#section-109'>#</a>
</div>
<p>Clip gradients for stabilization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">395</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="lineno">396</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-110'>
<div class='docs'>
<div class='section-link'>
<a href='#section-110'>#</a>
</div>
<p>Take optimizer step </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">399</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="lineno">400</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-111'>
<div class='docs'>
<div class='section-link'>
<a href='#section-111'>#</a>
</div>
<p>Log generated images </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">403</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_generated_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">404</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;generated&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">generated_images</span><span class="p">[:</span><span class="mi">6</span><span class="p">],</span> <span class="n">real_images</span><span class="p">[:</span><span class="mi">3</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-112'>
<div class='docs'>
<div class='section-link'>
<a href='#section-112'>#</a>
</div>
<p>Save model checkpoints </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">406</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">save_checkpoint_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-113'>
<div class='docs'>
<div class='section-link'>
<a href='#section-113'>#</a>
</div>
<p>Save checkpoint </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">408</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-114'>
<div class='docs'>
<div class='section-link'>
<a href='#section-114'>#</a>
</div>
<p>Flush tracker </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">411</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-115'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-115'>#</a>
</div>
<h2>Train model</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">413</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-116'>
<div class='docs'>
<div class='section-link'>
<a href='#section-116'>#</a>
</div>
<p>Loop for <code class="highlight"><span></span><span class="n">training_steps</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">419</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-117'>
<div class='docs'>
<div class='section-link'>
<a href='#section-117'>#</a>
</div>
<p>Take a training step </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">421</span> <span class="bp">self</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">i</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-118'>
<div class='docs'>
<div class='section-link'>
<a href='#section-118'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">423</span> <span class="k">if</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_generated_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">424</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-119'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-119'>#</a>
</div>
<h3>Train StyleGAN2</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">427</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-120'>
<div class='docs'>
<div class='section-link'>
<a href='#section-120'>#</a>
</div>
<p>Create an experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">433</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;stylegan2&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-121'>
<div class='docs'>
<div class='section-link'>
<a href='#section-121'>#</a>
</div>
<p>Create configurations object </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">435</span> <span class="n">configs</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-122'>
<div class='docs'>
<div class='section-link'>
<a href='#section-122'>#</a>
</div>
<p>Set configurations and override some </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">438</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">439</span> <span class="s1">&#39;device.cuda_device&#39;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span>
<span class="lineno">440</span> <span class="s1">&#39;image_size&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span>
<span class="lineno">441</span> <span class="s1">&#39;log_generated_interval&#39;</span><span class="p">:</span> <span class="mi">200</span>
<span class="lineno">442</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-123'>
<div class='docs'>
<div class='section-link'>
<a href='#section-123'>#</a>
</div>
<p>Initialize </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">445</span> <span class="n">configs</span><span class="o">.</span><span class="n">init</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-124'>
<div class='docs'>
<div class='section-link'>
<a href='#section-124'>#</a>
</div>
<p>Set models for saving and loading </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">447</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">mapping_network</span><span class="o">=</span><span class="n">configs</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">,</span>
<span class="lineno">448</span> <span class="n">generator</span><span class="o">=</span><span class="n">configs</span><span class="o">.</span><span class="n">generator</span><span class="p">,</span>
<span class="lineno">449</span> <span class="n">discriminator</span><span class="o">=</span><span class="n">configs</span><span class="o">.</span><span class="n">discriminator</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-125'>
<div class='docs'>
<div class='section-link'>
<a href='#section-125'>#</a>
</div>
<p>Start the experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">452</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-126'>
<div class='docs'>
<div class='section-link'>
<a href='#section-126'>#</a>
</div>
<p>Run the training loop </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">454</span> <span class="n">configs</span><span class="o">.</span><span class="n">train</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-127'>
<div class='docs'>
<div class='section-link'>
<a href='#section-127'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">458</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">459</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>