unescape *

This commit is contained in:
Varuna Jayasiri
2021-10-21 11:46:06 +05:30
parent 4c5a706836
commit 8aa83ddf7b
179 changed files with 9727 additions and 35256 deletions

View File

@ -1,470 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This experiment trains MLP Mixer on Tiny Shakespeare dataset."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="MLP Mixer experiment"/>
<meta name="twitter:description" content="This experiment trains MLP Mixer on Tiny Shakespeare dataset."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/mlp_mixer/experiment.html"/>
<meta property="og:title" content="MLP Mixer experiment"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="MLP Mixer experiment"/>
<meta property="og:description" content="This experiment trains MLP Mixer on Tiny Shakespeare dataset."/>
<title>MLP Mixer experiment</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/mlp_mixer/experiment.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">mlp_mixer</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/mlp_mixer/experiment.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="index.html">MLP Mixer</a> Experiment</h1>
<p>This is an annotated PyTorch experiment to train a <a href="index.html">MLP Mixer Model</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">12</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.configs</span> <span class="kn">import</span> <span class="n">FeedForwardConfigs</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mlm.experiment</span> <span class="kn">import</span> <span class="n">TransformerMLM</span><span class="p">,</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">MLMConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Configurations</h2>
<p>This inherits from
<a href="../mlm/experiment.html"><code>MLMConfigs</code></a> where we define an experiment for
<a href="../mlm.index.html">Masked Language Models</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">19</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">MLMConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Configurable <a href="../feed_forward.html">Feed-Forward Network</a> for the MLP</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span> <span class="n">mix_mlp</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>The mixing MLP configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">mix_mlp</span><span class="p">)</span>
<span class="lineno">33</span><span class="k">def</span> <span class="nf">_mix_mlp_configs</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">FeedForwardConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Size of the MLP is the sequence length, because it is applied across tokens</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="n">conf</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">seq_len</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>The paper suggests $GELU$ activation</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">conf</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="s1">&#39;GELU&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="k">return</span> <span class="n">conf</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h3>Transformer configurations</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">)</span>
<span class="lineno">49</span><span class="k">def</span> <span class="nf">_transformer_configs</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>We use our
<a href="../configs.html#TransformerConfigs">configurable transformer implementation</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">TransformerConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Set the vocabulary sizes for embeddings and generating logits</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_src_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
<span class="lineno">59</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_tgt_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Embedding size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">conf</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Change attention module to <a href="index.html">MLPMixer</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.mlp_mixer</span> <span class="kn">import</span> <span class="n">MLPMixer</span>
<span class="lineno">64</span> <span class="n">conf</span><span class="o">.</span><span class="n">encoder_attn</span> <span class="o">=</span> <span class="n">MLPMixer</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">mix_mlp</span><span class="o">.</span><span class="n">ffn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="k">return</span> <span class="n">conf</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Create experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;mlp_mixer_mlm&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Create configs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Override configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Sequence length of $32$. We use a short sequence length to train faster.
Otherwise MLM models take forever to train.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Train for 1024 epochs.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Switch between training and validation for $1$ times
per epoch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Transformer configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="s1">&#39;d_model&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
<span class="lineno">91</span> <span class="s1">&#39;transformer.ffn.d_ff&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
<span class="lineno">92</span> <span class="s1">&#39;transformer.n_heads&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="lineno">93</span> <span class="s1">&#39;transformer.n_layers&#39;</span><span class="p">:</span> <span class="mi">6</span><span class="p">,</span>
<span class="lineno">94</span> <span class="s1">&#39;transformer.ffn.activation&#39;</span><span class="p">:</span> <span class="s1">&#39;GELU&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Mixer MLP hidden layer size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="s1">&#39;mix_mlp.d_ff&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Use <a href="../../optimizers/noam.html">Noam optimizer</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Noam&#39;</span><span class="p">,</span>
<span class="lineno">101</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">1.</span><span class="p">,</span>
<span class="lineno">102</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Set models for saving and loading</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Start the experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Run training</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">115</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -1,289 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This is an annotated implementation/tutorial of MLP-Mixer: An all-MLP Architecture for Vision in PyTorch."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="MLP-Mixer: An all-MLP Architecture for Vision"/>
<meta name="twitter:description" content="This is an annotated implementation/tutorial of MLP-Mixer: An all-MLP Architecture for Vision in PyTorch."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/mlp_mixer/index.html"/>
<meta property="og:title" content="MLP-Mixer: An all-MLP Architecture for Vision"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="MLP-Mixer: An all-MLP Architecture for Vision"/>
<meta property="og:description" content="This is an annotated implementation/tutorial of MLP-Mixer: An all-MLP Architecture for Vision in PyTorch."/>
<title>MLP-Mixer: An all-MLP Architecture for Vision</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/mlp_mixer/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">mlp_mixer</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/mlp_mixer/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>MLP-Mixer: An all-MLP Architecture for Vision</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://papers.labml.ai/paper/2105.01601">MLP-Mixer: An all-MLP Architecture for Vision</a>.</p>
<p>This paper applies the model on vision tasks.
The model is similar to a transformer with attention layer being replaced by a MLP
that is applied across the patches (or tokens in case of a NLP task).</p>
<p>Our implementation of MLP Mixer is a drop in replacement for the <a href="../mha.html">self-attention layer</a>
in <a href="../models.html">our transformer implementation</a>.
So it&rsquo;s just a couple of lines of code, transposing the tensor to apply the MLP
across the sequence dimension.</p>
<p>Although the paper applied MLP Mixer on vision tasks,
we tried it on a <a href="../mlm/index.html">masked language model</a>.
<a href="experiment.html">Here is the experiment code</a>.</p>
<p><a href="https://app.labml.ai/run/994263d2cdb511eb961e872301f0dbab"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span>
<span class="lineno">30</span>
<span class="lineno">31</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>MLP Mixer</h2>
<p>This module is a drop-in replacement for <a href="../mha.html">self-attention layer</a>.
It transposes the input tensor before feeding it to the MLP and transposes back,
so that the MLP is applied across the sequence dimension (across tokens or image patches) instead
of the feature dimension.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">MLPMixer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul>
<li><code>ffn</code> is the MLP module.</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mlp</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">50</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="n">mlp</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>The <a href="../mha.html">normal attention module</a> can be fed with different token embeddings for
$\text{query}$,$\text{key}$, and $\text{value}$ and a mask.</p>
<p>We follow the same function signature so that we can replace it directly.</p>
<p>For MLP mixing, <script type="math/tex; mode=display">x = \text{query} = \text{key} = \text{value}</script> and masking is not possible.
Shape of <code>query</code> (and <code>key</code> and <code>value</code>) is <code>[seq_len, batch_size, d_model]</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>$\text{query}$,$\text{key}$, and $\text{value}$ all should be the same</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="k">assert</span> <span class="n">query</span> <span class="ow">is</span> <span class="n">key</span> <span class="ow">and</span> <span class="n">key</span> <span class="ow">is</span> <span class="n">value</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>MLP mixer doesn&rsquo;t support masking. i.e. all tokens will see all other token embeddings.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">assert</span> <span class="n">mask</span> <span class="ow">is</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Assign to <code>x</code> for clarity</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="n">x</span> <span class="o">=</span> <span class="n">query</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Transpose so that the last dimension is the sequence dimension.
New shape is <code>[d_model, batch_size, seq_len]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Apply the MLP across tokens</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Transpose back into original form</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -1,153 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="MLP-Mixer: An all-MLP Architecture for Vision"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/mlp_mixer/readme.html"/>
<meta property="og:title" content="MLP-Mixer: An all-MLP Architecture for Vision"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="MLP-Mixer: An all-MLP Architecture for Vision"/>
<meta property="og:description" content=""/>
<title>MLP-Mixer: An all-MLP Architecture for Vision</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/mlp_mixer/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">mlp_mixer</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/mlp_mixer/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/transformers/mlp_mixer/index.html">MLP-Mixer: An all-MLP Architecture for Vision</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://papers.labml.ai/paper/2105.01601">MLP-Mixer: An all-MLP Architecture for Vision</a>.</p>
<p>This paper applies the model on vision tasks.
The model is similar to a transformer with attention layer being replaced by a MLP
that is applied across the patches (or tokens in case of a NLP task).</p>
<p>Our implementation of MLP Mixer is a drop in replacement for the <a href="https://nn.labml.ai/transformers/mha.html">self-attention layer</a>
in <a href="https://nn.labml.ai/transformers/models.html">our transformer implementation</a>.
So it&rsquo;s just a couple of lines of code, transposing the tensor to apply the MLP
across the sequence dimension.</p>
<p>Although the paper applied MLP Mixer on vision tasks,
we tried it on a <a href="https://nn.labml.ai/transformers/mlm/index.html">masked language model</a>.
<a href="https://nn.labml.ai/transformers/mlp_mixer/experiment.html">Here is the experiment code</a>.</p>
<p><a href="https://app.labml.ai/run/994263d2cdb511eb961e872301f0dbab"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>