Files
Varuna Jayasiri c4d2e8cd22 docs
2025-07-31 08:48:07 +05:30

700 lines
41 KiB
HTML

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation/tutorial of the paper &quot;An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale&quot;"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Vision Transformer (ViT)"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of the paper &quot;An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale&quot;"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/vit/index.html"/>
<meta property="og:title" content="Vision Transformer (ViT)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Vision Transformer (ViT)"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Vision Transformer (ViT)"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of the paper &quot;An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale&quot;"/>
<title>Vision Transformer (ViT)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/transformers/vit/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">vit</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/vit/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Vision Transformer (ViT)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper <a href="https://arxiv.org/abs/2010.11929">An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale</a>.</p>
<p>Vision transformer applies a pure transformer to images without any convolution layers. They split the image into patches and apply a transformer on patch embeddings. <a href="#PathEmbeddings">Patch embeddings</a> are generated by applying a simple linear transformation to the flattened pixel values of the patch. Then a standard transformer encoder is fed with the patch embeddings, along with a classification token <code class="highlight"><span></span><span class="p">[</span><span class="n">CLS</span><span class="p">]</span></code>
. The encoding on the <code class="highlight"><span></span><span class="p">[</span><span class="n">CLS</span><span class="p">]</span></code>
token is used to classify the image with an MLP.</p>
<p>When feeding the transformer with the patches, learned positional embeddings are added to the patch embeddings, because the patch embeddings do not have any information about where that patch is from. The positional embeddings are a set of vectors for each patch location that get trained with gradient descent along with other parameters.</p>
<p>ViTs perform well when they are pre-trained on large datasets. The paper suggests pre-training them with an MLP classification head and then using a single linear layer when fine-tuning. The paper beats SOTA with a ViT pre-trained on a 300 million image dataset. They also use higher resolution images during inference while keeping the patch size the same. The positional embeddings for new patch locations are calculated by interpolating learning positional embeddings.</p>
<p>Here&#x27;s <a href="experiment.html">an experiment</a> that trains ViT on CIFAR-10. This doesn&#x27;t do very well because it&#x27;s trained on a small dataset. It&#x27;s a simple experiment that anyone can run and play with ViTs.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">45</span>
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerLayer</span>
<span class="lineno">47</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p> <a id="PatchEmbeddings"></a></p>
<h2>Get patch embeddings</h2>
<p>The paper splits the image into patches of equal size and do a linear transformation on the flattened pixels for each patch.</p>
<p>We implement the same thing through a convolution layer, because it&#x27;s simpler to implement.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span><span class="k">class</span> <span class="nc">PatchEmbeddings</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the transformer embeddings size </li>
<li><code class="highlight"><span></span><span class="n">patch_size</span></code>
is the size of the patch </li>
<li><code class="highlight"><span></span><span class="n">in_channels</span></code>
is the number of channels in the input image (3 for rgb)</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>We create a convolution layer with a kernel size and and stride length equal to patch size. This is equivalent to splitting the image into patches and doing a linear transformation on each patch. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">patch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the input image of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Apply convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Get the shape. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">bs</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Rearrange to shape <code class="highlight"><span></span><span class="p">[</span><span class="n">patches</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">85</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">h</span> <span class="o">*</span> <span class="n">w</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Return the patch embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p> <a id="LearnedPositionalEmbeddings"></a></p>
<h2>Add parameterized positional encodings</h2>
<p>This adds learned positional embeddings to patch embeddings.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span><span class="k">class</span> <span class="nc">LearnedPositionalEmbeddings</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the transformer embeddings size </li>
<li><code class="highlight"><span></span><span class="n">max_len</span></code>
is the maximum number of patches</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5_000</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Positional embeddings for each location </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the patch embeddings of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">patches</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Get the positional embeddings for the given patches </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Add to patch embeddings and return </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">pe</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p> <a id="ClassificationHead"></a></p>
<h2>MLP Classification Head</h2>
<p>This is the two layer MLP head to classify the image based on <code class="highlight"><span></span><span class="p">[</span><span class="n">CLS</span><span class="p">]</span></code>
token embedding.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span><span class="k">class</span> <span class="nc">ClassificationHead</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the transformer embedding size </li>
<li><code class="highlight"><span></span><span class="n">n_hidden</span></code>
is the size of the hidden layer </li>
<li><code class="highlight"><span></span><span class="n">n_classes</span></code>
is the number of classes in the classification task</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>First layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Activation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Second layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the transformer encoding for <code class="highlight"><span></span><span class="p">[</span><span class="n">CLS</span><span class="p">]</span></code>
token</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>First layer and activation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Second layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<h2>Vision Transformer</h2>
<p>This combines the <a href="#PatchEmbeddings">patch embeddings</a>, <a href="#LearnedPositionalEmbeddings">positional embeddings</a>, transformer and the <a href="#ClassificationHead">classification head</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span><span class="k">class</span> <span class="nc">VisionTransformer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">transformer_layer</span></code>
is a copy of a single <a href="../models.html#TransformerLayer">transformer layer</a>. We make copies of it to make the transformer with <code class="highlight"><span></span><span class="n">n_layers</span></code>
. </li>
<li><code class="highlight"><span></span><span class="n">n_layers</span></code>
is the number of <a href="../models.html#TransformerLayer">transformer layers</a>. </li>
<li><code class="highlight"><span></span><span class="n">patch_emb</span></code>
is the <a href="#PatchEmbeddings">patch embeddings layer</a>. </li>
<li><code class="highlight"><span></span><span class="n">pos_emb</span></code>
is the <a href="#LearnedPositionalEmbeddings">positional embeddings layer</a>. </li>
<li><code class="highlight"><span></span><span class="n">classification</span></code>
is the <a href="#ClassificationHead">classification head</a>.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transformer_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">163</span> <span class="n">patch_emb</span><span class="p">:</span> <span class="n">PatchEmbeddings</span><span class="p">,</span> <span class="n">pos_emb</span><span class="p">:</span> <span class="n">LearnedPositionalEmbeddings</span><span class="p">,</span>
<span class="lineno">164</span> <span class="n">classification</span><span class="p">:</span> <span class="n">ClassificationHead</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Patch embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">175</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span> <span class="o">=</span> <span class="n">patch_emb</span>
<span class="lineno">176</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_emb</span> <span class="o">=</span> <span class="n">pos_emb</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Classification head </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span> <span class="o">=</span> <span class="n">classification</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Make copies of the transformer layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">180</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">transformer_layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p><code class="highlight"><span></span><span class="p">[</span><span class="n">CLS</span><span class="p">]</span></code>
token embedding </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_token_emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">transformer_layer</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Final normalization layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">transformer_layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the input image of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">channels</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Get patch embeddings. This gives a tensor of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">patches</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">192</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Concatenate the <code class="highlight"><span></span><span class="p">[</span><span class="n">CLS</span><span class="p">]</span></code>
token embeddings before feeding the transformer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="n">cls_token_emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_token_emb</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">195</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">cls_token_emb</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Add positional embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Pass through transformer layers with no attention masking </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span><span class="p">:</span>
<span class="lineno">201</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Get the transformer output of the <code class="highlight"><span></span><span class="p">[</span><span class="n">CLS</span><span class="p">]</span></code>
token (which is the first in the sequence). </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Layer normalization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Classification head, to get logits </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>