Files
Varuna Jayasiri 748de53461 clickable math
2021-12-14 17:30:59 +05:30

1011 lines
91 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="These are configurable components that can be re-used quite easily."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Configurable Transformer Components"/>
<meta name="twitter:description" content="These are configurable components that can be re-used quite easily."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/configs.html"/>
<meta property="og:title" content="Configurable Transformer Components"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Configurable Transformer Components"/>
<meta property="og:description" content="These are configurable components that can be re-used quite easily."/>
<title>Configurable Transformer Components</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/transformers/configs.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">transformers</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/configs.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Configurable Transformer Components</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">9</span><span></span><span class="kn">import</span> <span class="nn">copy</span>
<span class="lineno">10</span>
<span class="lineno">11</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">12</span>
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span><span class="p">,</span> <span class="n">option</span><span class="p">,</span> <span class="n">calculate</span><span class="p">,</span> <span class="n">aggregate</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">.models</span> <span class="kn">import</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">,</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">,</span> <span class="n">TransformerLayer</span><span class="p">,</span> \
<span class="lineno">18</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">EncoderDecoder</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p> <a id="FFN"></a></p>
<h2>FFN Configurations</h2>
<p>Creates a Position-wise FeedForward Network defined in <a href="feed_forward.html"><code class="highlight"><span></span><span class="n">feed_forward</span><span class="o">.</span><span class="n">py</span></code>
</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">21</span><span class="k">class</span> <span class="nc">FeedForwardConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Position-wise feedforward layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span> <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForward</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Number of features in the embedding </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Number of features in in the hidden layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Dropout probability </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Activation in position-wise feedforward layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">activation</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;ReLU&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Whether the FFN layer should be gated </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="n">is_gated</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Whether the first fully connected layer should have a learnable bias </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="n">bias1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Whether the second fully connected layer should have a learnable bias </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">bias2</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Whether the fully connected layer for the gate should have a learnable bias </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">bias_gate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Predefined GLU variants </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">glu_variant</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;none&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<h3>ReLU activation</h3>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">max</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">&#39;ReLU&#39;</span><span class="p">)</span>
<span class="lineno">53</span><span class="k">def</span> <span class="nf">_ffn_activation_relu</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<h3>GELU activation</h3>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mord">Φ</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">Φ</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal" style="margin-right:0.14736em;">N</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">1</span><span class="mclose">)</span></span></span></span></p>
<p>It was introduced in paper <a href="https://papers.labml.ai/paper/1606.08415">Gaussian Error Linear Units</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">&#39;GELU&#39;</span><span class="p">)</span>
<span class="lineno">63</span><span class="k">def</span> <span class="nf">_ffn_activation_gelu</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p> Initialize a <a href="feed_forward.html">feed forward network</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">75</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="k">return</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span>
<span class="lineno">80</span> <span class="n">dropout</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span>
<span class="lineno">81</span> <span class="n">activation</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span>
<span class="lineno">82</span> <span class="n">is_gated</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span>
<span class="lineno">83</span> <span class="n">bias1</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span>
<span class="lineno">84</span> <span class="n">bias2</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span>
<span class="lineno">85</span> <span class="n">bias_gate</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<h2>GLU Variants</h2>
<p>These are variants with gated hidden layers for the FFN as introduced in paper <a href="https://papers.labml.ai/paper/2002.05202">GLU Variants Improve Transformer</a>. We have omitted the bias terms as specified in the paper. </p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<h3>FFN with Gated Linear Units</h3>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">FF</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">G</span><span class="mord mathnormal mtight" style="margin-right:0.10903em;">LU</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;GLU&#39;</span><span class="p">,</span>
<span class="lineno">96</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">97</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">98</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">99</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">100</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<h3>FFN with Bilinear hidden layer</h3>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">FF</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">B</span><span class="mord mathnormal mtight">i</span><span class="mord mathnormal mtight" style="margin-right:0.01968em;">l</span><span class="mord mathnormal mtight">in</span><span class="mord mathnormal mtight">e</span><span class="mord mathnormal mtight">a</span><span class="mord mathnormal mtight" style="margin-right:0.02778em;">r</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;Bilinear&#39;</span><span class="p">,</span>
<span class="lineno">106</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">107</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">108</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">109</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">110</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<h3>FFN with ReLU gate</h3>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">FF</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.00773em;">R</span><span class="mord mathnormal mtight">e</span><span class="mord mathnormal mtight">G</span><span class="mord mathnormal mtight" style="margin-right:0.10903em;">LU</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mop">max</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">x</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;ReGLU&#39;</span><span class="p">,</span>
<span class="lineno">116</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">117</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">118</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">119</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">120</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<h3>FFN with GELU gate</h3>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">FF</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">GEG</span><span class="mord mathnormal mtight" style="margin-right:0.10903em;">LU</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord text"><span class="mord">GELU</span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;GEGLU&#39;</span><span class="p">,</span>
<span class="lineno">126</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">127</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">128</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">129</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">130</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<h3>FFN with Swish gate</h3>
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">FF</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02691em;">Sw</span><span class="mord mathnormal mtight">i</span><span class="mord mathnormal mtight">G</span><span class="mord mathnormal mtight" style="margin-right:0.10903em;">LU</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord text"><span class="mord">Swish</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord text"><span class="mord">Swish</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05278em;">β</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">x</span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.05278em;">β</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;SwiGLU&#39;</span><span class="p">,</span>
<span class="lineno">137</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">138</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">139</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">140</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">141</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">SiLU</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p> <a id="TransformerConfigs"></a></p>
<h2>Transformer Configurations</h2>
<p>This defines configurations for a transformer. The configurations are calculate using option functions. These are lazy loaded and therefore only the necessary modules are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span><span class="k">class</span> <span class="nc">TransformerConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Number of attention heads </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Transformer embedding size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Number of layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Dropout probability </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Number of tokens in the source vocabulary (for token embeddings) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">n_src_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Number of tokens in the target vocabulary (to generate logits for prediction) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">n_tgt_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>The encoder self attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">encoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">&#39;mha&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>The decoder self attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">decoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">&#39;mha&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>The decoder memory attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">decoder_mem_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">&#39;mha&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Configurable Feedforward Layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Encoder layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">encoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Decoder layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">decoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Encoder consisting of multiple encoder layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Encoder consisting of multiple decoder layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Embedding layer for source </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;fixed_pos&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Embedding layer for target (for decoder) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;fixed_pos&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Logit generator for prediction </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Encoder-decoder </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">encoder_decoder</span><span class="p">:</span> <span class="n">EncoderDecoder</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<h3>Multi-head Attention</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span><span class="k">def</span> <span class="nf">_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">202</span> <span class="k">return</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span>
<span class="lineno">203</span>
<span class="lineno">204</span>
<span class="lineno">205</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">&#39;mha&#39;</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
<span class="lineno">206</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">&#39;mha&#39;</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
<span class="lineno">207</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">&#39;mha&#39;</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<h3>Relative Multi-head Attention</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span><span class="k">def</span> <span class="nf">_relative_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">212</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.xl.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span>
<span class="lineno">213</span> <span class="k">return</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">214</span>
<span class="lineno">215</span>
<span class="lineno">216</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">&#39;relative&#39;</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
<span class="lineno">217</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">&#39;relative&#39;</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
<span class="lineno">218</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">&#39;relative&#39;</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p> Create feedforward layer configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">222</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">FeedForwardConfigs</span><span class="p">()</span>
<span class="lineno">227</span> <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">228</span> <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span>
<span class="lineno">229</span> <span class="k">return</span> <span class="n">conf</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p> Encoder layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">233</span><span class="k">def</span> <span class="nf">_encoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span>
<span class="lineno">238</span> <span class="n">src_attn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
<span class="lineno">239</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p> Decoder layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">242</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">243</span><span class="k">def</span> <span class="nf">_decoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">247</span> <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span>
<span class="lineno">248</span> <span class="n">src_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
<span class="lineno">249</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p> Encoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">252</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">253</span><span class="k">def</span> <span class="nf">_encoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="k">return</span> <span class="n">Encoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p> Decoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">260</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">261</span><span class="k">def</span> <span class="nf">_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span> <span class="k">return</span> <span class="n">Decoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p> Logit generator</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">268</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">generator</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">269</span><span class="k">def</span> <span class="nf">_generator</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">273</span> <span class="k">return</span> <span class="n">Generator</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<h3>Fixed Positional Embeddings</h3>
<p>Source embedding with fixed positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">277</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">&#39;fixed_pos&#39;</span><span class="p">)</span>
<span class="lineno">278</span><span class="k">def</span> <span class="nf">_src_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p> Target embedding with fixed positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">&#39;fixed_pos&#39;</span><span class="p">)</span>
<span class="lineno">286</span><span class="k">def</span> <span class="nf">_tgt_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">290</span> <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<h3>Learned Positional Embeddings</h3>
<p>Source embedding with learned positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">&#39;learned_pos&#39;</span><span class="p">)</span>
<span class="lineno">295</span><span class="k">def</span> <span class="nf">_src_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">299</span> <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p> Target embedding with learned positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">302</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">&#39;learned_pos&#39;</span><span class="p">)</span>
<span class="lineno">303</span><span class="k">def</span> <span class="nf">_tgt_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">307</span> <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<h3>No Positional Embeddings</h3>
<p>Source embedding without positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">311</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">)</span>
<span class="lineno">312</span><span class="k">def</span> <span class="nf">_src_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">316</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">319</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">)</span>
<span class="lineno">320</span><span class="k">def</span> <span class="nf">_tgt_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">321</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">322</span>
<span class="lineno">323</span>
<span class="lineno">324</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_decoder</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">325</span><span class="k">def</span> <span class="nf">_encoder_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">326</span> <span class="k">return</span> <span class="n">EncoderDecoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>