mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
668 lines
46 KiB
HTML
668 lines
46 KiB
HTML
<!DOCTYPE html>
|
||
<html lang="en">
|
||
<head>
|
||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||
<meta name="description" content="PyTorch implementation and tutorial of U-Net model."/>
|
||
|
||
<meta name="twitter:card" content="summary"/>
|
||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta name="twitter:title" content="U-Net"/>
|
||
<meta name="twitter:description" content="PyTorch implementation and tutorial of U-Net model."/>
|
||
<meta name="twitter:site" content="@labmlai"/>
|
||
<meta name="twitter:creator" content="@labmlai"/>
|
||
|
||
<meta property="og:url" content="https://nn.labml.ai/unet/index.html"/>
|
||
<meta property="og:title" content="U-Net"/>
|
||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta property="og:site_name" content="U-Net"/>
|
||
<meta property="og:type" content="object"/>
|
||
<meta property="og:title" content="U-Net"/>
|
||
<meta property="og:description" content="PyTorch implementation and tutorial of U-Net model."/>
|
||
|
||
<title>U-Net</title>
|
||
<link rel="shortcut icon" href="/icon.png"/>
|
||
<link rel="stylesheet" href="../pylit.css?v=1">
|
||
<link rel="canonical" href="https://nn.labml.ai/unet/index.html"/>
|
||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||
|
||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||
<script>
|
||
window.dataLayer = window.dataLayer || [];
|
||
|
||
function gtag() {
|
||
dataLayer.push(arguments);
|
||
}
|
||
|
||
gtag('js', new Date());
|
||
|
||
gtag('config', 'G-4V3HC8HBLH');
|
||
</script>
|
||
</head>
|
||
<body>
|
||
<div id='container'>
|
||
<div id="background"></div>
|
||
<div class='section'>
|
||
<div class='docs'>
|
||
<p>
|
||
<a class="parent" href="/">home</a>
|
||
<a class="parent" href="index.html">unet</a>
|
||
</p>
|
||
<p>
|
||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
|
||
<img alt="Github"
|
||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||
style="max-width:100%;"/></a>
|
||
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
|
||
<img alt="Twitter"
|
||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||
style="max-width:100%;"/></a>
|
||
</p>
|
||
<p>
|
||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/unet/__init__.py" target="_blank">
|
||
View code on Github</a>
|
||
</p>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-0'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-0'>#</a>
|
||
</div>
|
||
<h1>U-Net</h1>
|
||
<p>This is an implementation of the U-Net model from the paper, <a href="https://arxiv.org/abs/1505.04597">U-Net: Convolutional Networks for Biomedical Image Segmentation</a>.</p>
|
||
<p>U-Net consists of a contracting path and an expansive path. The contracting path is a series of convolutional layers and pooling layers, where the resolution of the feature map gets progressively reduced. Expansive path is a series of up-sampling layers and convolutional layers where the resolution of the feature map gets progressively increased.</p>
|
||
<p>At every step in the expansive path the corresponding feature map from the contracting path concatenated with the current feature map.</p>
|
||
<p><img alt="U-Net diagram from paper" src="unet.png"></p>
|
||
<p>Here is the <a href="experiment.html">training code</a> for an experiment that trains a U-Net on <a href="carvana.html">Carvana dataset</a>.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">27</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="lineno">28</span><span class="kn">import</span> <span class="nn">torchvision.transforms.functional</span>
|
||
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-1'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-1'>#</a>
|
||
</div>
|
||
<h3>Two <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span> Convolution Layers</h3>
|
||
<p>Each step in the contraction path and expansive path have two <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span> convolutional layers followed by ReLU activations.</p>
|
||
<p>In the U-Net paper they used <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0</span></span></span></span></span> padding, but we use <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord" style="">1</span></span></span></span></span></span> padding so that final feature map is not cropped.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">DoubleConvolution</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-2'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-2'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
|
||
is the number of input channels </li>
|
||
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
|
||
is the number of output channels</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">43</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-3'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-3'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">48</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-4'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-4'>#</a>
|
||
</div>
|
||
<p>First <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span> convolutional layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">51</span> <span class="bp">self</span><span class="o">.</span><span class="n">first</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-5'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-5'>#</a>
|
||
</div>
|
||
<p>Second <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span> convolutional layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">second</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-6'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-6'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-7'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-7'>#</a>
|
||
</div>
|
||
<p>Apply the two convolution layers and activations </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">first</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="lineno">60</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="lineno">61</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">second</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="lineno">62</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-8'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-8'>#</a>
|
||
</div>
|
||
<h3>Down-sample</h3>
|
||
<p>Each step in the contracting path down-samples the feature map with a <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style="">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">2</span></span></span></span></span></span> max pooling layer.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">65</span><span class="k">class</span> <span class="nc">DownSample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-9'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-9'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">73</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="lineno">74</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-10'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-10'>#</a>
|
||
</div>
|
||
<p>Max pooling layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-11'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-11'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="lineno">79</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-12'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-12'>#</a>
|
||
</div>
|
||
<h3>Up-sample</h3>
|
||
<p>Each step in the expansive path up-samples the feature map with a <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqb" style=""><span class="mord" style="">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">2</span></span></span></span></span></span> up-convolution.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">82</span><span class="k">class</span> <span class="nc">UpSample</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-13'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-13'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">89</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="lineno">90</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-14'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-14'>#</a>
|
||
</div>
|
||
<p>Up-convolution </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-15'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-15'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="lineno">96</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">up</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-16'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-16'>#</a>
|
||
</div>
|
||
<h3>Crop and Concatenate the feature map</h3>
|
||
<p>At every step in the expansive path the corresponding feature map from the contracting path concatenated with the current feature map.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">99</span><span class="k">class</span> <span class="nc">CropAndConcat</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-17'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-17'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
|
||
current feature map in the expansive path </li>
|
||
<li><code class="highlight"><span></span><span class="n">contracting_x</span></code>
|
||
corresponding feature map from the contracting path</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">106</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">contracting_x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-18'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-18'>#</a>
|
||
</div>
|
||
<p>Crop the feature map from the contracting path to the size of the current feature map </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">contracting_x</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">center_crop</span><span class="p">(</span><span class="n">contracting_x</span><span class="p">,</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-19'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-19'>#</a>
|
||
</div>
|
||
<p>Concatenate the feature maps </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">115</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">contracting_x</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-20'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-20'>#</a>
|
||
</div>
|
||
<p> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">117</span> <span class="k">return</span> <span class="n">x</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-21'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-21'>#</a>
|
||
</div>
|
||
<h2>U-Net</h2>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">120</span><span class="k">class</span> <span class="nc">UNet</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-22'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-22'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
|
||
number of channels in the input image </li>
|
||
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
|
||
number of channels in the result feature map</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-23'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-23'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">129</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-24'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-24'>#</a>
|
||
</div>
|
||
<p>Double convolution layers for the contracting path. The number of features gets doubled at each step starting from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">64</span></span></span></span></span>. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">133</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">DoubleConvolution</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">o</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">o</span> <span class="ow">in</span>
|
||
<span class="lineno">134</span> <span class="p">[(</span><span class="n">in_channels</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">)]])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-25'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-25'>#</a>
|
||
</div>
|
||
<p>Down sampling layers for the contracting path </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">136</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_sample</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">DownSample</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-26'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-26'>#</a>
|
||
</div>
|
||
<p>The two convolution layers at the lowest resolution (the bottom of the U). </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">139</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle_conv</span> <span class="o">=</span> <span class="n">DoubleConvolution</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1024</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-27'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-27'>#</a>
|
||
</div>
|
||
<p>Up sampling layers for the expansive path. The number of features is halved with up-sampling. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">up_sample</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">UpSample</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">o</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">o</span> <span class="ow">in</span>
|
||
<span class="lineno">144</span> <span class="p">[(</span><span class="mi">1024</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">64</span><span class="p">)]])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-28'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-28'>#</a>
|
||
</div>
|
||
<p>Double convolution layers for the expansive path. Their input is the concatenation of the current feature map and the feature map from the contracting path. Therefore, the number of input features is double the number of features from up-sampling. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">149</span> <span class="bp">self</span><span class="o">.</span><span class="n">up_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">DoubleConvolution</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">o</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">o</span> <span class="ow">in</span>
|
||
<span class="lineno">150</span> <span class="p">[(</span><span class="mi">1024</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">64</span><span class="p">)]])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-29'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-29'>#</a>
|
||
</div>
|
||
<p>Crop and concatenate layers for the expansive path. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">concat</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">CropAndConcat</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">)])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-30'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-30'>#</a>
|
||
</div>
|
||
<p>Final <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord coloredeq eqf" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqf" style="">1</span></span></span></span></span></span></span> convolution layer to produce the output </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">154</span> <span class="bp">self</span><span class="o">.</span><span class="n">final_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-31'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-31'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
|
||
input image</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-32'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-32'>#</a>
|
||
</div>
|
||
<p>To collect the outputs of contracting path for later concatenation with the expansive path. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">pass_through</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-33'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-33'>#</a>
|
||
</div>
|
||
<p>Contracting path </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">down_conv</span><span class="p">)):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-34'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-34'>#</a>
|
||
</div>
|
||
<p>Two <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span> convolutional layers </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_conv</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-35'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-35'>#</a>
|
||
</div>
|
||
<p>Collect the output </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">pass_through</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-36'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-36'>#</a>
|
||
</div>
|
||
<p>Down-sample </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_sample</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-37'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-37'>#</a>
|
||
</div>
|
||
<p>Two <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span> convolutional layers at the bottom of the U-Net </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">middle_conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-38'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-38'>#</a>
|
||
</div>
|
||
<p>Expansive path </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">175</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">up_conv</span><span class="p">)):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-39'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-39'>#</a>
|
||
</div>
|
||
<p>Up-sample </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">177</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">up_sample</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-40'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-40'>#</a>
|
||
</div>
|
||
<p>Concatenate the output of the contracting path </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">concat</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">pass_through</span><span class="o">.</span><span class="n">pop</span><span class="p">())</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-41'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-41'>#</a>
|
||
</div>
|
||
<p>Two <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style="">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">3</span></span></span></span></span></span> convolutional layers </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">up_conv</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-42'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-42'>#</a>
|
||
</div>
|
||
<p>Final <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style=""><span class="mord coloredeq eqf" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style=""><span class="mord coloredeq eqf" style="">1</span></span></span></span></span></span></span> convolution layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">final_conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-43'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-43'>#</a>
|
||
</div>
|
||
<p> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">return</span> <span class="n">x</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='footer'>
|
||
<a href="https://labml.ai">labml.ai</a>
|
||
</div>
|
||
</div>
|
||
<script src=../interactive.js?v=1"></script>
|
||
<script>
|
||
function handleImages() {
|
||
var images = document.querySelectorAll('p>img')
|
||
|
||
for (var i = 0; i < images.length; ++i) {
|
||
handleImage(images[i])
|
||
}
|
||
}
|
||
|
||
function handleImage(img) {
|
||
img.parentElement.style.textAlign = 'center'
|
||
|
||
var modal = document.createElement('div')
|
||
modal.id = 'modal'
|
||
|
||
var modalContent = document.createElement('div')
|
||
modal.appendChild(modalContent)
|
||
|
||
var modalImage = document.createElement('img')
|
||
modalContent.appendChild(modalImage)
|
||
|
||
var span = document.createElement('span')
|
||
span.classList.add('close')
|
||
span.textContent = 'x'
|
||
modal.appendChild(span)
|
||
|
||
img.onclick = function () {
|
||
console.log('clicked')
|
||
document.body.appendChild(modal)
|
||
modalImage.src = img.src
|
||
}
|
||
|
||
span.onclick = function () {
|
||
document.body.removeChild(modal)
|
||
}
|
||
}
|
||
|
||
handleImages()
|
||
</script>
|
||
</body>
|
||
</html> |