Files
Varuna Jayasiri c4d2e8cd22 docs
2025-07-31 08:48:07 +05:30

724 lines
43 KiB
HTML
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation/tutorial of the paper &quot;Patches Are All You Need?&quot;"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Patches Are All You Need? (ConvMixer)"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of the paper &quot;Patches Are All You Need?&quot;"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/conv_mixer/index.html"/>
<meta property="og:title" content="Patches Are All You Need? (ConvMixer)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Patches Are All You Need? (ConvMixer)"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Patches Are All You Need? (ConvMixer)"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of the paper &quot;Patches Are All You Need?&quot;"/>
<title>Patches Are All You Need? (ConvMixer)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/conv_mixer/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">conv_mixer</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/conv_mixer/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Patches Are All You Need? (ConvMixer)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper <a href="https://arxiv.org/abs/2201.09792">Patches Are All You Need?</a>.</p>
<p><img alt="ConvMixer diagram from the paper" src="conv_mixer.png"></p>
<p>ConvMixer is Similar to <a href="../transformers/mlp_mixer/index.html">MLP-Mixer</a>. MLP-Mixer separates mixing of spatial and channel dimensions, by applying an MLP across spatial dimension and then an MLP across the channel dimension (spatial MLP replaces the <a href="../transformers/vit/index.html">ViT</a> attention and channel MLP is the <a href="../transformers/feed_forward.html">FFN</a> of ViT).</p>
<p>ConvMixer uses a <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style="">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">1</span></span></span></span></span></span> convolution for channel mixing and a depth-wise convolution for spatial mixing. Since it&#x27;s a convolution instead of a full MLP across the space, it mixes only the nearby batches in contrast to ViT or MLP-Mixer. Also, the MLP-mixer uses MLPs of two layers for each mixing and ConvMixer uses a single layer for each mixing.</p>
<p>The paper recommends removing the residual connection across the channel mixing (point-wise convolution) and having only a residual connection over the spatial mixing (depth-wise convolution). They also use <a href="../normalization/batch_norm/index.html">Batch normalization</a> instead of <a href="../normalization/layer_norm/index.html">Layer normalization</a>.</p>
<p>Here&#x27;s <a href="experiment.html">an experiment</a> that trains ConvMixer on CIFAR-10.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">36</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">37</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">38</span>
<span class="lineno">39</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p> <a id="ConvMixerLayer"></a></p>
<h2>ConvMixer layer</h2>
<p>This is a single ConvMixer layer. The model will have a series of these.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span><span class="k">class</span> <span class="nc">ConvMixerLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the number of channels in patch embeddings, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="">h</span></span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">kernel_size</span></code>
is the size of the kernel of spatial convolution, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.03148em;">k</span></span></span></span></span></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Depth-wise convolution is separate convolution for each channel. We do this with a convolution layer with the number of groups equal to the number of channels. So that each channel is it&#x27;s own group. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">depth_wise_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span>
<span class="lineno">61</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">kernel_size</span><span class="p">,</span>
<span class="lineno">62</span> <span class="n">groups</span><span class="o">=</span><span class="n">d_model</span><span class="p">,</span>
<span class="lineno">63</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="n">kernel_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Activation after depth-wise convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Normalization after depth-wise convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Point-wise convolution is a <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style="">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">1</span></span></span></span></span></span> convolution. i.e. a linear transformation of patch embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="bp">self</span><span class="o">.</span><span class="n">point_wise_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Activation after point-wise convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Normalization after point-wise convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>For the residual connection around the depth-wise convolution </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">residual</span> <span class="o">=</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Depth-wise convolution, activation and normalization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">depth_wise_conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">83</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">84</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Add residual connection </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">x</span> <span class="o">+=</span> <span class="n">residual</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Point-wise convolution, activation and normalization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">point_wise_conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">91</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">92</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p> <a id="PatchEmbeddings"></a></p>
<h2>Get patch embeddings</h2>
<p>This splits the image into patches of size <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.7777700000000001em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">p</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span></span> and gives an embedding for each patch.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span><span class="k">class</span> <span class="nc">PatchEmbeddings</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the number of channels in patch embeddings <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="">h</span></span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">patch_size</span></code>
is the size of the patch, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">in_channels</span></code>
is the number of channels in the input image (3 for rgb)</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>We create a convolution layer with a kernel size and and stride length equal to patch size. This is equivalent to splitting the image into patches and doing a linear transformation on each patch. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">patch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Activation function </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Batch normalization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the input image of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Apply convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Activation and normalization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">132</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p> <a id="ClassificationHead"></a></p>
<h2>Classification Head</h2>
<p>They do average pooling (taking the mean of all patch embeddings) and a final linear transformation to predict the log-probabilities of the image classes.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span><span class="k">class</span> <span class="nc">ClassificationHead</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the number of channels in patch embeddings, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="">h</span></span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">n_classes</span></code>
is the number of classes in the classification task</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Average Pool </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">155</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">AdaptiveAvgPool2d</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Linear layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">159</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Average pooling </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Get the embedding, <code class="highlight"><span></span><span class="n">x</span></code>
will have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Linear layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<h2>ConvMixer</h2>
<p>This combines the patch embeddings block, a number of ConvMixer layers and a classification head.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span><span class="k">class</span> <span class="nc">ConvMixer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">conv_mixer_layer</span></code>
is a copy of a single <a href="#ConvMixerLayer">ConvMixer layer</a>. We make copies of it to make ConvMixer with <code class="highlight"><span></span><span class="n">n_layers</span></code>
. </li>
<li><code class="highlight"><span></span><span class="n">n_layers</span></code>
is the number of ConvMixer layers (or depth), <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">d</span></span></span></span></span>. </li>
<li><code class="highlight"><span></span><span class="n">patch_emb</span></code>
is the <a href="#PatchEmbeddings">patch embeddings layer</a>. </li>
<li><code class="highlight"><span></span><span class="n">classification</span></code>
is the <a href="#ClassificationHead">classification head</a>.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">conv_mixer_layer</span><span class="p">:</span> <span class="n">ConvMixerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">179</span> <span class="n">patch_emb</span><span class="p">:</span> <span class="n">PatchEmbeddings</span><span class="p">,</span>
<span class="lineno">180</span> <span class="n">classification</span><span class="p">:</span> <span class="n">ClassificationHead</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Patch embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span> <span class="o">=</span> <span class="n">patch_emb</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Classification head </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">192</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span> <span class="o">=</span> <span class="n">classification</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Make copies of the <a href="#ConvMixerLayer">ConvMixer layer</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_mixer_layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">conv_mixer_layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the input image of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Get patch embeddings. This gives a tensor of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">height</span> <span class="o">/</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">width</span> <span class="o">/</span> <span class="n">patch_size</span><span class="p">]</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Pass through <a href="#ConvMixerLayer">ConvMixer layers</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_mixer_layers</span><span class="p">:</span>
<span class="lineno">205</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Classification head, to get logits </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>