Files
Varuna Jayasiri 1c14551a19 zh
2023-02-28 08:40:22 +05:30

669 lines
45 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="PyTorch 实现和 U-Net 模型教程。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="U-Net"/>
<meta name="twitter:description" content="PyTorch 实现和 U-Net 模型教程。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/unet/index.html"/>
<meta property="og:title" content="U-Net"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="U-Net"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="U-Net"/>
<meta property="og:description" content="PyTorch 实现和 U-Net 模型教程。"/>
<title>U-Net</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/unet/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">unet</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/unet/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>U-Net</h1>
<p>这是论文《U-N <a href="https://papers.labml.ai/paper/1505.04597">et生物医学图像分割的卷积网络》中U-Net模型的实现</a></p>
<p>U-Net 由一条收缩路径和一条扩展路径组成。收缩路径是一系列卷积图层和池化图层,其中要素地图的分辨率会逐渐降低。扩展路径是一系列向上采样图层和卷积图层,其中要素地图的分辨率会逐渐提高。</p>
<p>在扩张路径的每一步中,收缩路径中的相应要素地图都与当前要素地图相连。</p>
<p><img alt="U-Net diagram from paper" src="unet.png"></p>
<p>以下是在 <a href="carvana.html">Carvana 数据集</a><a href="experiment.html">训练 U-Net 的实验的训练代码</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">28</span><span class="kn">import</span> <span class="nn">torchvision.transforms.functional</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>两个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span>卷积层</h3>
<p>收缩路径和扩张路径中的每一步都有两个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span>卷积层,然后是 RelU 激活。</p>
<p>在 U-Net 论文中,他们使用<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0</span></span></span></span></span>填充,但我们使用<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style="">1</span></span></span></span></span></span>填充,这样最终的特征图就不会被裁剪。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">DoubleConvolution</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
是输入声道的数量</li>
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
是输出声道的数量</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>第一个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span>卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="bp">self</span><span class="o">.</span><span class="n">first</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>第二个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span>卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">second</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>应用两个卷积层和激活</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">first</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">60</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">61</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">second</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">62</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h3>向下采样</h3>
<p>收缩路径中的每个步骤都会使用<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style="">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">2</span></span></span></span></span></span>最大池化图层对要素地图进行缩减采样。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span><span class="k">class</span> <span class="nc">DownSample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">74</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>最大池化层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">79</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<h3>向上采样</h3>
<p>扩展路径中的每一步都使用向上<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style="">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">2</span></span></span></span></span></span>卷积对要素地图进行向上采样。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span><span class="k">class</span> <span class="nc">UpSample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">90</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>向上卷积</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">96</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<h3>裁剪并连接要素地图</h3>
<p>在扩张路径的每一步中,收缩路径中的相应要素地图都与当前要素地图相连。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span><span class="k">class</span> <span class="nc">CropAndConcat</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
扩展路径中的当前要素地图</li>
<li><code class="highlight"><span></span><span class="n">contracting_x</span></code>
收缩路径中的相应要素地图</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">contracting_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>将要素地图从收缩路径裁剪为当前要素地图的大小</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">contracting_x</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">center_crop</span><span class="p">(</span><span class="n">contracting_x</span><span class="p">,</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]])</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>连接要素映射</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">contracting_x</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<h2>U-Net</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span><span class="k">class</span> <span class="nc">UNet</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
输入图像中的通道数</li>
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
结果特征图中的信道数</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>收缩路径的双卷积层。从开始,每一步的功能数量都会增加一倍<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">64</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">DoubleConvolution</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">o</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">o</span> <span class="ow">in</span>
<span class="lineno">134</span> <span class="p">[(</span><span class="n">in_channels</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">)]])</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>向下采样收缩路径的图层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_sample</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">DownSample</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)])</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>分辨率最低的两个卷积层U 的底部)。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle_conv</span> <span class="o">=</span> <span class="n">DoubleConvolution</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1024</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>向上采样扩展路径的图层。通过向上采样,要素数量减半。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">up_sample</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">UpSample</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">o</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">o</span> <span class="ow">in</span>
<span class="lineno">144</span> <span class="p">[(</span><span class="mi">1024</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">64</span><span class="p">)]])</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>扩展路径的双卷积层。它们的输入是当前要素地图和收缩路径中的要素地图的串联。因此,输入要素的数量是向上采样的要素数量的两倍。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="bp">self</span><span class="o">.</span><span class="n">up_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">DoubleConvolution</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">o</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">o</span> <span class="ow">in</span>
<span class="lineno">150</span> <span class="p">[(</span><span class="mi">1024</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">64</span><span class="p">)]])</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>裁剪和连接扩展路径的图层。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">concat</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">CropAndConcat</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)])</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>生成输出的最终<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord coloredeq eqf" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqf" style="">1</span></span></span></span></span></span></span>卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="bp">self</span><span class="o">.</span><span class="n">final_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
输入图像</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>收集收缩路径的输出,以便稍后与扩展路径串联。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">pass_through</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>收缩路径</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">down_conv</span><span class="p">)):</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>两个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span>卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_conv</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>收集输出</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">pass_through</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>向下采样</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_sample</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>U-Net 底部有两个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span>卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle_conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>广阔的道路</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">175</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">up_conv</span><span class="p">)):</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>向上采样</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">177</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">up_sample</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>连接收缩路径的输出</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">concat</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">pass_through</span><span class="o">.</span><span class="n">pop</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>两个<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span>卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">up_conv</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>最终<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord coloredeq eqf" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqf" style="">1</span></span></span></span></span></span></span>卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">final_conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>