This commit is contained in:
Varuna Jayasiri
2021-07-17 15:24:17 +05:30
parent f0bf8d39e4
commit f038ab673d
12 changed files with 986 additions and 193 deletions

View File

@ -95,6 +95,7 @@ implementations.</p>
<li><a href="transformers/mlm/index.html">Masked Language Model</a></li>
<li><a href="transformers/mlp_mixer/index.html">MLP-Mixer: An all-MLP Architecture for Vision</a></li>
<li><a href="transformers/gmlp/index.html">Pay Attention to MLPs (gMLP)</a></li>
<li><a href="transformers/vit/index.html">Vision Transformer (ViT)</a></li>
</ul>
<h4><a href="recurrent_highway_networks/index.html">Recurrent Highway Networks</a></h4>
<h4><a href="lstm/index.html">LSTM</a></h4>

View File

@ -117,12 +117,15 @@ It does single GPU training but we implement the concept of switching as describ
<h2><a href="gmlp/index.html">Pay Attention to MLPs (gMLP)</a></h2>
<p>This is an implementation of the paper
<a href="https://papers.labml.ai/paper/2105.08050">Pay Attention to MLPs</a>.</p>
<h2><a href="vit/index.html">Vision Transformer (ViT)</a></h2>
<p>This is an implementation of the paper
<a href="https://arxiv.org/abs/2010.11929">An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span><span></span><span class="kn">from</span> <span class="nn">.configs</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
<span class="lineno">88</span><span class="kn">from</span> <span class="nn">.models</span> <span class="kn">import</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">EncoderDecoder</span>
<span class="lineno">89</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">90</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.xl.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span></pre></div>
<div class="highlight"><pre><span class="lineno">92</span><span></span><span class="kn">from</span> <span class="nn">.configs</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
<span class="lineno">93</span><span class="kn">from</span> <span class="nn">.models</span> <span class="kn">import</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">EncoderDecoder</span>
<span class="lineno">94</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">95</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.xl.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span></pre></div>
</div>
</div>
</div>

View File

@ -3,24 +3,24 @@
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Train a ViT on CIFAR 10"/>
<meta name="description" content="Train a Vision Transformer (ViT) on CIFAR 10"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Train a ViT on CIFAR 10"/>
<meta name="twitter:description" content="Train a ViT on CIFAR 10"/>
<meta name="twitter:title" content="Train a Vision Transformer (ViT) on CIFAR 10"/>
<meta name="twitter:description" content="Train a Vision Transformer (ViT) on CIFAR 10"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/vit/experiment.html"/>
<meta property="og:title" content="Train a ViT on CIFAR 10"/>
<meta property="og:title" content="Train a Vision Transformer (ViT) on CIFAR 10"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Train a ViT on CIFAR 10"/>
<meta property="og:description" content="Train a ViT on CIFAR 10"/>
<meta property="og:title" content="Train a Vision Transformer (ViT) on CIFAR 10"/>
<meta property="og:description" content="Train a Vision Transformer (ViT) on CIFAR 10"/>
<title>Train a ViT on CIFAR 10</title>
<title>Train a Vision Transformer (ViT) on CIFAR 10</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/vit/experiment.html"/>
@ -67,13 +67,14 @@
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Train a ViT on CIFAR 10</h1>
<h1>Train a <a href="index.html">Vision Transformer (ViT)</a> on CIFAR 10</h1>
<p><a href="https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">12</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.cifar10</span> <span class="kn">import</span> <span class="n">CIFAR10Configs</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span></pre></div>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.cifar10</span> <span class="kn">import</span> <span class="n">CIFAR10Configs</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -82,11 +83,11 @@
<a href='#section-1'>#</a>
</div>
<h2>Configurations</h2>
<p>We use <a href="../experiments/cifar10.html"><code>CIFAR10Configs</code></a> which defines all the
<p>We use <a href="../../experiments/cifar10.html"><code>CIFAR10Configs</code></a> which defines all the
dataset related configurations, optimizer, and a training loop.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">17</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">19</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">CIFAR10Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -94,31 +95,22 @@ dataset related configurations, optimizer, and a training loop.</p>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p><a href="../configs.html#TransformerConfigs">Transformer configurations</a>
to get <a href="../models.html#TransformerLayer">transformer layer</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">25</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">TransformerConfigs</span>
<span class="lineno">26</span>
<span class="lineno">27</span> <span class="n">patch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span>
<span class="lineno">28</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span>
<span class="lineno">29</span> <span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span></pre></div>
<div class="highlight"><pre><span class="lineno">29</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">TransformerConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs doc-strings'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<h3>Create model</h3>
<p>Size of a patch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">)</span>
<span class="lineno">33</span><span class="k">def</span> <span class="nf">_transformer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span>
<span class="lineno">34</span> <span class="k">return</span> <span class="n">TransformerConfigs</span><span class="p">()</span>
<span class="lineno">35</span>
<span class="lineno">36</span>
<span class="lineno">37</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">38</span><span class="k">def</span> <span class="nf">_vit</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">32</span> <span class="n">patch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -126,17 +118,10 @@ dataset related configurations, optimizer, and a training loop.</p>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Size of the hidden layer in classification head</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.vit</span> <span class="kn">import</span> <span class="n">VisionTransformer</span><span class="p">,</span> <span class="n">LearnedPositionalEmbeddings</span><span class="p">,</span> <span class="n">ClassificationHead</span><span class="p">,</span> \
<span class="lineno">43</span> <span class="n">PatchEmbeddings</span>
<span class="lineno">44</span>
<span class="lineno">45</span> <span class="n">d_model</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">d_model</span>
<span class="lineno">46</span> <span class="k">return</span> <span class="n">VisionTransformer</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">n_layers</span><span class="p">,</span>
<span class="lineno">47</span> <span class="n">PatchEmbeddings</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">patch_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span>
<span class="lineno">48</span> <span class="n">LearnedPositionalEmbeddings</span><span class="p">(</span><span class="n">d_model</span><span class="p">),</span>
<span class="lineno">49</span> <span class="n">ClassificationHead</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_classes</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">34</span> <span class="n">n_hidden_classification</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -144,21 +129,22 @@ dataset related configurations, optimizer, and a training loop.</p>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Number of classes in the task</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">36</span> <span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Create experiment</p>
<p>Create transformer configs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;ViT&#39;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">&#39;cifar10&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">39</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">)</span>
<span class="lineno">40</span><span class="k">def</span> <span class="nf">_transformer</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -166,22 +152,22 @@ dataset related configurations, optimizer, and a training loop.</p>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Create configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">return</span> <span class="n">TransformerConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Load configurations</p>
<h3>Create model</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">59</span> <span class="s1">&#39;device.cuda_device&#39;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">47</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">48</span><span class="k">def</span> <span class="nf">_vit</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -189,22 +175,11 @@ dataset related configurations, optimizer, and a training loop.</p>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>&lsquo;optimizer.optimizer&rsquo;: &lsquo;Noam&rsquo;,
&lsquo;optimizer.learning_rate&rsquo;: 1.,</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">64</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
<span class="lineno">65</span> <span class="s1">&#39;optimizer.d_model&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
<span class="lineno">66</span>
<span class="lineno">67</span> <span class="s1">&#39;transformer.d_model&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
<span class="lineno">68</span>
<span class="lineno">69</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">1000</span><span class="p">,</span>
<span class="lineno">70</span> <span class="s1">&#39;train_batch_size&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span>
<span class="lineno">71</span>
<span class="lineno">72</span> <span class="s1">&#39;train_dataset&#39;</span><span class="p">:</span> <span class="s1">&#39;cifar10_train_augmented&#39;</span><span class="p">,</span>
<span class="lineno">73</span> <span class="s1">&#39;valid_dataset&#39;</span><span class="p">:</span> <span class="s1">&#39;cifar10_valid_no_augment&#39;</span><span class="p">,</span>
<span class="lineno">74</span> <span class="p">})</span></pre></div>
<div class="highlight"><pre><span class="lineno">52</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.vit</span> <span class="kn">import</span> <span class="n">VisionTransformer</span><span class="p">,</span> <span class="n">LearnedPositionalEmbeddings</span><span class="p">,</span> <span class="n">ClassificationHead</span><span class="p">,</span> \
<span class="lineno">53</span> <span class="n">PatchEmbeddings</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -212,10 +187,10 @@ dataset related configurations, optimizer, and a training loop.</p>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Set model for saving/loading</p>
<p>Transformer size from <a href="../configs.html#TransformerConfigs">Transformer configurations</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">d_model</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">d_model</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -223,11 +198,13 @@ dataset related configurations, optimizer, and a training loop.</p>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Start the experiment and run the training loop</p>
<p>Create a vision transformer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">79</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">58</span> <span class="k">return</span> <span class="n">VisionTransformer</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">n_layers</span><span class="p">,</span>
<span class="lineno">59</span> <span class="n">PatchEmbeddings</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">patch_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span>
<span class="lineno">60</span> <span class="n">LearnedPositionalEmbeddings</span><span class="p">(</span><span class="n">d_model</span><span class="p">),</span>
<span class="lineno">61</span> <span class="n">ClassificationHead</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_hidden_classification</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_classes</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -238,8 +215,133 @@ dataset related configurations, optimizer, and a training loop.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">84</span> <span class="n">main</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">64</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Create experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;ViT&#39;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">&#39;cifar10&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Create configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Load configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Optimizer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">73</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Transformer embedding size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="s1">&#39;transformer.d_model&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Training epochs and batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">1000</span><span class="p">,</span>
<span class="lineno">80</span> <span class="s1">&#39;train_batch_size&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Augment CIFAR 10 images for training</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="s1">&#39;train_dataset&#39;</span><span class="p">:</span> <span class="s1">&#39;cifar10_train_augmented&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Do not augment CIFAR 10 images for validation</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="s1">&#39;valid_dataset&#39;</span><span class="p">:</span> <span class="s1">&#39;cifar10_valid_no_augment&#39;</span><span class="p">,</span>
<span class="lineno">86</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Set model for saving/loading</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Start the experiment and run the training loop</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">91</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">96</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
</div>

View File

@ -3,24 +3,24 @@
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="description" content="A PyTorch implementation/tutorial of the paper &quot;An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale&quot;"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="__init__.py"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:title" content="Vision Transformer (ViT)"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of the paper &quot;An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale&quot;"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/vit/index.html"/>
<meta property="og:title" content="__init__.py"/>
<meta property="og:title" content="Vision Transformer (ViT)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="__init__.py"/>
<meta property="og:description" content=""/>
<meta property="og:title" content="Vision Transformer (ViT)"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of the paper &quot;An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale&quot;"/>
<title>__init__.py</title>
<title>Vision Transformer (ViT)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/vit/index.html"/>
@ -63,19 +63,46 @@
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Vision Transformer (ViT)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://arxiv.org/abs/2010.11929">An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale</a>.</p>
<p>Vision transformer applies a pure transformer to images
without any convolution layers.
They split the image into patches and apply a transformer on patch embeddings.
<a href="#PathEmbeddings">Patch embeddings</a> are generated by applying a simple linear transformation
to the flattened pixel values of the patch.
Then a standard transformer encoder is fed with the patch embeddings, along with a
classification token <code>[CLS]</code>.
The encoding on the <code>[CLS]</code> token is used to classify the image with an MLP.</p>
<p>When feeding the transformer with the patches, learned positional embeddings are
added to the patch embeddings, because the patch embeddings do not have any information
about where that patch is from.
The positional embeddings are a set of vectors for each patch location that get trained
with gradient descent along with other parameters.</p>
<p>ViTs perform well when they are pre-trained on large datasets.
The paper suggests pre-training them with an MLP classification head and
then using a single linear layer when fine-tuning.
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
They also use higher resolution images during inference while keeping the
patch size the same.
The positional embeddings for new patch locations are calculated by interpolating
learning positional embeddings.</p>
<p>Here&rsquo;s <a href="experiment.html">an experiment</a> that trains ViT on CIFAR-10.
This doesn&rsquo;t do very well because it&rsquo;s trained on a small dataset.
It&rsquo;s a simple experiment that anyone can run and play with ViTs.</p>
<p><a href="https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">2</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">3</span>
<span class="lineno">4</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">5</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerLayer</span>
<span class="lineno">6</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
<div class="highlight"><pre><span class="lineno">45</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">47</span>
<span class="lineno">48</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">49</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerLayer</span>
<span class="lineno">50</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -84,36 +111,40 @@
<a href='#section-1'>#</a>
</div>
<p><a id="PatchEmbeddings"></p>
<h2>Embed patches</h2>
<h2>Get patch embeddings</h2>
<p></a></p>
<p>The paper splits the image into patches of equal size and do a linear transformation
on the flattened pixels for each patch.</p>
<p>We implement the same thing through a convolution layer, because it&rsquo;s simpler to implement.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">9</span><span class="k">class</span> <span class="nc">PatchEmbeddings</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">53</span><span class="k">class</span> <span class="nc">PatchEmbeddings</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul>
<li><code>d_model</code> is the transformer embeddings size</li>
<li><code>patch_size</code> is the size of the patch</li>
<li><code>in_channels</code> is the number of channels in the input image (3 for rgb)</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">16</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">17</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">18</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span>
<span class="lineno">19</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">patch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>x has shape <code>[batch_size, channels, height, width]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">21</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">71</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -121,15 +152,12 @@
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>We create a convolution layer with a kernel size and and stride length equal to patch size.
This is equivalent to splitting the image into patches and doing a linear
transformation on each patch.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">25</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">26</span> <span class="n">bs</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
<span class="lineno">27</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">28</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">h</span> <span class="o">*</span> <span class="n">w</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
<span class="lineno">29</span>
<span class="lineno">30</span> <span class="k">return</span> <span class="n">x</span></pre></div>
<div class="highlight"><pre><span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">patch_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -137,12 +165,12 @@
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p><a id="LearnedPositionalEmbeddings"></p>
<h2>Add parameterized positional encodings</h2>
<p></a></p>
<ul>
<li><code>x</code> is the input image of shape <code>[batch_size, channels, height, width]</code></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span><span class="k">class</span> <span class="nc">LearnedPositionalEmbeddings</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -150,12 +178,10 @@
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Apply convolution layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5_000</span><span class="p">):</span>
<span class="lineno">41</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">83</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -163,12 +189,10 @@
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Get the shape.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">45</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
<span class="lineno">46</span> <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">pe</span></pre></div>
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">bs</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -176,10 +200,11 @@
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Rearrange to shape <code>[patches, batch_size, d_model]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span><span class="k">class</span> <span class="nc">ClassificationHead</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">88</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">h</span> <span class="o">*</span> <span class="n">w</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -187,42 +212,38 @@
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Return the patch embeddings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">51</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">53</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">)</span>
<span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span>
<span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">91</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p><a id="LearnedPositionalEmbeddings"></p>
<h2>Add parameterized positional encodings</h2>
<p></a></p>
<p>This adds learned positional embeddings to patch embeddings.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">58</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">59</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="lineno">60</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">61</span>
<span class="lineno">62</span> <span class="k">return</span> <span class="n">x</span></pre></div>
<div class="highlight"><pre><span class="lineno">94</span><span class="k">class</span> <span class="nc">LearnedPositionalEmbeddings</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<ul>
<li><code>d_model</code> is the transformer embeddings size</li>
<li><code>max_len</code> is the maximum number of patches</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span><span class="k">class</span> <span class="nc">VisionTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">103</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5_000</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -233,10 +254,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transformer_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">67</span> <span class="n">patch_emb</span><span class="p">:</span> <span class="n">PatchEmbeddings</span><span class="p">,</span> <span class="n">pos_emb</span><span class="p">:</span> <span class="n">LearnedPositionalEmbeddings</span><span class="p">,</span>
<span class="lineno">68</span> <span class="n">classification</span><span class="p">:</span> <span class="n">ClassificationHead</span><span class="p">):</span>
<span class="lineno">69</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">108</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -244,38 +262,368 @@
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Make copies of the transformer layer</p>
<p>Positional embeddings for each location</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span> <span class="o">=</span> <span class="n">classification</span>
<span class="lineno">72</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_emb</span> <span class="o">=</span> <span class="n">pos_emb</span>
<span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span> <span class="o">=</span> <span class="n">patch_emb</span>
<span class="lineno">74</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">transformer_layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span>
<span class="lineno">75</span>
<span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_token_emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">transformer_layer</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<ul>
<li><code>x</code> is the patch embeddings of shape <code>[patches, batch_size, d_model]</code></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Get the positional embeddings for the given patches</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Add to patch embeddings and return</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">pe</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p><a id="ClassificationHead"></p>
<h2>MLP Classification Head</h2>
<p></a></p>
<p>This is the two layer MLP head to classify the image based on <code>[CLS]</code> token embedding.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span><span class="k">class</span> <span class="nc">ClassificationHead</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<ul>
<li><code>d_model</code> is the transformer embedding size</li>
<li><code>n_hidden</code> is the size of the hidden layer</li>
<li><code>n_classes</code> is the number of classes in the classification task</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">79</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">80</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">81</span> <span class="n">cls_token_emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_token_emb</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">82</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">cls_token_emb</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span>
<span class="lineno">83</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span><span class="p">:</span>
<span class="lineno">84</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="lineno">85</span>
<span class="lineno">86</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="lineno">87</span>
<span class="lineno">88</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">89</span>
<span class="lineno">90</span> <span class="k">return</span> <span class="n">x</span></pre></div>
<div class="highlight"><pre><span class="lineno">136</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>First layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Activation</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Second layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<ul>
<li><code>x</code> is the transformer encoding for <code>[CLS]</code> token</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>First layer and activation</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Second layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<h2>Vision Transformer</h2>
<p>This combines the <a href="#PatchEmbeddings">patch embeddings</a>,
<a href="#LearnedPositionalEmbeddings">positional embeddings</a>,
transformer and the <a href="#ClassificationHead">classification head</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span><span class="k">class</span> <span class="nc">VisionTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<ul>
<li><code>transformer_layer</code> is a copy of a single <a href="../models.html#TransformerLayer">transformer layer</a>.
We make copies of it to make the transformer with <code>n_layers</code>.</li>
<li><code>n_layers</code> is the number of [transformer layers]((../models.html#TransformerLayer).</li>
<li><code>patch_emb</code> is the <a href="#PatchEmbeddings">patch embeddings layer</a>.</li>
<li><code>pos_emb</code> is the <a href="#LearnedPositionalEmbeddings">positional embeddings layer</a>.</li>
<li><code>classification</code> is the <a href="#ClassificationHead">classification head</a>.</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">transformer_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">166</span> <span class="n">patch_emb</span><span class="p">:</span> <span class="n">PatchEmbeddings</span><span class="p">,</span> <span class="n">pos_emb</span><span class="p">:</span> <span class="n">LearnedPositionalEmbeddings</span><span class="p">,</span>
<span class="lineno">167</span> <span class="n">classification</span><span class="p">:</span> <span class="n">ClassificationHead</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Patch embeddings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span> <span class="o">=</span> <span class="n">patch_emb</span>
<span class="lineno">179</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_emb</span> <span class="o">=</span> <span class="n">pos_emb</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Classification head</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span> <span class="o">=</span> <span class="n">classification</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Make copies of the transformer layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">transformer_layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p><code>[CLS]</code> token embedding</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_token_emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">transformer_layer</span><span class="o">.</span><span class="n">size</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Final normalization layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">transformer_layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<ul>
<li><code>x</code> is the input image of shape <code>[batch_size, channels, height, width]</code></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Get patch embeddings. This gives a tensor of shape <code>[patches, batch_size, d_model]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Add positional embeddings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Concatenate the <code>[CLS]</code> token embeddings before feeding the transformer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">cls_token_emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_token_emb</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">200</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">cls_token_emb</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Pass through transformer layers with no attention masking</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_layers</span><span class="p">:</span>
<span class="lineno">204</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Get the transformer output of the <code>[CLS]</code> token (which is the first in the sequence).</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Layer normalization</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Classification head, to get logits</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">216</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
</div>

View File

@ -0,0 +1,162 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content=" Vision Transformer (ViT)"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/vit/readme.html"/>
<meta property="og:title" content=" Vision Transformer (ViT)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content=" Vision Transformer (ViT)"/>
<meta property="og:description" content=""/>
<title> Vision Transformer (ViT)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/vit/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">vit</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/transformers/vit/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/transformer/vit/index.html">Vision Transformer (ViT)</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://arxiv.org/abs/2010.11929">An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale</a>.</p>
<p>Vision transformer applies a pure transformer to images
without any convolution layers.
They split the image into patches and apply a transformer on patch embeddings.
<a href="https://nn.labml.ai/transformer/vit/index.html#PathEmbeddings">Patch embeddings</a> are generated by applying a simple linear transformation
to the flattened pixel values of the patch.
Then a standard transformer encoder is fed with the patch embeddings, along with a
classification token <code>[CLS]</code>.
The encoding on the <code>[CLS]</code> token is used to classify the image with an MLP.</p>
<p>When feeding the transformer with the patches, learned positional embeddings are
added to the patch embeddings, because the patch embeddings do not have any information
about where that patch is from.
The positional embeddings are a set of vectors for each patch location that get trained
with gradient descent along with other parameters.</p>
<p>ViTs perform well when they are pre-trained on large datasets.
The paper suggests pre-training them with an MLP classification head and
then using a single linear layer when fine-tuning.
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
They also use higher resolution images during inference while keeping the
patch size the same.
The positional embeddings for new patch locations are calculated by interpolating
learning positional embeddings.</p>
<p>Here&rsquo;s <a href="https://nn.labml.ai/transformer/vit/experiment.html">an experiment</a> that trains ViT on CIFAR-10.
This doesn&rsquo;t do very well because it&rsquo;s trained on a small dataset.
It&rsquo;s a simple experiment that anyone can run and play with ViTs.</p>
</div>
<div class='code'>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -31,6 +31,7 @@ implementations.
* [Masked Language Model](transformers/mlm/index.html)
* [MLP-Mixer: An all-MLP Architecture for Vision](transformers/mlp_mixer/index.html)
* [Pay Attention to MLPs (gMLP)](transformers/gmlp/index.html)
* [Vision Transformer (ViT)](transformers/vit/index.html)
#### ✨ [Recurrent Highway Networks](recurrent_highway_networks/index.html)

View File

@ -82,6 +82,11 @@ This is an implementation of the paper
This is an implementation of the paper
[Pay Attention to MLPs](https://papers.labml.ai/paper/2105.08050).
## [Vision Transformer (ViT)](vit/index.html)
This is an implementation of the paper
[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
"""
from .configs import TransformerConfigs

View File

@ -1,3 +1,47 @@
"""
---
title: Vision Transformer (ViT)
summary: >
A PyTorch implementation/tutorial of the paper
"An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale"
---
# Vision Transformer (ViT)
This is a [PyTorch](https://pytorch.org) implementation of the paper
[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
Vision transformer applies a pure transformer to images
without any convolution layers.
They split the image into patches and apply a transformer on patch embeddings.
[Patch embeddings](#PathEmbeddings) are generated by applying a simple linear transformation
to the flattened pixel values of the patch.
Then a standard transformer encoder is fed with the patch embeddings, along with a
classification token `[CLS]`.
The encoding on the `[CLS]` token is used to classify the image with an MLP.
When feeding the transformer with the patches, learned positional embeddings are
added to the patch embeddings, because the patch embeddings do not have any information
about where that patch is from.
The positional embeddings are a set of vectors for each patch location that get trained
with gradient descent along with other parameters.
ViTs perform well when they are pre-trained on large datasets.
The paper suggests pre-training them with an MLP classification head and
then using a single linear layer when fine-tuning.
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
They also use higher resolution images during inference while keeping the
patch size the same.
The positional embeddings for new patch locations are calculated by interpolating
learning positional embeddings.
Here's [an experiment](experiment.html) that trains ViT on CIFAR-10.
This doesn't do very well because it's trained on a small dataset.
It's a simple experiment that anyone can run and play with ViTs.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
"""
import torch
from torch import nn
@ -9,24 +53,41 @@ from labml_nn.utils import clone_module_list
class PatchEmbeddings(Module):
"""
<a id="PatchEmbeddings">
## Embed patches
## Get patch embeddings
</a>
The paper splits the image into patches of equal size and do a linear transformation
on the flattened pixels for each patch.
We implement the same thing through a convolution layer, because it's simpler to implement.
"""
def __init__(self, d_model: int, patch_size: int, in_channels: int):
"""
* `d_model` is the transformer embeddings size
* `patch_size` is the size of the patch
* `in_channels` is the number of channels in the input image (3 for rgb)
"""
super().__init__()
self.patch_size = patch_size
# We create a convolution layer with a kernel size and and stride length equal to patch size.
# This is equivalent to splitting the image into patches and doing a linear
# transformation on each patch.
self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
def __call__(self, x: torch.Tensor):
"""
x has shape `[batch_size, channels, height, width]`
* `x` is the input image of shape `[batch_size, channels, height, width]`
"""
# Apply convolution layer
x = self.conv(x)
# Get the shape.
bs, c, h, w = x.shape
# Rearrange to shape `[patches, batch_size, d_model]`
x = x.permute(2, 3, 0, 1)
x = x.view(h * w, bs, c)
# Return the patch embeddings
return x
@ -35,56 +96,121 @@ class LearnedPositionalEmbeddings(Module):
<a id="LearnedPositionalEmbeddings">
## Add parameterized positional encodings
</a>
This adds learned positional embeddings to patch embeddings.
"""
def __init__(self, d_model: int, max_len: int = 5_000):
"""
* `d_model` is the transformer embeddings size
* `max_len` is the maximum number of patches
"""
super().__init__()
# Positional embeddings for each location
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
def __call__(self, x: torch.Tensor):
"""
* `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
"""
# Get the positional embeddings for the given patches
pe = self.positional_encodings[x.shape[0]]
# Add to patch embeddings and return
return x + pe
class ClassificationHead(Module):
"""
<a id="ClassificationHead">
## MLP Classification Head
</a>
This is the two layer MLP head to classify the image based on `[CLS]` token embedding.
"""
def __init__(self, d_model: int, n_hidden: int, n_classes: int):
"""
* `d_model` is the transformer embedding size
* `n_hidden` is the size of the hidden layer
* `n_classes` is the number of classes in the classification task
"""
super().__init__()
self.ln = nn.LayerNorm([d_model])
# First layer
self.linear1 = nn.Linear(d_model, n_hidden)
# Activation
self.act = nn.ReLU()
# Second layer
self.linear2 = nn.Linear(n_hidden, n_classes)
def __call__(self, x: torch.Tensor):
x = self.ln(x)
"""
* `x` is the transformer encoding for `[CLS]` token
"""
# First layer and activation
x = self.act(self.linear1(x))
# Second layer
x = self.linear2(x)
#
return x
class VisionTransformer(Module):
"""
## Vision Transformer
This combines the [patch embeddings](#PatchEmbeddings),
[positional embeddings](#LearnedPositionalEmbeddings),
transformer and the [classification head](#ClassificationHead).
"""
def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
classification: ClassificationHead):
"""
* `transformer_layer` is a copy of a single [transformer layer](../models.html#TransformerLayer).
We make copies of it to make the transformer with `n_layers`.
* `n_layers` is the number of [transformer layers]((../models.html#TransformerLayer).
* `patch_emb` is the [patch embeddings layer](#PatchEmbeddings).
* `pos_emb` is the [positional embeddings layer](#LearnedPositionalEmbeddings).
* `classification` is the [classification head](#ClassificationHead).
"""
super().__init__()
# Make copies of the transformer layer
self.classification = classification
self.pos_emb = pos_emb
# Patch embeddings
self.patch_emb = patch_emb
self.pos_emb = pos_emb
# Classification head
self.classification = classification
# Make copies of the transformer layer
self.transformer_layers = clone_module_list(transformer_layer, n_layers)
# `[CLS]` token embedding
self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
# Final normalization layer
self.ln = nn.LayerNorm([transformer_layer.size])
def __call__(self, x):
def __call__(self, x: torch.Tensor):
"""
* `x` is the input image of shape `[batch_size, channels, height, width]`
"""
# Get patch embeddings. This gives a tensor of shape `[patches, batch_size, d_model]`
x = self.patch_emb(x)
# Add positional embeddings
x = self.pos_emb(x)
# Concatenate the `[CLS]` token embeddings before feeding the transformer
cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
x = torch.cat([cls_token_emb, x])
# Pass through transformer layers with no attention masking
for layer in self.transformer_layers:
x = layer(x=x, mask=None)
# Get the transformer output of the `[CLS]` token (which is the first in the sequence).
x = x[0]
# Layer normalization
x = self.ln(x)
# Classification head, to get logits
x = self.classification(x)
#
return x

View File

@ -1,11 +1,13 @@
"""
---
title: Train a ViT on CIFAR 10
title: Train a Vision Transformer (ViT) on CIFAR 10
summary: >
Train a ViT on CIFAR 10
Train a Vision Transformer (ViT) on CIFAR 10
---
# Train a ViT on CIFAR 10
# Train a [Vision Transformer (ViT)](index.html) on CIFAR 10
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
"""
from labml import experiment
@ -18,19 +20,27 @@ class Configs(CIFAR10Configs):
"""
## Configurations
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
We use [`CIFAR10Configs`](../../experiments/cifar10.html) which defines all the
dataset related configurations, optimizer, and a training loop.
"""
# [Transformer configurations](../configs.html#TransformerConfigs)
# to get [transformer layer](../models.html#TransformerLayer)
transformer: TransformerConfigs
# Size of a patch
patch_size: int = 4
n_hidden: int = 2048
# Size of the hidden layer in classification head
n_hidden_classification: int = 2048
# Number of classes in the task
n_classes: int = 10
@option(Configs.transformer)
def _transformer(c: Configs):
def _transformer():
"""
Create transformer configs
"""
return TransformerConfigs()
@ -42,11 +52,13 @@ def _vit(c: Configs):
from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
PatchEmbeddings
# Transformer size from [Transformer configurations](../configs.html#TransformerConfigs)
d_model = c.transformer.d_model
# Create a vision transformer
return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
PatchEmbeddings(d_model, c.patch_size, 3),
LearnedPositionalEmbeddings(d_model),
ClassificationHead(d_model, c.n_hidden, c.n_classes)).to(c.device)
ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
def main():
@ -56,20 +68,20 @@ def main():
conf = Configs()
# Load configurations
experiment.configs(conf, {
'device.cuda_device': 0,
# 'optimizer.optimizer': 'Noam',
# 'optimizer.learning_rate': 1.,
# Optimizer
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 2.5e-4,
'optimizer.d_model': 512,
# Transformer embedding size
'transformer.d_model': 512,
# Training epochs and batch size
'epochs': 1000,
'train_batch_size': 64,
# Augment CIFAR 10 images for training
'train_dataset': 'cifar10_train_augmented',
# Do not augment CIFAR 10 images for validation
'valid_dataset': 'cifar10_valid_no_augment',
})
# Set model for saving/loading

View File

@ -0,0 +1,32 @@
# [Vision Transformer (ViT)](https://nn.labml.ai/transformer/vit/index.html)
This is a [PyTorch](https://pytorch.org) implementation of the paper
[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
Vision transformer applies a pure transformer to images
without any convolution layers.
They split the image into patches and apply a transformer on patch embeddings.
[Patch embeddings](https://nn.labml.ai/transformer/vit/index.html#PathEmbeddings) are generated by applying a simple linear transformation
to the flattened pixel values of the patch.
Then a standard transformer encoder is fed with the patch embeddings, along with a
classification token `[CLS]`.
The encoding on the `[CLS]` token is used to classify the image with an MLP.
When feeding the transformer with the patches, learned positional embeddings are
added to the patch embeddings, because the patch embeddings do not have any information
about where that patch is from.
The positional embeddings are a set of vectors for each patch location that get trained
with gradient descent along with other parameters.
ViTs perform well when they are pre-trained on large datasets.
The paper suggests pre-training them with an MLP classification head and
then using a single linear layer when fine-tuning.
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
They also use higher resolution images during inference while keeping the
patch size the same.
The positional embeddings for new patch locations are calculated by interpolating
learning positional embeddings.
Here's [an experiment](https://nn.labml.ai/transformer/vit/experiment.html) that trains ViT on CIFAR-10.
This doesn't do very well because it's trained on a small dataset.
It's a simple experiment that anyone can run and play with ViTs.

View File

@ -37,6 +37,7 @@ implementations almost weekly.
* [Masked Language Model](https://nn.labml.ai/transformers/mlm/index.html)
* [MLP-Mixer: An all-MLP Architecture for Vision](https://nn.labml.ai/transformers/mlp_mixer/index.html)
* [Pay Attention to MLPs (gMLP)](https://nn.labml.ai/transformers/gmlp/index.html)
* [Vision Transformer (ViT)](https://nn.labml.ai/transformers/vit/index.html)
#### ✨ [Recurrent Highway Networks](https://nn.labml.ai/recurrent_highway_networks/index.html)

View File

@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
version='0.4.102',
version='0.4.103',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.",