Files
2024-06-21 19:35:22 +05:30

1453 lines
119 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="用于去噪扩散概率模型 (DDPM) 的 unET 模型"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="用于去噪扩散概率模型 (DDPM) 的 U-Net 模型"/>
<meta name="twitter:description" content="用于去噪扩散概率模型 (DDPM) 的 unET 模型"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/unet.html"/>
<meta property="og:title" content="用于去噪扩散概率模型 (DDPM) 的 U-Net 模型"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="用于去噪扩散概率模型 (DDPM) 的 U-Net 模型"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="用于去噪扩散概率模型 (DDPM) 的 U-Net 模型"/>
<meta property="og:description" content="用于去噪扩散概率模型 (DDPM) 的 unET 模型"/>
<title>用于去噪扩散概率模型 (DDPM) 的 U-Net 模型</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/unet.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">diffusion</a>
<a class="parent" href="index.html">ddpm</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/unet.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>用于<a href="index.html">去噪扩散概率模型 (DDPM) 的 U-Net 模型</a></h1>
<p>这是一个基于 <a href="../../unet/index.html">U-Net</a> 的模型,用于预测噪声<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord" style="color:lightgreen"><span class="mord mathnormal" style="">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span><span class="mclose">)</span></span></span></span></span></p>
<p>U-Net 是从模型图中的 U 形中获取它的名字。它通过逐步降低(减半)要素图分辨率,然后提高分辨率来处理给定的图像。每种分辨率都有直通连接。</p>
<p><img alt="U-Net diagram from paper" src="../../unet/unet.png"></p>
<p>此实现包含对原始 U-Net残差块、多头注意的大量修改还添加了时间步长嵌入<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">24</span><span></span><span class="kn">import</span> <span class="nn">math</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span>
<span class="lineno">26</span>
<span class="lineno">27</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">29</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>Swish 激活功能</h3>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44445em;vertical-align:0em;"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="k">class</span> <span class="nc">Swish</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">41</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<h3>嵌入用于<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span></h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span><span class="k">class</span> <span class="nc">TimeEmbedding</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_channels</span></code>
是嵌入中的维数</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span> <span class="o">=</span> <span class="n">n_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>第一个线性层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span> <span class="o">//</span> <span class="mi">4</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>激活</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>第二个线性层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>创建与<a href="../../transformers/positional_encoding.html">变压器相同的</a>正弦位置嵌入</p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:6.600059999999999em;vertical-align:-3.0500299999999996em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.5500299999999996em;"><span style="top:-5.55003em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.4231360000000004em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="">t</span></span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.412972em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.2500000000000004em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.4231360000000004em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="">t</span></span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.412972em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:3.0500299999999996em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.5500299999999996em;"><span style="top:-5.55003em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathnormal">s</span><span class="mord mathnormal">in</span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.29208em;"><span style="top:-2.121225em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1000</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9887749999999998em;"><span style="top:-3.3902150000000004em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8550857142857142em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqi" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span><span style="top:-3.2255em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.40352142857142853em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8787749999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span><span style="top:-2.2500000000000004em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathnormal">cos</span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.29208em;"><span style="top:-2.121225em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1000</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9887749999999998em;"><span style="top:-3.3902150000000004em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8550857142857142em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqi" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span><span style="top:-3.2255em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.40352142857142853em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8787749999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:3.0500299999999996em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>在哪<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="">d</span></span></span></span></span></span><code class="highlight"><span></span><span class="n">half_dim</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">half_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span> <span class="o">//</span> <span class="mi">8</span>
<span class="lineno">73</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">half_dim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">74</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">half_dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">t</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">*</span> <span class="o">-</span><span class="n">emb</span><span class="p">)</span>
<span class="lineno">75</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">emb</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">76</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">emb</span><span class="o">.</span><span class="n">sin</span><span class="p">(),</span> <span class="n">emb</span><span class="o">.</span><span class="n">cos</span><span class="p">()),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>使用 MLP 进行转型</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lin1</span><span class="p">(</span><span class="n">emb</span><span class="p">))</span>
<span class="lineno">80</span> <span class="n">emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin2</span><span class="p">(</span><span class="n">emb</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">return</span> <span class="n">emb</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<h3>剩余方块</h3>
<p>残差块具有两个具有组归一化的卷积层。每个分辨率都使用两个残差块进行处理。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span><span class="k">class</span> <span class="nc">ResidualBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
是输入通道的数量</li>
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
是输入通道的数量</li>
<li><code class="highlight"><span></span><span class="n">time_channels</span></code>
是时间步 (<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span>) 嵌入中的通道数</li>
<li><code class="highlight"><span></span><span class="n">n_groups</span></code>
是用于组<a href="../../normalization/group_norm/index.html">标准化的组</a></li>
<li><code class="highlight"><span></span><span class="n">dropout</span></code>
是辍学率</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">95</span> <span class="n">n_groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>组归一化和第一个卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">)</span>
<span class="lineno">106</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span>
<span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>组归一化和第二个卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">)</span>
<span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span>
<span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>如果输入通道的数量不等于输出通道的数量,我们必须投影快捷方式连接</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">if</span> <span class="n">in_channels</span> <span class="o">!=</span> <span class="n">out_channels</span><span class="p">:</span>
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">118</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">119</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>用于时间嵌入的线性层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">time_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">)</span>
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_act</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span>
<span class="lineno">124</span>
<span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">t</span></code>
有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>第一个卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>添加时间嵌入</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="n">h</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">time_act</span><span class="p">(</span><span class="n">t</span><span class="p">))[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>第二个卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">h</span><span class="p">))))</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>添加快捷方式连接并返回</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="k">return</span> <span class="n">h</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<h3>注意力块</h3>
<p>这类似于<a href="../../transformers/mha.html">变压器多头的关注</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span><span class="k">class</span> <span class="nc">AttentionBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_channels</span></code>
是输入中的声道数</li>
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
是多头关注中的头部数量</li>
<li><code class="highlight"><span></span><span class="n">d_k</span></code>
是每个头部的尺寸数</li>
<li><code class="highlight"><span></span><span class="n">n_groups</span></code>
是组归一<a href="../../normalization/group_norm/index.html">化的组</a></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>默认<code class="highlight"><span></span><span class="n">d_k</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="k">if</span> <span class="n">d_k</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">161</span> <span class="n">d_k</span> <span class="o">=</span> <span class="n">n_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>归一化层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>查询、键和值的投影</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>用于最终变换的线性层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>缩放点产品注意力</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">d_k</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
<span class="lineno">172</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_k</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">t</span></code>
有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">t</span></code>
未使用,但它保留在参数中,因为要与注意层函数签名匹配<code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">_</span> <span class="o">=</span> <span class="n">t</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>塑造身材</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">x</span></code>
成形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>获取查询、键和值(串联)并将其调整为<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>拆分查询、键和值。他们每个人都会有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>计算缩放的点积<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.633028em;vertical-align:-0.538em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqi" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">Q</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bihd,bjhd-&gt;bijh&#39;</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>顺序维度上的 Softmax<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mord mathnormal">so</span><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span><span class="mord mathnormal">ma</span><span class="mord mathnormal">x</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqi" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">Q</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>乘以值</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bijh,bjhd-&gt;bihd&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>重塑为<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">res</span> <span class="o">=</span> <span class="n">res</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>变换为<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>添加跳过连接</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="n">res</span> <span class="o">+=</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>改成形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">res</span> <span class="o">=</span> <span class="n">res</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="k">return</span> <span class="n">res</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<h3>向下方块</h3>
<p>这结合了<code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
<code class="highlight"><span></span><span class="n">AttentionBlock</span></code>
.这些在U-Net的前半部分以每种分辨率使用。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span><span class="k">class</span> <span class="nc">DownBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">218</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">has_attn</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
<span class="lineno">219</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">220</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span>
<span class="lineno">221</span> <span class="k">if</span> <span class="n">has_attn</span><span class="p">:</span>
<span class="lineno">222</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">AttentionBlock</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span>
<span class="lineno">223</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">224</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">227</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">228</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">229</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<h3>向上方块</h3>
<p>这结合了<code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
<code class="highlight"><span></span><span class="n">AttentionBlock</span></code>
.这些在U-Net的后半部分以每种分辨率使用。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span><span class="k">class</span> <span class="nc">UpBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">has_attn</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
<span class="lineno">240</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>输入之<code class="highlight"><span></span><span class="n">in_channels</span> <span class="o">+</span> <span class="n">out_channels</span></code>
所以有,是因为我们将 U-Net 前半部分相同分辨率的输出连接起来</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">+</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span>
<span class="lineno">244</span> <span class="k">if</span> <span class="n">has_attn</span><span class="p">:</span>
<span class="lineno">245</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">AttentionBlock</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span>
<span class="lineno">246</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">247</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">249</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">250</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">251</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">252</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<h3>中间方块</h3>
<p>它结合了<code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
<code class="highlight"><span></span><span class="n">AttentionBlock</span></code>
、后跟另一个<code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
。此块应用于 U-Net 的最低分辨率。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">255</span><span class="k">class</span> <span class="nc">MiddleBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">264</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">265</span> <span class="bp">self</span><span class="o">.</span><span class="n">res1</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span>
<span class="lineno">266</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">AttentionBlock</span><span class="p">(</span><span class="n">n_channels</span><span class="p">)</span>
<span class="lineno">267</span> <span class="bp">self</span><span class="o">.</span><span class="n">res2</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">269</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">270</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">271</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">272</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res2</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">273</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<h3>按比例放大要素地图<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">2</span><span class="mord">×</span></span></span></span></span></h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span><span class="k">class</span> <span class="nc">Upsample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">281</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">):</span>
<span class="lineno">282</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">283</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose2d</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">t</span></code>
未使用,但它保留在参数中,因为要与注意层函数签名匹配<code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="n">_</span> <span class="o">=</span> <span class="n">t</span>
<span class="lineno">289</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<h3>按比例缩小要素地图<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.190108em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord">×</span></span></span></span></span></h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span><span class="k">class</span> <span class="nc">Downsample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">297</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">):</span>
<span class="lineno">298</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">299</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">301</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">t</span></code>
未使用,但它保留在参数中,因为要与注意层函数签名匹配<code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">304</span> <span class="n">_</span> <span class="o">=</span> <span class="n">t</span>
<span class="lineno">305</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<h2>U-Net</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">308</span><span class="k">class</span> <span class="nc">UNet</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">image_channels</span></code>
是图像中的通道数。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span></span>对于 RGB。</li>
<li><code class="highlight"><span></span><span class="n">n_channels</span></code>
是初始特征图中我们将图像转换为的通道数</li>
<li><code class="highlight"><span></span><span class="n">ch_mults</span></code>
是每种分辨率下的通道编号列表。频道的数量是<code class="highlight"><span></span><span class="n">ch_mults</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">n_channels</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">is_attn</span></code>
是一个布尔值列表,用于指示是否在每个分辨率下使用注意力</li>
<li><code class="highlight"><span></span><span class="n">n_blocks</span></code>
是每种分辨<code class="highlight"><span></span><span class="n">UpDownBlocks</span></code>
率的数字</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">313</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span>
<span class="lineno">314</span> <span class="n">ch_mults</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span>
<span class="lineno">315</span> <span class="n">is_attn</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">bool</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="nb">bool</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">316</span> <span class="n">n_blocks</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">324</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>分辨率数量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">327</span> <span class="n">n_resolutions</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">ch_mults</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>将图像投影到要素地图中</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">330</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">image_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>时间嵌入层。时间嵌入有<code class="highlight"><span></span><span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span></code>
频道</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">333</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span> <span class="o">=</span> <span class="n">TimeEmbedding</span><span class="p">(</span><span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<h4>U-Net 的前半部分-分辨率降低</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">336</span> <span class="n">down</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>频道数量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">338</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">n_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>对于每种分辨率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">340</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_resolutions</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>此分辨率下的输出声道数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">342</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> <span class="o">*</span> <span class="n">ch_mults</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>添加<code class="highlight"><span></span><span class="n">n_blocks</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">344</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_blocks</span><span class="p">):</span>
<span class="lineno">345</span> <span class="n">down</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">DownBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="n">is_attn</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span>
<span class="lineno">346</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">out_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>除最后一个分辨率之外的所有分辨率都向下采样</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">348</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">n_resolutions</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">349</span> <span class="n">down</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Downsample</span><span class="p">(</span><span class="n">in_channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<p>组合这组模块</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">352</span> <span class="bp">self</span><span class="o">.</span><span class="n">down</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">down</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<p>中间方块</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">355</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle</span> <span class="o">=</span> <span class="n">MiddleBlock</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<h4>U-Net 的后半部分-提高分辨率</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">358</span> <span class="n">up</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>频道数量</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">360</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">out_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>对于每种分辨率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">362</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_resolutions</span><span class="p">)):</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">n_blocks</span></code>
以相同的分辨率</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">364</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span>
<span class="lineno">365</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_blocks</span><span class="p">):</span>
<span class="lineno">366</span> <span class="n">up</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">UpBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="n">is_attn</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>减少信道数量的最终区块</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">368</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> <span class="o">//</span> <span class="n">ch_mults</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="lineno">369</span> <span class="n">up</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">UpBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="n">is_attn</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span>
<span class="lineno">370</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">out_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<p>除最后一个以外的所有分辨率向上采样</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">372</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">373</span> <span class="n">up</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Upsample</span><span class="p">(</span><span class="n">in_channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>组合这组模块</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">376</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">up</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>最终归一化和卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">379</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">)</span>
<span class="lineno">380</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span>
<span class="lineno">381</span> <span class="bp">self</span><span class="o">.</span><span class="n">final</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">image_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">t</span></code>
有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">383</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>获取时间步长嵌入</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">390</span> <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span><span class="p">(</span><span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>获取图像投影</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">393</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">h</span></code>
将以每种分辨率存储输出以进行跳过连接</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">396</span> <span class="n">h</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>U-Net 的上半年</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">398</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">down</span><span class="p">:</span>
<span class="lineno">399</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">400</span> <span class="n">h</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>中间(底部)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">403</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<p>U-Net 的下半场</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">406</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span><span class="p">:</span>
<span class="lineno">407</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">Upsample</span><span class="p">):</span>
<span class="lineno">408</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">409</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>从 U-Net 的前半部分获取跳过连接并连接</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">411</span> <span class="n">s</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">pop</span><span class="p">()</span>
<span class="lineno">412</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">s</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">414</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>最终归一化和卷积</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">417</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">final</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>