Files
2024-06-21 19:35:22 +05:30

1763 lines
120 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="StyleGan2 模型训练代码的带注释的 PyTorch 实现。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="StyleGan 2 模型训练"/>
<meta name="twitter:description" content="StyleGan2 模型训练代码的带注释的 PyTorch 实现。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/gan/stylegan/experiment.html"/>
<meta property="og:title" content="StyleGan 2 模型训练"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="StyleGan 2 模型训练"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="StyleGan 2 模型训练"/>
<meta property="og:description" content="StyleGan2 模型训练代码的带注释的 PyTorch 实现。"/>
<title>StyleGan 2 模型训练</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/gan/stylegan/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">gan</a>
<a class="parent" href="index.html">stylegan</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/stylegan/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="index.html">StyleGan 2</a> 模型训练</h1>
<p>这是 <a href="index.html">StyleGan 2</a> 模型的训练代码。</p>
<p><img alt="Generated Images" src="generated_64.png"></p>
<p><small><em>这些是在训练了大约 80K 步之后生成的<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">6</span><span class="mord coloredeq eqm" style=""><span class="mord" style="">4</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">6</span><span class="mord coloredeq eqm" style=""><span class="mord" style="">4</span></span></span></span></span></span>图像。</em></small></p>
<p><em>我们的实现是一个简约的 StyleGan 2 模型训练代码。仅支持单个 GPU 训练,以保持实现简单。我们设法缩小了它,使其保持在不到 500 行代码中,包括训练循环。</em></p>
<p><em>如果没有 DDP分布式数据并行和多 GPU 训练将无法为大分辨率128+)训练模型。如果你想用 fp16 和 DDP 训练代码,可以看看 l <a href="https://github.com/lucidrains/stylegan2-pytorch">ucidrains/stylegan2-pytorch</a></em></p>
<p>我们在 <a href="https://github.com/tkarras/progressive_growing_of_gans">Celeba-HQ 数据集</a>上训练了这个。你可以在这篇关于 <a href="https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3">fast.ai 的讨论</a>中找到下载说明。将图像保存在<a href="#dataset_path"><code class="highlight"><span></span><span class="n">data</span><span class="o">/</span><span class="n">stylegan</span></code>
文件夹中</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span><span></span><span class="kn">import</span> <span class="nn">math</span>
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Iterator</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="lineno">34</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">36</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">37</span><span class="kn">import</span> <span class="nn">torchvision</span>
<span class="lineno">38</span><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
<span class="lineno">39</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">lab</span><span class="p">,</span> <span class="n">monit</span><span class="p">,</span> <span class="n">experiment</span>
<span class="lineno">41</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span>
<span class="lineno">42</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">ModeState</span><span class="p">,</span> <span class="n">hook_model_outputs</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_nn.gan.stylegan</span> <span class="kn">import</span> <span class="n">Discriminator</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">MappingNetwork</span><span class="p">,</span> <span class="n">GradientPenalty</span><span class="p">,</span> <span class="n">PathLengthPenalty</span>
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_nn.gan.wasserstein</span> <span class="kn">import</span> <span class="n">DiscriminatorLoss</span><span class="p">,</span> <span class="n">GeneratorLoss</span>
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">cycle_dataloader</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>数据集</h2>
<p>这将加载训练数据集并将其调整为给定的图像大小。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span><span class="k">class</span> <span class="nc">Dataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">path</span></code>
包含图像的文件夹的路径</li>
<li><code class="highlight"><span></span><span class="n">image_size</span></code>
图像的大小</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>获取所有<code class="highlight"><span></span><span class="n">jpg</span></code>
文件的路径</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="bp">self</span><span class="o">.</span><span class="n">paths</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;**/*.jpg&#39;</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>转型</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>调整图像大小</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">image_size</span><span class="p">),</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>转换为 pyTorch 张量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="lineno">72</span> <span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>图像数量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">paths</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>获取第<code class="highlight"><span></span><span class="n">index</span></code>
-th 张图片</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">paths</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
<span class="lineno">81</span> <span class="n">img</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
<span class="lineno">82</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<h2>配置</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>用于训练模型的设备。<a href="https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs"><code class="highlight"><span></span><span class="n">DeviceConfigs</span></code>
</a>选择可用的 CUDA 设备或默认为 CPU。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p><a href="index.html#discriminator">StyleGan2 鉴别器</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">discriminator</span><span class="p">:</span> <span class="n">Discriminator</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p><a href="index.html#generator">StyleGan2 生成器</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p><a href="index.html#mapping_network">测绘网络</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">mapping_network</span><span class="p">:</span> <span class="n">MappingNetwork</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>鉴别器和发生器损耗函数。我们使用 <a href="../wasserstein/index.html">Wasserstein 的损失</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">discriminator_loss</span><span class="p">:</span> <span class="n">DiscriminatorLoss</span>
<span class="lineno">105</span> <span class="n">generator_loss</span><span class="p">:</span> <span class="n">GeneratorLoss</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>优化器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">generator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
<span class="lineno">109</span> <span class="n">discriminator_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span>
<span class="lineno">110</span> <span class="n">mapping_network_optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p><a href="index.html#gradient_penalty">梯度惩罚正则化损失</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">gradient_penalty</span> <span class="o">=</span> <span class="n">GradientPenalty</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>梯度惩罚系数<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.05556em;">γ</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span> <span class="n">gradient_penalty_coefficient</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10.</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p><a href="index.html#path_length_penalty">路径长度惩罚</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="n">path_length_penalty</span><span class="p">:</span> <span class="n">PathLengthPenalty</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>数据加载器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">loader</span><span class="p">:</span> <span class="n">Iterator</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>批量大小</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span>和的维度<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">d_latent</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>图像的高度/宽度</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>制图网络中的图层数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="n">mapping_network_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>生成器和鉴别器学习速率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">learning_rate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-3</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>映射网络学习率(<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">100</span><span class="mord">×</span></span></span></span></span>低于其他)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">mapping_network_learning_rate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>累积梯度的步数。使用它可以增加有效批次大小。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">gradient_accumulate_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.05278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>于 Adam 优化器来说</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">adam_betas</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.99</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>混合样式的概率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">style_mixing_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>训练步数总数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">training_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">150_000</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>生成器中的块数(根据图像分辨率计算)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">n_gen_blocks</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<h3>懒惰正则化</h3>
本@@ <p>文没有计算正则化损失,而是提出了懒惰的正则化,即偶尔计算一次正则化项。这大大提高了训练效率。</p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>计算梯度惩罚的间隔</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">lazy_gradient_penalty_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>路径长度惩罚计算间隔</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">lazy_path_penalty_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>在训练的初始阶段跳过计算路径长度损失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">lazy_path_penalty_after</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5_000</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>记录生成的图像的频率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">log_generated_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">500</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>保存模型检查点的频率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="n">save_checkpoint_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2_000</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>日志记录激活的训练模式状态</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">mode</span><span class="p">:</span> <span class="n">ModeState</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>是否记录模型层输出</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">log_layer_outputs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p><a id="dataset_path"></a>我们在 <a href="https://github.com/tkarras/progressive_growing_of_gans">Celeba-HQ 数据集</a>上训练了这个。你可以在这篇关于 <a href="https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3">fast.ai 的讨论</a>中找到下载说明。将图像保存在<code class="highlight"><span></span><span class="n">data</span><span class="o">/</span><span class="n">stylegan</span></code>
文件夹中。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">175</span> <span class="n">dataset_path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;stylegan2&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<h3>初始化</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">177</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>创建数据集</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>创建数据加载器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
<span class="lineno">185</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">drop_last</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>连续<a href="../../utils.html#cycle_dataloader">循环装载机</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">loader</span> <span class="o">=</span> <span class="n">cycle_dataloader</span><span class="p">(</span><span class="n">dataloader</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.93858em;vertical-align:-0.24414em;"></span><span class="mop"><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.20696799999999996em;"><span style="top:-2.4558600000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.24414em;"><span></span></span></span></span></span></span></span></span></span></span>的图像分辨率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="n">log_resolution</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>创建鉴别器和生成器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span> <span class="o">=</span> <span class="n">Discriminator</span><span class="p">(</span><span class="n">log_resolution</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">194</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">Generator</span><span class="p">(</span><span class="n">log_resolution</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>获取用于创建样式和噪声输入的生成器模块的数量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">n_blocks</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>创建测绘网络</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span> <span class="o">=</span> <span class="n">MappingNetwork</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_layers</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>创建路径长度惩罚损失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="bp">self</span><span class="o">.</span><span class="n">path_length_penalty</span> <span class="o">=</span> <span class="n">PathLengthPenalty</span><span class="p">(</span><span class="mf">0.99</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>添加模型挂接以监视层输出</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_layer_outputs</span><span class="p">:</span>
<span class="lineno">204</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">,</span> <span class="s1">&#39;discriminator&#39;</span><span class="p">)</span>
<span class="lineno">205</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">,</span> <span class="s1">&#39;generator&#39;</span><span class="p">)</span>
<span class="lineno">206</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">,</span> <span class="s1">&#39;mapping_network&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>鉴别器和发电机损耗</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span> <span class="o">=</span> <span class="n">DiscriminatorLoss</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">210</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_loss</span> <span class="o">=</span> <span class="n">GeneratorLoss</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>创建优化器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
<span class="lineno">214</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span>
<span class="lineno">215</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span>
<span class="lineno">216</span> <span class="p">)</span>
<span class="lineno">217</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
<span class="lineno">218</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span>
<span class="lineno">219</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span>
<span class="lineno">220</span> <span class="p">)</span>
<span class="lineno">221</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span>
<span class="lineno">222</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span>
<span class="lineno">223</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">adam_betas</span>
<span class="lineno">224</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>设置跟踪器配置</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_image</span><span class="p">(</span><span class="s2">&quot;generated&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<h3>样本<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span></h3>
<p>这是<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span>随机采样并<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span>从映射网络中获取。</p>
<p>有时我们还会应用样式混合,我们生成两个潜在变量<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqk" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqo" style="margin-right:0.04398em">z</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqo" style="margin-right:0.04398em">z</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>并得到相应的<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>。然后我们随机采样一个交叉点,然后应用<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>于交叉点之前的生成器方块和<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>之后的区块。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">229</span> <span class="k">def</span> <span class="nf">get_w</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>混合风格</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(())</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_mixing_prob</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>随机交叉点</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">245</span> <span class="n">cross_over_point</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(())</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>样本<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqk" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqo" style="margin-right:0.04398em">z</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqo" style="margin-right:0.04398em">z</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">247</span> <span class="n">z2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">248</span> <span class="n">z1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>获取<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span> <span class="n">w1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">(</span><span class="n">z1</span><span class="p">)</span>
<span class="lineno">251</span> <span class="n">w2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">(</span><span class="n">z2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="margin-right:0.02691em">w</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>开 and for 生成器块并连接</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">253</span> <span class="n">w1</span> <span class="o">=</span> <span class="n">w1</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">cross_over_point</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">254</span> <span class="n">w2</span> <span class="o">=</span> <span class="n">w2</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span> <span class="o">-</span> <span class="n">cross_over_point</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">255</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>不混合</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>样本<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">259</span> <span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_latent</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>获取<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">261</span> <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span>为发电机组展开</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="k">return</span> <span class="n">w</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<h3>产生噪音</h3>
<p>这会为每个<a href="index.html#generator_block">发电机组</a>产生噪声</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span> <span class="k">def</span> <span class="nf">get_noise</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>存储噪音的列表</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">272</span> <span class="n">noise</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>噪声分辨率从<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqm" style=""><span class="mord" style="">4</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">274</span> <span class="n">resolution</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>为每个发电机组生成噪声</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">277</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_gen_blocks</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>第一个方块只有一个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span></span>卷积</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">279</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">280</span> <span class="n">n1</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>生成要在第一个卷积层之后添加的噪波</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">283</span> <span class="n">n1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">resolution</span><span class="p">,</span> <span class="n">resolution</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>生成要在第二个卷积层之后添加的噪波</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span> <span class="n">n2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">resolution</span><span class="p">,</span> <span class="n">resolution</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>将噪声张量添加到列表中</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="n">noise</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">n1</span><span class="p">,</span> <span class="n">n2</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>下一个区块有<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">2</span><span class="mord">×</span></span></span></span></span>分辨率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">291</span> <span class="n">resolution</span> <span class="o">*=</span> <span class="mi">2</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>返回噪声张量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="k">return</span> <span class="n">noise</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<h3>生成图像</h3>
<p>这会使用生成器生成图像</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">296</span> <span class="k">def</span> <span class="nf">generate_images</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>得到<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">304</span> <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_w</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<p>得到噪音</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">306</span> <span class="n">noise</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_noise</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<p>生成图像</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">309</span> <span class="n">images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">noise</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<p>返回图像和<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span> <span class="k">return</span> <span class="n">images</span><span class="p">,</span> <span class="n">w</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<h3>训练步骤</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">314</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>训练鉴别器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">320</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Discriminator&#39;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>重置渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">322</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>累积梯度<code class="highlight"><span></span><span class="n">gradient_accumulate_steps</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">325</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gradient_accumulate_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<p>更新<code class="highlight"><span></span><span class="n">mode</span></code>
。设置是否记录激活</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">327</span> <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">is_log_activations</span><span class="o">=</span><span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_generated_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>来自生成器的样本图像</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">329</span> <span class="n">generated_images</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generate_images</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>生成图像的鉴别器分类</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">331</span> <span class="n">fake_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">generated_images</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<p>从数据加载器获取真实图像</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">334</span> <span class="n">real_images</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loader</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>我们需要用真实图像计算梯度以获得梯度惩罚</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">336</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_gradient_penalty_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">337</span> <span class="n">real_images</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>真实图像的鉴别器分类</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">339</span> <span class="n">real_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">real_images</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>获得鉴别器损失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">342</span> <span class="n">real_loss</span><span class="p">,</span> <span class="n">fake_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span><span class="p">(</span><span class="n">real_output</span><span class="p">,</span> <span class="n">fake_output</span><span class="p">)</span>
<span class="lineno">343</span> <span class="n">disc_loss</span> <span class="o">=</span> <span class="n">real_loss</span> <span class="o">+</span> <span class="n">fake_loss</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>添加渐变惩罚</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">346</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_gradient_penalty_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>计算并记录梯度损失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">348</span> <span class="n">gp</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gradient_penalty</span><span class="p">(</span><span class="n">real_images</span><span class="p">,</span> <span class="n">real_output</span><span class="p">)</span>
<span class="lineno">349</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;loss.gp&#39;</span><span class="p">,</span> <span class="n">gp</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<p>乘以系数并添加梯度惩罚</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">351</span> <span class="n">disc_loss</span> <span class="o">=</span> <span class="n">disc_loss</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">gradient_penalty_coefficient</span> <span class="o">*</span> <span class="n">gp</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_gradient_penalty_interval</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>计算梯度</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">354</span> <span class="n">disc_loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p>日志鉴别器丢失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">357</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;loss.discriminator&#39;</span><span class="p">,</span> <span class="n">disc_loss</span><span class="p">)</span>
<span class="lineno">358</span>
<span class="lineno">359</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_generated_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>偶尔记录鉴别器模型参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">361</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;discriminator&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-98'>
<div class='docs'>
<div class='section-link'>
<a href='#section-98'>#</a>
</div>
<p>用于稳定的剪辑渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">364</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-99'>
<div class='docs'>
<div class='section-link'>
<a href='#section-99'>#</a>
</div>
<p>采取优化器步骤</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">366</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-100'>
<div class='docs'>
<div class='section-link'>
<a href='#section-100'>#</a>
</div>
<p>训练发电机</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">369</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Generator&#39;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-101'>
<div class='docs'>
<div class='section-link'>
<a href='#section-101'>#</a>
</div>
<p>重置渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">371</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="lineno">372</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-102'>
<div class='docs'>
<div class='section-link'>
<a href='#section-102'>#</a>
</div>
<p>累积梯度<code class="highlight"><span></span><span class="n">gradient_accumulate_steps</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">375</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gradient_accumulate_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-103'>
<div class='docs'>
<div class='section-link'>
<a href='#section-103'>#</a>
</div>
<p>来自生成器的样本图像</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">377</span> <span class="n">generated_images</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generate_images</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-104'>
<div class='docs'>
<div class='section-link'>
<a href='#section-104'>#</a>
</div>
<p>生成图像的鉴别器分类</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">379</span> <span class="n">fake_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">generated_images</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-105'>
<div class='docs'>
<div class='section-link'>
<a href='#section-105'>#</a>
</div>
<p>获得发电机损失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">382</span> <span class="n">gen_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_loss</span><span class="p">(</span><span class="n">fake_output</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-106'>
<div class='docs'>
<div class='section-link'>
<a href='#section-106'>#</a>
</div>
<p>增加路径长度惩罚</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">385</span> <span class="k">if</span> <span class="n">idx</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_path_penalty_after</span> <span class="ow">and</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">lazy_path_penalty_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-107'>
<div class='docs'>
<div class='section-link'>
<a href='#section-107'>#</a>
</div>
<p>计算路径长度损失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">387</span> <span class="n">plp</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">path_length_penalty</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">generated_images</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-108'>
<div class='docs'>
<div class='section-link'>
<a href='#section-108'>#</a>
</div>
<p>忽略如果<code class="highlight"><span></span><span class="n">nan</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">389</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">plp</span><span class="p">):</span>
<span class="lineno">390</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;loss.plp&#39;</span><span class="p">,</span> <span class="n">plp</span><span class="p">)</span>
<span class="lineno">391</span> <span class="n">gen_loss</span> <span class="o">=</span> <span class="n">gen_loss</span> <span class="o">+</span> <span class="n">plp</span></pre></div>
</div>
</div>
<div class='section' id='section-109'>
<div class='docs'>
<div class='section-link'>
<a href='#section-109'>#</a>
</div>
<p>计算梯度</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">394</span> <span class="n">gen_loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-110'>
<div class='docs'>
<div class='section-link'>
<a href='#section-110'>#</a>
</div>
<p>日志生成器丢失</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">397</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;loss.generator&#39;</span><span class="p">,</span> <span class="n">gen_loss</span><span class="p">)</span>
<span class="lineno">398</span>
<span class="lineno">399</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_generated_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-111'>
<div class='docs'>
<div class='section-link'>
<a href='#section-111'>#</a>
</div>
<p>偶尔记录鉴别器模型参数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">401</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;generator&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span>
<span class="lineno">402</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;mapping_network&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-112'>
<div class='docs'>
<div class='section-link'>
<a href='#section-112'>#</a>
</div>
<p>用于稳定的剪辑渐变</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">405</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="lineno">406</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping_network</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-113'>
<div class='docs'>
<div class='section-link'>
<a href='#section-113'>#</a>
</div>
<p>采取优化器步骤</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">409</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="lineno">410</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping_network_optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-114'>
<div class='docs'>
<div class='section-link'>
<a href='#section-114'>#</a>
</div>
<p>日志生成的图像</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">413</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_generated_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">414</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;generated&#39;</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">generated_images</span><span class="p">[:</span><span class="mi">6</span><span class="p">],</span> <span class="n">real_images</span><span class="p">[:</span><span class="mi">3</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-115'>
<div class='docs'>
<div class='section-link'>
<a href='#section-115'>#</a>
</div>
<p>保存模型检查点</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">416</span> <span class="k">if</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">save_checkpoint_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">417</span> <span class="n">experiment</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-116'>
<div class='docs'>
<div class='section-link'>
<a href='#section-116'>#</a>
</div>
<p>冲洗追踪器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">420</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-117'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-117'>#</a>
</div>
<h2>火车模型</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">422</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-118'>
<div class='docs'>
<div class='section-link'>
<a href='#section-118'>#</a>
</div>
<p>循环寻回<code class="highlight"><span></span><span class="n">training_steps</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">428</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-119'>
<div class='docs'>
<div class='section-link'>
<a href='#section-119'>#</a>
</div>
<p>迈出训练一步</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">430</span> <span class="bp">self</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">i</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-120'>
<div class='docs'>
<div class='section-link'>
<a href='#section-120'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">432</span> <span class="k">if</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_generated_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">433</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-121'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-121'>#</a>
</div>
<h3>Train styleGan2</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">436</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-122'>
<div class='docs'>
<div class='section-link'>
<a href='#section-122'>#</a>
</div>
<p>创建实验</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">442</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;stylegan2&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-123'>
<div class='docs'>
<div class='section-link'>
<a href='#section-123'>#</a>
</div>
<p>创建配置对象</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">444</span> <span class="n">configs</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-124'>
<div class='docs'>
<div class='section-link'>
<a href='#section-124'>#</a>
</div>
<p>设置配置并覆盖一些</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">447</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">448</span> <span class="s1">&#39;device.cuda_device&#39;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span>
<span class="lineno">449</span> <span class="s1">&#39;image_size&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span>
<span class="lineno">450</span> <span class="s1">&#39;log_generated_interval&#39;</span><span class="p">:</span> <span class="mi">200</span>
<span class="lineno">451</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-125'>
<div class='docs'>
<div class='section-link'>
<a href='#section-125'>#</a>
</div>
<p>初始化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">454</span> <span class="n">configs</span><span class="o">.</span><span class="n">init</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-126'>
<div class='docs'>
<div class='section-link'>
<a href='#section-126'>#</a>
</div>
<p>设置用于保存和加载的模型</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">456</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">mapping_network</span><span class="o">=</span><span class="n">configs</span><span class="o">.</span><span class="n">mapping_network</span><span class="p">,</span>
<span class="lineno">457</span> <span class="n">generator</span><span class="o">=</span><span class="n">configs</span><span class="o">.</span><span class="n">generator</span><span class="p">,</span>
<span class="lineno">458</span> <span class="n">discriminator</span><span class="o">=</span><span class="n">configs</span><span class="o">.</span><span class="n">discriminator</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-127'>
<div class='docs'>
<div class='section-link'>
<a href='#section-127'>#</a>
</div>
<p>开始实验</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">461</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-128'>
<div class='docs'>
<div class='section-link'>
<a href='#section-128'>#</a>
</div>
<p>运行训练循环</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">463</span> <span class="n">configs</span><span class="o">.</span><span class="n">train</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-129'>
<div class='docs'>
<div class='section-link'>
<a href='#section-129'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">467</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">468</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>