Files
Varuna Jayasiri 1c14551a19 zh
2023-02-28 08:40:22 +05:30

727 lines
43 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="论文 “补丁就是你所需要的吗?” 的 PyTorch 实现/教程"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="补丁是你所需要的吗convMixer"/>
<meta name="twitter:description" content="论文 “补丁就是你所需要的吗?” 的 PyTorch 实现/教程"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/conv_mixer/index.html"/>
<meta property="og:title" content="补丁是你所需要的吗convMixer"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="补丁是你所需要的吗convMixer"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="补丁是你所需要的吗convMixer"/>
<meta property="og:description" content="论文 “补丁就是你所需要的吗?” 的 PyTorch 实现/教程"/>
<title>补丁是你所需要的吗convMixer</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/conv_mixer/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">conv_mixer</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/conv_mixer/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>补丁是你所需要的吗convMixer</h1>
<p>这是 <a href="https://pytorch.org">PayTorch</a> 实现的纸质<a href="https://papers.labml.ai/paper/2201.09792">补丁是你所需要的吗?</a></p>
<p><img alt="ConvMixer diagram from the paper" src="conv_mixer.png"></p>
<p>ConvMixer 与 <a href="../transformers/mlp_mixer/index.html">MLP 混音器</a>类似。MLP-Mixer 将空间维度和信道维度的混合分开,方法是跨空间维度应用 MLP然后在通道维度上应用 MLP空间 MLP 取代 <a href="../transformers/vit/index.html">ViT</a> 注意力,频道 MLP 是 <a href="../transformers/feed_forward.html">FFN</a>ViT)。</p>
C@@ <p>onvMixer 使用<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style="">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">1</span></span></span></span></span></span>卷积进行通道混合使用深度卷积进行空间混合。由于它是卷积而不是整个空间的完整MLP因此与 ViT 或 MLP 混音器相比它只混合附近的批次。此外MLP-Mixer 在每次混音时使用两层的 MLP而 ConvMixer 为每次混音使用单个层。</p>
<p>本文建议移除通道混音中的残余连接(逐点卷积),并且在空间混合(深度卷积)上只有一个剩余连接。他们还使用<a href="../normalization/batch_norm/index.html">批量归一化</a>而不是<a href="../normalization/layer_norm/index.html">图层规范化</a></p>
<p><a href="experiment.html">这是一个在 CIFAR-10 上训练 ConvMixer 的实验</a></p>
<p><a href="https://app.labml.ai/run/0fc344da2cd011ecb0bc3fdb2e774a3d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">39</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">40</span>
<span class="lineno">41</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">42</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p><a id="ConvMixerLayer"></a></p>
<h2>混音器层</h2>
<p>这是单个 ConvMixer 层。该模型将有一系列这样的。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span><span class="k">class</span> <span class="nc">ConvMixerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
是补丁嵌入中的通道数,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="">h</span></span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">kernel_size</span></code>
是空间卷积内核的大小,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.03148em;">k</span></span></span></span></span></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>深度卷积是每个通道的单独卷积。我们使用卷积层来完成此操作,该卷积层的组数等于通道数。因此,每个频道都是它自己的组。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="bp">self</span><span class="o">.</span><span class="n">depth_wise_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span>
<span class="lineno">64</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">kernel_size</span><span class="p">,</span>
<span class="lineno">65</span> <span class="n">groups</span><span class="o">=</span><span class="n">d_model</span><span class="p">,</span>
<span class="lineno">66</span> <span class="n">padding</span><span class="o">=</span><span class="p">(</span><span class="n">kernel_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>深度卷积后激活</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>深度卷积后的归一化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>逐点卷积是一种<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord coloredeq eqa" style=""><span class="mord" style="">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin" style="">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord" style="">1</span></span></span></span></span></span>卷积。即补丁嵌入的线性变换</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="bp">self</span><span class="o">.</span><span class="n">point_wise_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>逐点卷积后激活</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>逐点卷积后的归一化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>对于围绕深度卷积的剩余连接</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">residual</span> <span class="o">=</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>深度卷积、激活和归一化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">depth_wise_conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">86</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">87</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>添加剩余连接</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">x</span> <span class="o">+=</span> <span class="n">residual</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>逐点卷积、激活和归一化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">point_wise_conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">94</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">95</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p><a id="PatchEmbeddings"></a></p>
<h2>获取补丁嵌入</h2>
<p>这会将图像拆分为大小的补丁,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.7777700000000001em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">p</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span></span>并为每个补丁提供嵌入。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span><span class="k">class</span> <span class="nc">PatchEmbeddings</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
是补丁嵌入中的通道数<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="">h</span></span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">patch_size</span></code>
是补丁的大小,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">p</span></span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">in_channels</span></code>
是输入图像中的通道数rgb 为 3</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>我们创建一个卷积层,其内核大小和步长等于补丁大小。这相当于将图像分割成色块并在每个面片上进行线性变换。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">patch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>激活功能</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>批量标准化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
是形状的输入图像<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>应用卷积层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>激活和规范化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">135</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p><a id="ClassificationHead"></a></p>
<h2>分类主管</h2>
<p>它们进行平均池(取所有补丁嵌入的均值)和最终的线性变换来预测影像类的对数概率。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span><span class="k">class</span> <span class="nc">ClassificationHead</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
是补丁嵌入中的通道数,<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqd" style=""><span class="mord mathnormal" style="">h</span></span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">n_classes</span></code>
是分类任务中的类数</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>平均池</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">AdaptiveAvgPool2d</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>线性层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>平均汇集</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>得到嵌入,<code class="highlight"><span></span><span class="n">x</span></code>
会有形状<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>线性层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<h2>混音器</h2>
<p>它结合了补丁嵌入块、许多 ConvMixer 层和一个分类头。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span><span class="k">class</span> <span class="nc">ConvMixer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">conv_mixer_layer</span></code>
是单个 C <a href="#ConvMixerLayer">onvMixer 层</a>的副本。我们制作它的副本来制作 ConvMixer<code class="highlight"><span></span><span class="n">n_layers</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">n_layers</span></code>
是 ConvMixer 层(或深度)的数量<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathnormal">d</span></span></span></span></span></li>
<li><code class="highlight"><span></span><span class="n">patch_emb</span></code>
<a href="#PatchEmbeddings">补丁嵌入层</a></li>
<li><code class="highlight"><span></span><span class="n">classification</span></code>
<a href="#ClassificationHead">分类头</a></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">conv_mixer_layer</span><span class="p">:</span> <span class="n">ConvMixerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">182</span> <span class="n">patch_emb</span><span class="p">:</span> <span class="n">PatchEmbeddings</span><span class="p">,</span>
<span class="lineno">183</span> <span class="n">classification</span><span class="p">:</span> <span class="n">ClassificationHead</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>补丁嵌入</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span> <span class="o">=</span> <span class="n">patch_emb</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>分类主管</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span> <span class="o">=</span> <span class="n">classification</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>制作 C <a href="#ConvMixerLayer">onvMixer 图层</a>的副本</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_mixer_layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">conv_mixer_layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
是形状的输入图像<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>获取补丁嵌入。这给出了形状的张量<code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">height</span> <span class="o">/</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">width</span> <span class="o">/</span> <span class="n">patch_size</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>穿过 <a href="#ConvMixerLayer">ConvMixer 图层</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_mixer_layers</span><span class="p">:</span>
<span class="lineno">208</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>分类头,获取日志</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>