Files
2024-06-21 19:35:22 +05:30

1453 lines
120 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="UNet model for Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="U-Net model for Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta name="twitter:description" content="UNet model for Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/unet.html"/>
<meta property="og:title" content="U-Net model for Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="U-Net model for Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="U-Net model for Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta property="og:description" content="UNet model for Denoising Diffusion Probabilistic Models (DDPM)"/>
<title>U-Net model for Denoising Diffusion Probabilistic Models (DDPM)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/unet.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">diffusion</a>
<a class="parent" href="index.html">ddpm</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/unet.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>U-Net model for <a href="index.html">Denoising Diffusion Probabilistic Models (DDPM)</a></h1>
<p>This is a <a href="../../unet/index.html">U-Net</a> based model to predict noise <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord" style="color:lightgreen"><span class="mord mathnormal" style="">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2805559999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span><span class="mclose">)</span></span></span></span></span>.</p>
<p>U-Net is a gets it&#x27;s name from the U shape in the model diagram. It processes a given image by progressively lowering (halving) the feature map resolution and then increasing the resolution. There are pass-through connection at each resolution.</p>
<p><img alt="U-Net diagram from paper" src="../../unet/unet.png"></p>
<p>This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) and also adds time-step embeddings <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">24</span><span></span><span class="kn">import</span> <span class="nn">math</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span>
<span class="lineno">26</span>
<span class="lineno">27</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">29</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>Swish activation function</h3>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.44445em;vertical-align:0em;"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="k">class</span> <span class="nc">Swish</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">41</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<h3>Embeddings for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span></h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span><span class="k">class</span> <span class="nc">TimeEmbedding</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_channels</span></code>
is the number of dimensions in the embedding</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span> <span class="o">=</span> <span class="n">n_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>First linear layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span> <span class="o">//</span> <span class="mi">4</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Activation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Second linear layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Create sinusoidal position embeddings <a href="../../transformers/positional_encoding.html">same as those from the transformer</a></p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:6.600059999999999em;vertical-align:-3.0500299999999996em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.5500299999999996em;"><span style="top:-5.55003em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.4231360000000004em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="">t</span></span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.412972em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.2500000000000004em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.4231360000000004em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqj" style=""><span class="mord mathnormal mtight" style="">t</span></span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">i</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.412972em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:3.0500299999999996em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.5500299999999996em;"><span style="top:-5.55003em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathnormal">s</span><span class="mord mathnormal">in</span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.29208em;"><span style="top:-2.121225em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1000</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9887749999999998em;"><span style="top:-3.3902150000000004em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8550857142857142em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqi" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span><span style="top:-3.2255em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.40352142857142853em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8787749999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span><span style="top:-2.2500000000000004em;"><span class="pstrut" style="height:3.75em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathnormal">cos</span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.29208em;"><span style="top:-2.121225em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">1000</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9887749999999998em;"><span style="top:-3.3902150000000004em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8550857142857142em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqi" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span><span style="top:-3.2255em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.40352142857142853em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8787749999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:3.0500299999999996em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqi" style=""><span class="mord mathnormal" style="">d</span></span></span></span></span></span> is <code class="highlight"><span></span><span class="n">half_dim</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">half_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span> <span class="o">//</span> <span class="mi">8</span>
<span class="lineno">73</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">half_dim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">74</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">half_dim</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">t</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">*</span> <span class="o">-</span><span class="n">emb</span><span class="p">)</span>
<span class="lineno">75</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">emb</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">76</span> <span class="n">emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">emb</span><span class="o">.</span><span class="n">sin</span><span class="p">(),</span> <span class="n">emb</span><span class="o">.</span><span class="n">cos</span><span class="p">()),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Transform with the MLP </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lin1</span><span class="p">(</span><span class="n">emb</span><span class="p">))</span>
<span class="lineno">80</span> <span class="n">emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin2</span><span class="p">(</span><span class="n">emb</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">return</span> <span class="n">emb</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<h3>Residual block</h3>
<p>A residual block has two convolution layers with group normalization. Each resolution is processed with two residual blocks.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span><span class="k">class</span> <span class="nc">ResidualBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
is the number of input channels </li>
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
is the number of input channels </li>
<li><code class="highlight"><span></span><span class="n">time_channels</span></code>
is the number channels in the time step (<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.61508em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span></span></span></span></span>) embeddings </li>
<li><code class="highlight"><span></span><span class="n">n_groups</span></code>
is the number of groups for <a href="../../normalization/group_norm/index.html">group normalization</a> </li>
<li><code class="highlight"><span></span><span class="n">dropout</span></code>
is the dropout rate</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">95</span> <span class="n">n_groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Group normalization and the first convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">)</span>
<span class="lineno">106</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span>
<span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Group normalization and the second convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">)</span>
<span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span>
<span class="lineno">112</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>If the number of input channels is not equal to the number of output channels we have to project the shortcut connection </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">if</span> <span class="n">in_channels</span> <span class="o">!=</span> <span class="n">out_channels</span><span class="p">:</span>
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">118</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">119</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Linear layer for time embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">time_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">)</span>
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_act</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span>
<span class="lineno">124</span>
<span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">t</span></code>
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>First convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Add time embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="n">h</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">time_act</span><span class="p">(</span><span class="n">t</span><span class="p">))[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Second convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">h</span><span class="p">))))</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Add the shortcut connection and return </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="k">return</span> <span class="n">h</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<h3>Attention block</h3>
<p>This is similar to <a href="../../transformers/mha.html">transformer multi-head attention</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span><span class="k">class</span> <span class="nc">AttentionBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_channels</span></code>
is the number of channels in the input </li>
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
is the number of heads in multi-head attention </li>
<li><code class="highlight"><span></span><span class="n">d_k</span></code>
is the number of dimensions in each head </li>
<li><code class="highlight"><span></span><span class="n">n_groups</span></code>
is the number of groups for <a href="../../normalization/group_norm/index.html">group normalization</a></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">n_groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Default <code class="highlight"><span></span><span class="n">d_k</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="k">if</span> <span class="n">d_k</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">161</span> <span class="n">d_k</span> <span class="o">=</span> <span class="n">n_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Normalization layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">n_groups</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Projections for query, key and values </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Linear layer for final transformation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Scale for dot-product attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">d_k</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
<span class="lineno">172</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_k</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">t</span></code>
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">t</span></code>
is not used, but it&#x27;s kept in the arguments because for the attention layer function signature to match with <code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">_</span> <span class="o">=</span> <span class="n">t</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Get shape </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Change <code class="highlight"><span></span><span class="n">x</span></code>
to shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Get query, key, and values (concatenated) and shape it to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Split query, key, and values. Each of them will have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Calculate scaled dot-product <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.633028em;vertical-align:-0.538em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqi" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">Q</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bihd,bjhd-&gt;bijh&#39;</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Softmax along the sequence dimension <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mord mathnormal">so</span><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="">t</span></span><span class="mord mathnormal">ma</span><span class="mord mathnormal">x</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqi" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">Q</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mord"><span class="delimsizing size4">)</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Multiply by values </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bijh,bjhd-&gt;bihd&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Reshape to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">res</span> <span class="o">=</span> <span class="n">res</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Transform to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Add skip connection </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="n">res</span> <span class="o">+=</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Change to shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">res</span> <span class="o">=</span> <span class="n">res</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="k">return</span> <span class="n">res</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<h3>Down block</h3>
<p>This combines <code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
and <code class="highlight"><span></span><span class="n">AttentionBlock</span></code>
. These are used in the first half of U-Net at each resolution.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span><span class="k">class</span> <span class="nc">DownBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">218</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">has_attn</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
<span class="lineno">219</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">220</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span>
<span class="lineno">221</span> <span class="k">if</span> <span class="n">has_attn</span><span class="p">:</span>
<span class="lineno">222</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">AttentionBlock</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span>
<span class="lineno">223</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">224</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">227</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">228</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">229</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<h3>Up block</h3>
<p>This combines <code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
and <code class="highlight"><span></span><span class="n">AttentionBlock</span></code>
. These are used in the second half of U-Net at each resolution.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span><span class="k">class</span> <span class="nc">UpBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">has_attn</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
<span class="lineno">240</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>The input has <code class="highlight"><span></span><span class="n">in_channels</span> <span class="o">+</span> <span class="n">out_channels</span></code>
because we concatenate the output of the same resolution from the first half of the U-Net </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">+</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span>
<span class="lineno">244</span> <span class="k">if</span> <span class="n">has_attn</span><span class="p">:</span>
<span class="lineno">245</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">AttentionBlock</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span>
<span class="lineno">246</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">247</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">249</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">250</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">251</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">252</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<h3>Middle block</h3>
<p>It combines a <code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
, <code class="highlight"><span></span><span class="n">AttentionBlock</span></code>
, followed by another <code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
. This block is applied at the lowest resolution of the U-Net.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">255</span><span class="k">class</span> <span class="nc">MiddleBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">264</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">265</span> <span class="bp">self</span><span class="o">.</span><span class="n">res1</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span>
<span class="lineno">266</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">AttentionBlock</span><span class="p">(</span><span class="n">n_channels</span><span class="p">)</span>
<span class="lineno">267</span> <span class="bp">self</span><span class="o">.</span><span class="n">res2</span> <span class="o">=</span> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">time_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">269</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">270</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">271</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">272</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">res2</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">273</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<h3>Scale up the feature map by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">2</span><span class="mord">×</span></span></span></span></span></h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span><span class="k">class</span> <span class="nc">Upsample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">281</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">):</span>
<span class="lineno">282</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">283</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose2d</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">t</span></code>
is not used, but it&#x27;s kept in the arguments because for the attention layer function signature to match with <code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="n">_</span> <span class="o">=</span> <span class="n">t</span>
<span class="lineno">289</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<h3>Scale down the feature map by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.190108em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord">×</span></span></span></span></span></h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span><span class="k">class</span> <span class="nc">Downsample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">297</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">):</span>
<span class="lineno">298</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">299</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">n_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">301</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">t</span></code>
is not used, but it&#x27;s kept in the arguments because for the attention layer function signature to match with <code class="highlight"><span></span><span class="n">ResidualBlock</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">304</span> <span class="n">_</span> <span class="o">=</span> <span class="n">t</span>
<span class="lineno">305</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<h2>U-Net</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">308</span><span class="k">class</span> <span class="nc">UNet</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">image_channels</span></code>
is the number of channels in the image. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span></span> for RGB. </li>
<li><code class="highlight"><span></span><span class="n">n_channels</span></code>
is number of channels in the initial feature map that we transform the image into </li>
<li><code class="highlight"><span></span><span class="n">ch_mults</span></code>
is the list of channel numbers at each resolution. The number of channels is <code class="highlight"><span></span><span class="n">ch_mults</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">n_channels</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">is_attn</span></code>
is a list of booleans that indicate whether to use attention at each resolution </li>
<li><code class="highlight"><span></span><span class="n">n_blocks</span></code>
is the number of <code class="highlight"><span></span><span class="n">UpDownBlocks</span></code>
at each resolution</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">313</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span>
<span class="lineno">314</span> <span class="n">ch_mults</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span>
<span class="lineno">315</span> <span class="n">is_attn</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">bool</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="nb">bool</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">316</span> <span class="n">n_blocks</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">324</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>Number of resolutions </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">327</span> <span class="n">n_resolutions</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">ch_mults</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>Project image into feature map </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">330</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">image_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>Time embedding layer. Time embedding has <code class="highlight"><span></span><span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span></code>
channels </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">333</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span> <span class="o">=</span> <span class="n">TimeEmbedding</span><span class="p">(</span><span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<h4>First half of U-Net - decreasing resolution</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">336</span> <span class="n">down</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>Number of channels </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">338</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">n_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>For each resolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">340</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_resolutions</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>Number of output channels at this resolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">342</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> <span class="o">*</span> <span class="n">ch_mults</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>Add <code class="highlight"><span></span><span class="n">n_blocks</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">344</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_blocks</span><span class="p">):</span>
<span class="lineno">345</span> <span class="n">down</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">DownBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="n">is_attn</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span>
<span class="lineno">346</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">out_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>Down sample at all resolutions except the last </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">348</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">n_resolutions</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">349</span> <span class="n">down</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Downsample</span><span class="p">(</span><span class="n">in_channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<p>Combine the set of modules </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">352</span> <span class="bp">self</span><span class="o">.</span><span class="n">down</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">down</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<p>Middle block </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">355</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle</span> <span class="o">=</span> <span class="n">MiddleBlock</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<h4>Second half of U-Net - increasing resolution</h4>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">358</span> <span class="n">up</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>Number of channels </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">360</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">out_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>For each resolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">362</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_resolutions</span><span class="p">)):</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">n_blocks</span></code>
at the same resolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">364</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span>
<span class="lineno">365</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_blocks</span><span class="p">):</span>
<span class="lineno">366</span> <span class="n">up</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">UpBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="n">is_attn</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>Final block to reduce the number of channels </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">368</span> <span class="n">out_channels</span> <span class="o">=</span> <span class="n">in_channels</span> <span class="o">//</span> <span class="n">ch_mults</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="lineno">369</span> <span class="n">up</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">UpBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">n_channels</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="n">is_attn</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span>
<span class="lineno">370</span> <span class="n">in_channels</span> <span class="o">=</span> <span class="n">out_channels</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<p>Up sample at all resolutions except last </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">372</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">373</span> <span class="n">up</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Upsample</span><span class="p">(</span><span class="n">in_channels</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>Combine the set of modules </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">376</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">up</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>Final normalization and convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">379</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">)</span>
<span class="lineno">380</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">Swish</span><span class="p">()</span>
<span class="lineno">381</span> <span class="bp">self</span><span class="o">.</span><span class="n">final</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">image_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">t</span></code>
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">383</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>Get time-step embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">390</span> <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_emb</span><span class="p">(</span><span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>Get image projection </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">393</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">h</span></code>
will store outputs at each resolution for skip connection </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">396</span> <span class="n">h</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>First half of U-Net </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">398</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">down</span><span class="p">:</span>
<span class="lineno">399</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">400</span> <span class="n">h</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>Middle (bottom) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">403</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<p>Second half of U-Net </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">406</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span><span class="p">:</span>
<span class="lineno">407</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">Upsample</span><span class="p">):</span>
<span class="lineno">408</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">409</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>Get the skip connection from first half of U-Net and concatenate </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">411</span> <span class="n">s</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">pop</span><span class="p">()</span>
<span class="lineno">412</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">s</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">414</span> <span class="n">x</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>Final normalization and convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">417</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">final</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>