Files
Varuna Jayasiri a7a7a3bdb7 RETRO (#110)
2022-03-12 15:44:35 +05:30

1046 lines
94 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation/tutorial of Deep Residual Learning for Image Recognition (ResNet)."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Deep Residual Learning for Image Recognition (ResNet)"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of Deep Residual Learning for Image Recognition (ResNet)."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/resnet/index.html"/>
<meta property="og:title" content="Deep Residual Learning for Image Recognition (ResNet)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Deep Residual Learning for Image Recognition (ResNet)"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of Deep Residual Learning for Image Recognition (ResNet)."/>
<title>Deep Residual Learning for Image Recognition (ResNet)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/resnet/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">resnet</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/resnet/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Deep Residual Learning for Image Recognition (ResNet)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper <a href="https://papers.labml.ai/paper/1512.03385">Deep Residual Learning for Image Recognition</a>.</p>
<p>ResNets train layers as residual functions to overcome the <em>degradation problem</em>. The degradation problem is the accuracy of deep neural networks degrading when the number of layers becomes very high. The accuracy increases as the number of layers increase, then saturates, and then starts to degrade.</p>
<p>The paper argues that deeper models should perform at least as well as shallower models because the extra layers can just learn to perform an identity mapping.</p>
<h2>Residual Learning</h2>
<p>If <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqh" style=""><span class="mord mathcal" style="margin-right:0.00965em">H</span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mclose" style="">)</span></span></span></span></span> is the mapping that needs to be learned by a few layers, they train the residual function</p>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqg" style=""><span class="mord coloredeq eqj" style=""><span class="mord coloredeq eqt" style=""><span class="mord mathcal" style="margin-right:0.09931em">F</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqh" style=""><span class="mord mathcal" style="margin-right:0.00965em">H</span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span></span></p>
<p>instead. And the original function becomes <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqg" style=""><span class="mord coloredeq eqj" style=""><span class="mord coloredeq eqt" style=""><span class="mord mathcal" style="margin-right:0.09931em">F</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mclose" style="">)</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span>.</p>
<p>In this case, learning identity mapping for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqh" style=""><span class="mord mathcal" style="margin-right:0.00965em">H</span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mclose" style="">)</span></span></span></span></span> is equivalent to learning <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqg" style=""><span class="mord coloredeq eqj" style=""><span class="mord coloredeq eqt" style=""><span class="mord mathcal" style="margin-right:0.09931em">F</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mclose" style="">)</span></span></span></span></span> to be <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0</span></span></span></span>, which is easier to learn.</p>
<p>In the parameterized form this can be written as,</p>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqd" style=""><span class="mord coloredeq eqj" style=""><span class="mord coloredeq eqt" style=""><span class="mord mathcal" style="margin-right:0.09931em">F</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mopen" style="">{</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">})</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span></span></p>
<p>and when the feature map sizes of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqj" style=""><span class="mord coloredeq eqt" style=""><span class="mord mathcal" style="margin-right:0.09931em">F</span></span></span><span class="mopen">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span> are different the paper suggests doing a linear projection, with learned weights <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>.</p>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqd" style=""><span class="mord coloredeq eqj" style=""><span class="mord coloredeq eqt" style=""><span class="mord mathcal" style="margin-right:0.09931em">F</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mopen" style="">{</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">})</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqm" style=""><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span></span></span></p>
<p>Paper experimented with zero padding instead of linear projections and found linear projections to work better. Also when the feature map sizes match they found identity mapping to be better than linear projections.</p>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord coloredeq eqt" style=""><span class="mord mathcal" style="margin-right:0.09931em">F</span></span></span></span></span></span> should have more than one layer, otherwise the sum <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqd" style=""><span class="mord coloredeq eqj" style=""><span class="mord coloredeq eqt" style=""><span class="mord mathcal" style="margin-right:0.09931em">F</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mopen" style="">{</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">})</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqm" style=""><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span></span> also won&#x27;t have non-linearities and will be like a linear layer.</p>
<p>Here is <a href="experiment.html">the training code</a> for training a ResNet on CIFAR-10.</p>
<p><a href="https://app.labml.ai/run/fc5ad600e4af11ebbafd23b8665193c1"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span>
<span class="lineno">58</span>
<span class="lineno">59</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">60</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">61</span>
<span class="lineno">62</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Linear projections for shortcut connection</h2>
<p>This does the <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqm" style=""><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span></span> projection described above.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span><span class="k">class</span> <span class="nc">ShortcutProjection</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
is the number of channels in <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
is the number of channels in <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqd" style=""><span class="mord coloredeq eqj" style=""><span class="mord coloredeq eqt" style=""><span class="mord mathcal" style="margin-right:0.09931em">F</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mopen" style="">{</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style="">})</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">stride</span></code>
is the stride length in the convolution operation for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqt" style=""><span class="mord mathnormal" style="margin-right:0.13889em">F</span></span></span></span></span>. We do the same stride on the shortcut connection, to match the feature-map size.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Convolution layer for linear projection <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqm" style=""><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Paper suggests adding batch normalization after each convolution operation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Convolution and batch normalization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p> <a id="residual_block"></a></p>
<h2>Residual Block</h2>
<p>This implements the residual block described in the paper. It has two <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span> convolution layers.</p>
<p><img alt="Residual Block" src="residual_block.svg"></p>
<p>The first convolution layer maps from <code class="highlight"><span></span><span class="n">in_channels</span></code>
to <code class="highlight"><span></span><span class="n">out_channels</span></code>
, where the <code class="highlight"><span></span><span class="n">out_channels</span></code>
is higher than <code class="highlight"><span></span><span class="n">in_channels</span></code>
when we reduce the feature map size with a stride length greater than <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span>.</p>
<p>The second convolution layer maps from <code class="highlight"><span></span><span class="n">out_channels</span></code>
to <code class="highlight"><span></span><span class="n">out_channels</span></code>
and always has a stride length of 1.</p>
<p>Both convolution layers are followed by batch normalization.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span><span class="k">class</span> <span class="nc">ResidualBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
is the number of channels in <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
is the number of output channels </li>
<li><code class="highlight"><span></span><span class="n">stride</span></code>
is the stride length in the convolution operation.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>First <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span> convolution layer, this maps to <code class="highlight"><span></span><span class="n">out_channels</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Batch normalization after the first convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>First activation function (ReLU) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Second <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span> convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Batch normalization after the second convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Shortcut connection should be a projection if the stride length is not <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span> of if the number of channels change </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="k">if</span> <span class="n">stride</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">in_channels</span> <span class="o">!=</span> <span class="n">out_channels</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Projection <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqm" style=""><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">ShortcutProjection</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
<span class="lineno">137</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Identity <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Second activation function (ReLU) (after adding the shortcut) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the input of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Get the shortcut connection </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">shortcut</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>First convolution and activation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Second convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Activation function after adding the shortcut </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">155</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">shortcut</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p> <a id="bottleneck_residual_block"></a></p>
<h2>Bottleneck Residual Block</h2>
<p>This implements the bottleneck block described in the paper. It has <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span>, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span>, and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span> convolution layers.</p>
<p><img alt="Bottlenext Block" src="bottleneck_block.svg"></p>
<p>The first convolution layer maps from <code class="highlight"><span></span><span class="n">in_channels</span></code>
to <code class="highlight"><span></span><span class="n">bottleneck_channels</span></code>
with a <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span> convolution, where the <code class="highlight"><span></span><span class="n">bottleneck_channels</span></code>
is lower than <code class="highlight"><span></span><span class="n">in_channels</span></code>
.</p>
<p>The second <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mord">3</span></span></span></span> convolution layer maps from <code class="highlight"><span></span><span class="n">bottleneck_channels</span></code>
to <code class="highlight"><span></span><span class="n">bottleneck_channels</span></code>
. This can have a stride length greater than <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span> when we want to compress the feature map size.</p>
<p>The third, final <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span> convolution layer maps to <code class="highlight"><span></span><span class="n">out_channels</span></code>
. <code class="highlight"><span></span><span class="n">out_channels</span></code>
is higher than <code class="highlight"><span></span><span class="n">in_channels</span></code>
if the stride length is greater than <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span>; otherwise, <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord mathnormal">o</span><span class="mord mathnormal">u</span><span class="mord"><span class="mord mathnormal">t</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathnormal">hann</span><span class="mord mathnormal">e</span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">s</span></span></span></span> is equal to <code class="highlight"><span></span><span class="n">in_channels</span></code>
.</p>
<p><code class="highlight"><span></span><span class="n">bottleneck_channels</span></code>
is less than <code class="highlight"><span></span><span class="n">in_channels</span></code>
and the <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mord">3</span></span></span></span> convolution is performed on this shrunk space (hence the bottleneck). The two <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span> convolution decreases and increases the number of channels.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span><span class="k">class</span> <span class="nc">BottleneckResidualBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">in_channels</span></code>
is the number of channels in <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">bottleneck_channels</span></code>
is the number of channels for the <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mord">3</span></span></span></span> convlution </li>
<li><code class="highlight"><span></span><span class="n">out_channels</span></code>
is the number of output channels </li>
<li><code class="highlight"><span></span><span class="n">stride</span></code>
is the stride length in the <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span><span class="mord">3</span></span></span></span> convolution operation.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">bottleneck_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>First <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span> convolution layer, this maps to <code class="highlight"><span></span><span class="n">bottleneck_channels</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">bottleneck_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Batch normalization after the first convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">bottleneck_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>First activation function (ReLU) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Second <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">3</span></span></span></span> convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">bottleneck_channels</span><span class="p">,</span> <span class="n">bottleneck_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Batch normalization after the second convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">bottleneck_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Second activation function (ReLU) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Third <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span> convolution layer, this maps to <code class="highlight"><span></span><span class="n">out_channels</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">bottleneck_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Batch normalization after the second convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Shortcut connection should be a projection if the stride length is not <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqr" style=""><span class="mord" style="">1</span></span></span></span></span> of if the number of channels change </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">216</span> <span class="k">if</span> <span class="n">stride</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">in_channels</span> <span class="o">!=</span> <span class="n">out_channels</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Projection <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqm" style=""><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">s</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">218</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">ShortcutProjection</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
<span class="lineno">219</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Identity <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="">x</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Second activation function (ReLU) (after adding the shortcut) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="bp">self</span><span class="o">.</span><span class="n">act3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the input of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Get the shortcut connection </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">231</span> <span class="n">shortcut</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shortcut</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>First convolution and activation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">233</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Second convolution and activation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Third convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn3</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv3</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Activation function after adding the shortcut </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">act3</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">shortcut</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<h2>ResNet Model</h2>
<p>This is a the base of the resnet model without the final linear layer and softmax for classification.</p>
<p>The resnet is made of stacked <a href="#residual_block">residual blocks</a> or <a href="#bottleneck_residual_block">bottleneck residual blocks</a>. The feature map size is halved after a few blocks with a block of stride length <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">2</span></span></span></span>. The number of channels is increased when the feature map size is reduced. Finally the feature map is average pooled to get a vector representation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">242</span><span class="k">class</span> <span class="nc">ResNetBase</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_blocks</span></code>
is a list of of number of blocks for each feature map size. </li>
<li><code class="highlight"><span></span><span class="n">n_channels</span></code>
is the number of channels for each feature map size. </li>
<li><code class="highlight"><span></span><span class="n">bottlenecks</span></code>
is the number of channels the bottlenecks. If this is <code class="highlight"><span></span><span class="kc">None</span></code>
, <a href="#residual_block">residual blocks</a> are used. </li>
<li><code class="highlight"><span></span><span class="n">img_channels</span></code>
is the number of channels in the input. </li>
<li><code class="highlight"><span></span><span class="n">first_kernel_size</span></code>
is the kernel size of the initial convolution layer</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">256</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_blocks</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">n_channels</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
<span class="lineno">257</span> <span class="n">bottlenecks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">258</span> <span class="n">img_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span> <span class="n">first_kernel_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">7</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">267</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Number of blocks and number of channels for each feature map size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">270</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">n_blocks</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>If <a href="#bottleneck_residual_block">bottleneck residual blocks</a> are used, the number of channels in bottlenecks should be provided for each feature map size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">273</span> <span class="k">assert</span> <span class="n">bottlenecks</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">bottlenecks</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">n_channels</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Initial convolution layer maps from <code class="highlight"><span></span><span class="n">img_channels</span></code>
to number of channels in the first residual block (<code class="highlight"><span></span><span class="n">n_channels</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></code>
) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">277</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">img_channels</span><span class="p">,</span> <span class="n">n_channels</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="lineno">278</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">first_kernel_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="n">first_kernel_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>Batch norm after initial convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">280</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">n_channels</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>List of blocks </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">283</span> <span class="n">blocks</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Number of channels from previous layer (or block) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span> <span class="n">prev_channels</span> <span class="o">=</span> <span class="n">n_channels</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Loop through each feature map size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">287</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">channels</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">n_channels</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>The first block for the new feature map size, will have a stride length of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">2</span></span></span></span> except fro the very first block </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">290</span> <span class="n">stride</span> <span class="o">=</span> <span class="mi">2</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">blocks</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">1</span>
<span class="lineno">291</span>
<span class="lineno">292</span> <span class="k">if</span> <span class="n">bottlenecks</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p><a href="#residual_block">residual blocks</a> that maps from <code class="highlight"><span></span><span class="n">prev_channels</span></code>
to <code class="highlight"><span></span><span class="n">channels</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="n">blocks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ResidualBlock</span><span class="p">(</span><span class="n">prev_channels</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">))</span>
<span class="lineno">295</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p><a href="#bottleneck_residual_block">bottleneck residual blocks</a> that maps from <code class="highlight"><span></span><span class="n">prev_channels</span></code>
to <code class="highlight"><span></span><span class="n">channels</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">298</span> <span class="n">blocks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">BottleneckResidualBlock</span><span class="p">(</span><span class="n">prev_channels</span><span class="p">,</span> <span class="n">bottlenecks</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">channels</span><span class="p">,</span>
<span class="lineno">299</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Change the number of channels </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">302</span> <span class="n">prev_channels</span> <span class="o">=</span> <span class="n">channels</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>Add rest of the blocks - no change in feature map size or channels </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">304</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_blocks</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
<span class="lineno">305</span> <span class="k">if</span> <span class="n">bottlenecks</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p><a href="#residual_block">residual blocks</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">307</span> <span class="n">blocks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ResidualBlock</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="lineno">308</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p><a href="#bottleneck_residual_block">bottleneck residual blocks</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">310</span> <span class="n">blocks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">BottleneckResidualBlock</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">bottlenecks</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">channels</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Stack the blocks </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">313</span> <span class="bp">self</span><span class="o">.</span><span class="n">blocks</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">blocks</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">img_channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">315</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p>Initial convolution and batch normalization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">321</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>Residual (or bottleneck) blocks </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">323</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">blocks</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>Change <code class="highlight"><span></span><span class="n">x</span></code>
from shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">]</span></code>
to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">h</span> <span class="o">*</span> <span class="n">w</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">325</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>Global average pooling </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">327</span> <span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>