📚 glu variants

This commit is contained in:
Varuna Jayasiri
2021-01-26 16:54:23 +05:30
parent 20d2e27a3c
commit abe5caba6f
10 changed files with 1390 additions and 552 deletions

BIN
docs/optimizers/noam_lr.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

View File

@ -86,11 +86,15 @@
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p><a id="FFN"></p>
<h2>FFN Configurations</h2>
<p></a></p>
<p>Creates a Position-wise FeedForward Network defined in
<a href="feed_forward.html"><code>feed_forward.py</code></a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">21</span><span class="k">class</span> <span class="nc">FeedForwardConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
@ -104,7 +108,7 @@
<p>Position-wise feedforward layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">23</span> <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForward</span></pre></div>
<div class="highlight"><pre><span class="lineno">31</span> <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForward</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -115,7 +119,7 @@
<p>Number of features in the embedding</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">25</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span></pre></div>
<div class="highlight"><pre><span class="lineno">33</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -126,7 +130,7 @@
<p>Number of features in in the hidden layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span></pre></div>
<div class="highlight"><pre><span class="lineno">35</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -137,7 +141,7 @@
<p>Dropout probability</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
<div class="highlight"><pre><span class="lineno">37</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -148,7 +152,7 @@
<p>Activation in position-wise feedforward layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span> <span class="n">activation</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;ReLU&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">activation</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;ReLU&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -159,7 +163,7 @@
<p>Whether the FFN layer should be gated</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="n">is_gated</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
<div class="highlight"><pre><span class="lineno">41</span> <span class="n">is_gated</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -170,7 +174,7 @@
<p>Whether the first fully connected layer should have a learnable bias</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span> <span class="n">bias1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
<div class="highlight"><pre><span class="lineno">43</span> <span class="n">bias1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -181,7 +185,7 @@
<p>Whether the second fully connected layer should have a learnable bias</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="n">bias2</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">bias2</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -192,7 +196,7 @@
<p>Whether the fully connected layer for the gate should have a learnable bias</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">bias_gate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">bias_gate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -203,7 +207,7 @@
<p>Predefined GLU variants</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="n">glu_variant</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;none&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">glu_variant</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;none&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -211,11 +215,14 @@
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>ReLU activation</p>
<h3>ReLU activation</h3>
<p>
<script type="math/tex; mode=display">\max(0, x)</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">&#39;ReLU&#39;</span><span class="p">)</span>
<span class="lineno">45</span><span class="k">def</span> <span class="nf">_ffn_activation_relu</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">52</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">&#39;ReLU&#39;</span><span class="p">)</span>
<span class="lineno">53</span><span class="k">def</span> <span class="nf">_ffn_activation_relu</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -226,7 +233,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">59</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -234,11 +241,14 @@
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>GELU activation</p>
<h3>GELU activation</h3>
<p>
<script type="math/tex; mode=display">x \Phi(x)</script> where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$</p>
<p>It was introduced in paper <a href="https://arxiv.org/abs/1606.08415">Gaussian Error Linear Units</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">&#39;GELU&#39;</span><span class="p">)</span>
<span class="lineno">53</span><span class="k">def</span> <span class="nf">_ffn_activation_gelu</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">62</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">&#39;GELU&#39;</span><span class="p">)</span>
<span class="lineno">63</span><span class="k">def</span> <span class="nf">_ffn_activation_gelu</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -249,7 +259,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -257,11 +267,11 @@
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Create feedforward layer</p>
<p>Initialize a <a href="feed_forward.html">feed forward network</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">61</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">74</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">75</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -272,53 +282,129 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="k">return</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span>
<span class="lineno">66</span> <span class="n">dropout</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span>
<span class="lineno">67</span> <span class="n">activation</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span>
<span class="lineno">68</span> <span class="n">is_gated</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span>
<span class="lineno">69</span> <span class="n">bias1</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span>
<span class="lineno">70</span> <span class="n">bias2</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span>
<span class="lineno">71</span> <span class="n">bias_gate</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">)</span>
<span class="lineno">72</span>
<span class="lineno">73</span>
<span class="lineno">74</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;GLU&#39;</span><span class="p">,</span>
<span class="lineno">75</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">76</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">77</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">78</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">79</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()))</span>
<span class="lineno">80</span>
<span class="lineno">81</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;Bilinear&#39;</span><span class="p">,</span>
<span class="lineno">82</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">83</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">84</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">85</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">86</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()))</span>
<span class="lineno">87</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;ReGLU&#39;</span><span class="p">,</span>
<span class="lineno">88</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">89</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">90</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">91</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">92</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()))</span>
<span class="lineno">93</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;GEGLU&#39;</span><span class="p">,</span>
<span class="lineno">94</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">95</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">96</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">97</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">98</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()))</span>
<span class="lineno">99</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;SwiGLU&#39;</span><span class="p">,</span>
<span class="lineno">100</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">101</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">102</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">103</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">104</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">SiLU</span><span class="p">()))</span></pre></div>
<div class="highlight"><pre><span class="lineno">79</span> <span class="k">return</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span>
<span class="lineno">80</span> <span class="n">dropout</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span>
<span class="lineno">81</span> <span class="n">activation</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span>
<span class="lineno">82</span> <span class="n">is_gated</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span>
<span class="lineno">83</span> <span class="n">bias1</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span>
<span class="lineno">84</span> <span class="n">bias2</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span>
<span class="lineno">85</span> <span class="n">bias_gate</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs doc-strings'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<h2>GLU Variants</h2>
<p>These are variants with gated hidden layers for the FFN
as introduced in paper <a href="https://arxiv.org/abs/2002.05202">GLU Variants Improve Transformer</a>.
We have omitted the bias terms as specified in the paper.</p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<h3>FFN with Gated Linear Units</h3>
<p>
<script type="math/tex; mode=display">FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;GLU&#39;</span><span class="p">,</span>
<span class="lineno">96</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">97</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">98</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">99</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">100</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<h3>FFN with Bilinear hidden layer</h3>
<p>
<script type="math/tex; mode=display">FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;Bilinear&#39;</span><span class="p">,</span>
<span class="lineno">106</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">107</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">108</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">109</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">110</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<h3>FFN with ReLU gate</h3>
<p>
<script type="math/tex; mode=display">FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;ReGLU&#39;</span><span class="p">,</span>
<span class="lineno">116</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">117</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">118</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">119</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">120</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<h3>FFN with GELU gate</h3>
<p>
<script type="math/tex; mode=display">FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;GEGLU&#39;</span><span class="p">,</span>
<span class="lineno">126</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">127</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">128</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">129</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">130</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<h3>FFN with Swish gate</h3>
<p>
<script type="math/tex; mode=display">FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2</script>
where $\text{Swish}_\beta(x) = x \sigma(\beta x)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">&#39;SwiGLU&#39;</span><span class="p">,</span>
<span class="lineno">137</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
<span class="lineno">138</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">139</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">140</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
<span class="lineno">141</span> <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">SiLU</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p><a id="TransformerConfigs"></p>
<h2>Transformer Configurations</h2>
<p></a></p>
@ -328,73 +414,7 @@ These are lazy loaded and therefore only the necessary modules
are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span><span class="k">class</span> <span class="nc">TransformerConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Number of attention heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Transformer embedding size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Number of layers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Dropout probability</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Number of tokens in the source vocabulary (for token embeddings)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">n_src_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Number of tokens in the target vocabulary (to generate logits for prediction)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">n_tgt_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
<div class="highlight"><pre><span class="lineno">144</span><span class="k">class</span> <span class="nc">TransformerConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -402,10 +422,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>The encoder self attention</p>
<p>Number of attention heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">encoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">&#39;mha&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -413,10 +433,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>The decoder self attention</p>
<p>Transformer embedding size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">decoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">&#39;mha&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -424,10 +444,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>The decoder memory attention</p>
<p>Number of layers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">decoder_mem_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">&#39;mha&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
@ -435,10 +455,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Configurable Feedforward Layer</p>
<p>Dropout probability</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span></pre></div>
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
@ -446,10 +466,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Encoder layer</p>
<p>Number of tokens in the source vocabulary (for token embeddings)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">encoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">n_src_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
@ -457,10 +477,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Decoder layer</p>
<p>Number of tokens in the target vocabulary (to generate logits for prediction)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">decoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">n_tgt_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
@ -468,10 +488,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Encoder consisting of multiple encoder layers</p>
<p>The encoder self attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">encoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">&#39;mha&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
@ -479,10 +499,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Encoder consisting of multiple decoder layers</p>
<p>The decoder self attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">decoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">&#39;mha&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
@ -490,10 +510,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Embedding layer for source</p>
<p>The decoder memory attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;fixed_pos&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">decoder_mem_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">&#39;mha&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
@ -501,10 +521,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Embedding layer for target (for decoder)</p>
<p>Configurable Feedforward Layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;fixed_pos&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
@ -512,10 +532,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Logit generator for prediction</p>
<p>Encoder layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">encoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
@ -523,10 +543,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Encoder-decoder</p>
<p>Decoder layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">encoder_decoder</span><span class="p">:</span> <span class="n">EncoderDecoder</span></pre></div>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">decoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
@ -534,16 +554,10 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<h3>Multi-head Attention</h3>
<p>Encoder consisting of multiple encoder layers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span><span class="k">def</span> <span class="nf">_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">165</span> <span class="k">return</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">166</span>
<span class="lineno">167</span>
<span class="lineno">168</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">&#39;mha&#39;</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
<span class="lineno">169</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">&#39;mha&#39;</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
<span class="lineno">170</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">&#39;mha&#39;</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
@ -551,29 +565,21 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<h3>Relative Multi-head Attention</h3>
<p>Encoder consisting of multiple decoder layers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span><span class="k">def</span> <span class="nf">_relative_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">175</span> <span class="kn">from</span> <span class="nn">.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span>
<span class="lineno">176</span> <span class="k">return</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">177</span>
<span class="lineno">178</span>
<span class="lineno">179</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">&#39;relative&#39;</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
<span class="lineno">180</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">&#39;relative&#39;</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
<span class="lineno">181</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">&#39;relative&#39;</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">186</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs doc-strings'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Create feedforward layer configurations</p>
<p>Embedding layer for source</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">185</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;fixed_pos&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
@ -581,25 +587,21 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Embedding layer for target (for decoder)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">FeedForwardConfigs</span><span class="p">()</span>
<span class="lineno">190</span> <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">191</span> <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span>
<span class="lineno">192</span> <span class="k">return</span> <span class="n">conf</span></pre></div>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">&#39;fixed_pos&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs doc-strings'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Encoder layer</p>
<p>Logit generator for prediction</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">196</span><span class="k">def</span> <span class="nf">_encoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">194</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span> <span class="o">=</span> <span class="s1">&#39;default&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
@ -607,24 +609,27 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Encoder-decoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span>
<span class="lineno">201</span> <span class="n">src_attn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
<span class="lineno">202</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">encoder_decoder</span><span class="p">:</span> <span class="n">EncoderDecoder</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs doc-strings'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Decoder layer</p>
<h3>Multi-head Attention</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">206</span><span class="k">def</span> <span class="nf">_decoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">201</span><span class="k">def</span> <span class="nf">_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">202</span> <span class="k">return</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">203</span>
<span class="lineno">204</span>
<span class="lineno">205</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">&#39;mha&#39;</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
<span class="lineno">206</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">&#39;mha&#39;</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
<span class="lineno">207</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">&#39;mha&#39;</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
@ -632,12 +637,17 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<h3>Relative Multi-head Attention</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span>
<span class="lineno">211</span> <span class="n">src_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
<span class="lineno">212</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">211</span><span class="k">def</span> <span class="nf">_relative_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">212</span> <span class="kn">from</span> <span class="nn">.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span>
<span class="lineno">213</span> <span class="k">return</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">214</span>
<span class="lineno">215</span>
<span class="lineno">216</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">&#39;relative&#39;</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
<span class="lineno">217</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">&#39;relative&#39;</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
<span class="lineno">218</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">&#39;relative&#39;</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
@ -645,11 +655,11 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Encoder</p>
<p>Create feedforward layer configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">216</span><span class="k">def</span> <span class="nf">_encoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">221</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">222</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
@ -660,7 +670,10 @@ are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="k">return</span> <span class="n">Encoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">FeedForwardConfigs</span><span class="p">()</span>
<span class="lineno">227</span> <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">228</span> <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span>
<span class="lineno">229</span> <span class="k">return</span> <span class="n">conf</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
@ -668,11 +681,11 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>Decoder</p>
<p>Encoder layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">223</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">224</span><span class="k">def</span> <span class="nf">_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">232</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">233</span><span class="k">def</span> <span class="nf">_encoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
@ -683,7 +696,9 @@ are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="k">return</span> <span class="n">Decoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">237</span> <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span>
<span class="lineno">238</span> <span class="n">src_attn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
<span class="lineno">239</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
@ -691,11 +706,11 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Logit generator</p>
<p>Decoder layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">231</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">generator</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">232</span><span class="k">def</span> <span class="nf">_generator</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">242</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">243</span><span class="k">def</span> <span class="nf">_decoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
@ -706,7 +721,9 @@ are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">236</span> <span class="k">return</span> <span class="n">Generator</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">247</span> <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span>
<span class="lineno">248</span> <span class="n">src_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
<span class="lineno">249</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
@ -714,12 +731,11 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<h2>Positional Embeddings</h2>
<p>Source embedding with fixed positional encodings</p>
<p>Encoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">240</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">&#39;fixed_pos&#39;</span><span class="p">)</span>
<span class="lineno">241</span><span class="k">def</span> <span class="nf">_src_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">252</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">253</span><span class="k">def</span> <span class="nf">_encoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
@ -730,7 +746,7 @@ are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">245</span> <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">257</span> <span class="k">return</span> <span class="n">Encoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
@ -738,11 +754,11 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Target embedding with fixed positional encodings</p>
<p>Decoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">248</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">&#39;fixed_pos&#39;</span><span class="p">)</span>
<span class="lineno">249</span><span class="k">def</span> <span class="nf">_tgt_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">260</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">261</span><span class="k">def</span> <span class="nf">_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
@ -753,7 +769,7 @@ are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">253</span> <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">265</span> <span class="k">return</span> <span class="n">Decoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
@ -761,12 +777,11 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<h2>Learned Positional Embeddings</h2>
<p>Source embedding with learned positional encodings</p>
<p>Logit generator</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">&#39;learned_pos&#39;</span><span class="p">)</span>
<span class="lineno">258</span><span class="k">def</span> <span class="nf">_src_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">268</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">generator</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">269</span><span class="k">def</span> <span class="nf">_generator</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
@ -777,7 +792,7 @@ are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">262</span> <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">273</span> <span class="k">return</span> <span class="n">Generator</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
@ -785,11 +800,12 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Target embedding with learned positional encodings</p>
<h3>Fixed Positional Embeddings</h3>
<p>Source embedding with fixed positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">265</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">&#39;learned_pos&#39;</span><span class="p">)</span>
<span class="lineno">266</span><span class="k">def</span> <span class="nf">_tgt_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">277</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">&#39;fixed_pos&#39;</span><span class="p">)</span>
<span class="lineno">278</span><span class="k">def</span> <span class="nf">_src_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
@ -800,7 +816,7 @@ are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">270</span> <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">282</span> <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
@ -808,12 +824,11 @@ are calculated.</p>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<h2>No Positional Embeddings</h2>
<p>Source embedding without positional encodings</p>
<p>Target embedding with fixed positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">274</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">)</span>
<span class="lineno">275</span><span class="k">def</span> <span class="nf">_src_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">285</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">&#39;fixed_pos&#39;</span><span class="p">)</span>
<span class="lineno">286</span><span class="k">def</span> <span class="nf">_tgt_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
@ -824,25 +839,96 @@ are calculated.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">279</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">290</span> <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<h3>Learned Positional Embeddings</h3>
<p>Source embedding with learned positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">&#39;learned_pos&#39;</span><span class="p">)</span>
<span class="lineno">295</span><span class="k">def</span> <span class="nf">_src_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">)</span>
<span class="lineno">283</span><span class="k">def</span> <span class="nf">_tgt_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">284</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">285</span>
<span class="lineno">286</span>
<span class="lineno">287</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_decoder</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">288</span><span class="k">def</span> <span class="nf">_encoder_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">289</span> <span class="k">return</span> <span class="n">EncoderDecoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">299</span> <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Target embedding with learned positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">302</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">&#39;learned_pos&#39;</span><span class="p">)</span>
<span class="lineno">303</span><span class="k">def</span> <span class="nf">_tgt_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">307</span> <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<h3>No Positional Embeddings</h3>
<p>Source embedding without positional encodings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">311</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">)</span>
<span class="lineno">312</span><span class="k">def</span> <span class="nf">_src_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">316</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">319</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">)</span>
<span class="lineno">320</span><span class="k">def</span> <span class="nf">_tgt_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">321</span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
<span class="lineno">322</span>
<span class="lineno">323</span>
<span class="lineno">324</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_decoder</span><span class="p">,</span> <span class="s1">&#39;default&#39;</span><span class="p">)</span>
<span class="lineno">325</span><span class="k">def</span> <span class="nf">_encoder_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">326</span> <span class="k">return</span> <span class="n">EncoderDecoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span></pre></div>
</div>
</div>
</div>

View File

@ -84,12 +84,20 @@ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.</p>
<p>Sometimes the
GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
<script type="math/tex; mode=display">x \Phi(x)</script> where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$</p>
<h3>Gated Linear Units</h3>
<p>This is a generic implementation that supports different variants including
<a href="https://arxiv.org/abs/2002.05202">Gated Linear Units</a> (GLU).
We have also implemented experiments on these:</p>
<ul>
<li><a href="glu_variants/experiment.html">experiment that uses <code>labml.configs</code></a></li>
<li><a href="glu_variants/simple.html">simpler version from scratch</a></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">26</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="lineno">28</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
<div class="highlight"><pre><span class="lineno">35</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">36</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="lineno">37</span>
<span class="lineno">38</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -97,10 +105,10 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Position-wise feed-forward network (FFN) module</h2>
<h2>FFN module</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">41</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -119,13 +127,13 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">38</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="lineno">39</span> <span class="n">activation</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="lineno">40</span> <span class="n">is_gated</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="lineno">41</span> <span class="n">bias1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">42</span> <span class="n">bias2</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">43</span> <span class="n">bias_gate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">46</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">47</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="lineno">48</span> <span class="n">activation</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="lineno">49</span> <span class="n">is_gated</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="lineno">50</span> <span class="n">bias1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">51</span> <span class="n">bias2</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">52</span> <span class="n">bias_gate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -136,14 +144,7 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias1</span><span class="p">)</span>
<span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias2</span><span class="p">)</span>
<span class="lineno">56</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
<span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span>
<span class="lineno">58</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_gated</span> <span class="o">=</span> <span class="n">is_gated</span>
<span class="lineno">59</span> <span class="k">if</span> <span class="n">is_gated</span><span class="p">:</span>
<span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_v</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias_gate</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">62</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -151,17 +152,136 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Layer one parameterized by weight $W_1$ and bias $b_1$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Layer one parameterized by weight $W_1$ and bias $b_1$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Hidden layer dropout</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Activation function $f$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Whether there is a gate</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_gated</span> <span class="o">=</span> <span class="n">is_gated</span>
<span class="lineno">73</span> <span class="k">if</span> <span class="n">is_gated</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>If there is a gate the linear layer to transform inputs to
be multiplied by the gate, parameterized by weight $V$ and bias $c$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_v</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias_gate</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">63</span> <span class="n">g</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="lineno">64</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_gated</span><span class="p">:</span>
<span class="lineno">65</span> <span class="n">x</span> <span class="o">=</span> <span class="n">g</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_v</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">66</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">67</span> <span class="n">x</span> <span class="o">=</span> <span class="n">g</span>
<span class="lineno">68</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">69</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>$f(x W_1 + b_1)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">g</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>If gated, $f(x W_1 + b_1) \otimes (x V + b) $</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_gated</span><span class="p">:</span>
<span class="lineno">83</span> <span class="n">x</span> <span class="o">=</span> <span class="n">g</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_v</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Otherwise</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">86</span> <span class="n">x</span> <span class="o">=</span> <span class="n">g</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Apply dropout</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>$(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$
depending on whether it is gated</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
</div>

View File

@ -71,19 +71,21 @@
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Train Autoregressive Transformer</h1>
<p>This trains a simple <a href="../../">transformer</a> model for auto-regression.</p>
<h1>Gated Linear Units and Variants</h1>
<p>This trains a simple <a href="../../">transformer</a> model for auto-regression.
We try different variants for the <a href="../feed_forward">position-wise feedforward network</a>.
The reusable &amp; configurable are defined in <a href="configs.html"><code>configs.py</code></a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml.utils.pytorch</span> <span class="kn">import</span> <span class="n">get_modules</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">19</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">TransformerConfigs</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div>
<div class="highlight"><pre><span class="lineno">16</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml.utils.pytorch</span> <span class="kn">import</span> <span class="n">get_modules</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">21</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">TransformerConfigs</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -94,7 +96,7 @@
<h2>Auto regressive model</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">25</span><span class="k">class</span> <span class="nc">AutoregressiveModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">AutoregressiveModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -105,8 +107,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">30</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span><span class="p">):</span>
<span class="lineno">31</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">32</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span><span class="p">):</span>
<span class="lineno">33</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -117,7 +119,7 @@
<p>Token embedding module</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span></pre></div>
<div class="highlight"><pre><span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -128,7 +130,7 @@
<p>Transformer based encoder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span></pre></div>
<div class="highlight"><pre><span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -140,7 +142,7 @@
this give logits of the the next token</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div>
<div class="highlight"><pre><span class="lineno">40</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -151,7 +153,7 @@ this give logits of the the next token</p>
<p>This will be initialized on the first call</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
<div class="highlight"><pre><span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -162,7 +164,7 @@ this give logits of the the next token</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -173,8 +175,8 @@ this give logits of the the next token</p>
<p>Create subsequent mask, so that the transformer can only pay attention to past tokens.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">):</span>
<span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="o">=</span> <span class="n">subsequent_mask</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">src</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">46</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">):</span>
<span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="o">=</span> <span class="n">subsequent_mask</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">src</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -185,7 +187,7 @@ this give logits of the the next token</p>
<p>Embed the tokens (<code>src</code>) and run it through the the transformer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -196,7 +198,7 @@ this give logits of the the next token</p>
<p>Generate logits of the next token</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">res</span><span class="p">),</span> <span class="kc">None</span></pre></div>
<div class="highlight"><pre><span class="lineno">51</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">res</span><span class="p">),</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -208,7 +210,7 @@ this give logits of the the next token</p>
<p>The default configs can and will be over-ridden when we start the experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">52</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">54</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -219,8 +221,8 @@ this give logits of the the next token</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">TransformerConfigs</span>
<span class="lineno">60</span> <span class="n">model</span><span class="p">:</span> <span class="n">AutoregressiveModel</span></pre></div>
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">TransformerConfigs</span>
<span class="lineno">62</span> <span class="n">model</span><span class="p">:</span> <span class="n">AutoregressiveModel</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -231,8 +233,8 @@ this give logits of the the next token</p>
<p>Initialize the auto-regressive model</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">64</span><span class="k">def</span> <span class="nf">autoregressive_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">65</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">66</span><span class="k">def</span> <span class="nf">autoregressive_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -243,8 +245,8 @@ this give logits of the the next token</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">m</span> <span class="o">=</span> <span class="n">AutoregressiveModel</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span>
<span class="lineno">69</span> <span class="k">return</span> <span class="n">m</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">m</span> <span class="o">=</span> <span class="n">AutoregressiveModel</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span>
<span class="lineno">71</span> <span class="k">return</span> <span class="n">m</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -252,11 +254,11 @@ this give logits of the the next token</p>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Initialize the configurable transformer encoder for our autoregressive model</p>
<p>Initialize the <a href="../configs.html">configurable transformer</a> encoder for our autoregressive model.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">)</span>
<span class="lineno">73</span><span class="k">def</span> <span class="nf">transformer_c</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">74</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">)</span>
<span class="lineno">75</span><span class="k">def</span> <span class="nf">transformer_c</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -267,11 +269,11 @@ this give logits of the the next token</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="n">tc</span> <span class="o">=</span> <span class="n">TransformerConfigs</span><span class="p">()</span>
<span class="lineno">78</span> <span class="n">tc</span><span class="o">.</span><span class="n">n_src_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
<span class="lineno">79</span> <span class="n">tc</span><span class="o">.</span><span class="n">n_tgt_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
<span class="lineno">80</span>
<span class="lineno">81</span> <span class="k">return</span> <span class="n">tc</span></pre></div>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">tc</span> <span class="o">=</span> <span class="n">TransformerConfigs</span><span class="p">()</span>
<span class="lineno">80</span> <span class="n">tc</span><span class="o">.</span><span class="n">n_src_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
<span class="lineno">81</span> <span class="n">tc</span><span class="o">.</span><span class="n">n_tgt_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
<span class="lineno">82</span>
<span class="lineno">83</span> <span class="k">return</span> <span class="n">tc</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -282,7 +284,7 @@ this give logits of the the next token</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">86</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -293,7 +295,7 @@ this give logits of the the next token</p>
<p>Create experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;glu_variants&quot;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;glu_variants&quot;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
@ -304,7 +306,7 @@ this give logits of the the next token</p>
<p>Create configs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
@ -315,7 +317,7 @@ this give logits of the the next token</p>
<p>Load configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
@ -326,19 +328,19 @@ this give logits of the the next token</p>
<p>A dictionary of configurations to override</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="p">{</span><span class="s1">&#39;tokenizer&#39;</span><span class="p">:</span> <span class="s1">&#39;character&#39;</span><span class="p">,</span>
<span class="lineno">93</span> <span class="s1">&#39;prompt_separator&#39;</span><span class="p">:</span> <span class="s1">&#39;&#39;</span><span class="p">,</span>
<span class="lineno">94</span> <span class="s1">&#39;prompt&#39;</span><span class="p">:</span> <span class="s1">&#39;It is &#39;</span><span class="p">,</span>
<span class="lineno">95</span> <span class="s1">&#39;text&#39;</span><span class="p">:</span> <span class="s1">&#39;tiny_shakespeare&#39;</span><span class="p">,</span>
<span class="lineno">96</span>
<span class="lineno">97</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Noam&#39;</span><span class="p">,</span>
<span class="lineno">98</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">1.</span><span class="p">,</span>
<span class="lineno">99</span> <span class="s1">&#39;optimizer.d_model&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
<span class="lineno">100</span>
<span class="lineno">101</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span>
<span class="lineno">102</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
<span class="lineno">103</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">6</span><span class="p">,</span>
<span class="lineno">104</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">94</span> <span class="p">{</span><span class="s1">&#39;tokenizer&#39;</span><span class="p">:</span> <span class="s1">&#39;character&#39;</span><span class="p">,</span>
<span class="lineno">95</span> <span class="s1">&#39;prompt_separator&#39;</span><span class="p">:</span> <span class="s1">&#39;&#39;</span><span class="p">,</span>
<span class="lineno">96</span> <span class="s1">&#39;prompt&#39;</span><span class="p">:</span> <span class="s1">&#39;It is &#39;</span><span class="p">,</span>
<span class="lineno">97</span> <span class="s1">&#39;text&#39;</span><span class="p">:</span> <span class="s1">&#39;tiny_shakespeare&#39;</span><span class="p">,</span>
<span class="lineno">98</span>
<span class="lineno">99</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Noam&#39;</span><span class="p">,</span>
<span class="lineno">100</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">1.</span><span class="p">,</span>
<span class="lineno">101</span> <span class="s1">&#39;optimizer.d_model&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
<span class="lineno">102</span>
<span class="lineno">103</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span>
<span class="lineno">104</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
<span class="lineno">105</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">6</span><span class="p">,</span>
<span class="lineno">106</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
@ -347,9 +349,11 @@ this give logits of the the next token</p>
<a href='#section-22'>#</a>
</div>
<p>GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU</p>
<p>These are defined in the <a href="../configs.html#FFN">configurable FFN</a>
implementation</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="s1">&#39;transformer.ffn.glu_variant&#39;</span><span class="p">:</span> <span class="s1">&#39;Bilinear&#39;</span><span class="p">,</span></pre></div>
<div class="highlight"><pre><span class="lineno">112</span> <span class="s1">&#39;transformer.ffn.glu_variant&#39;</span><span class="p">:</span> <span class="s1">&#39;Bilinear&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
@ -360,10 +364,10 @@ this give logits of the the next token</p>
<p>Transformer configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="s1">&#39;transformer.d_model&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
<span class="lineno">111</span> <span class="s1">&#39;transformer.ffn.d_ff&#39;</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span>
<span class="lineno">112</span> <span class="s1">&#39;transformer.n_heads&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="lineno">113</span> <span class="s1">&#39;transformer.n_layers&#39;</span><span class="p">:</span> <span class="mi">6</span><span class="p">})</span></pre></div>
<div class="highlight"><pre><span class="lineno">115</span> <span class="s1">&#39;transformer.d_model&#39;</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
<span class="lineno">116</span> <span class="s1">&#39;transformer.ffn.d_ff&#39;</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span>
<span class="lineno">117</span> <span class="s1">&#39;transformer.n_heads&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="lineno">118</span> <span class="s1">&#39;transformer.n_layers&#39;</span><span class="p">:</span> <span class="mi">6</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
@ -374,7 +378,7 @@ this give logits of the the next token</p>
<p>This is needed to initialize models</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_tokens</span> <span class="o">=</span> <span class="n">conf</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">n_tokens</span></pre></div>
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">conf</span><span class="o">.</span><span class="n">n_tokens</span> <span class="o">=</span> <span class="n">conf</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">n_tokens</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -385,7 +389,7 @@ this give logits of the the next token</p>
<p>Set models for saving and loading</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -396,7 +400,7 @@ this give logits of the the next token</p>
<p>Start the experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -407,11 +411,11 @@ this give logits of the the next token</p>
<p><code>TrainValidConfigs.run</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
<span class="lineno">125</span>
<span class="lineno">126</span>
<span class="lineno">127</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">128</span> <span class="n">main</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
<span class="lineno">130</span>
<span class="lineno">131</span>
<span class="lineno">132</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">133</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
</div>

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,14 @@ from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPosit
class FeedForwardConfigs(BaseConfigs):
"""
<a id="FFN">
## FFN Configurations
</a>
Creates a Position-wise FeedForward Network defined in
[`feed_forward.py`](feed_forward.html).
"""
# Position-wise feedforward layer
ffn: FeedForward
# Number of features in the embedding
@ -44,7 +52,9 @@ class FeedForwardConfigs(BaseConfigs):
@option(FeedForwardConfigs.activation, 'ReLU')
def _ffn_activation_relu():
"""
ReLU activation
### ReLU activation
$$\max(0, x)$$
"""
return nn.ReLU()
@ -52,7 +62,11 @@ def _ffn_activation_relu():
@option(FeedForwardConfigs.activation, 'GELU')
def _ffn_activation_gelu():
"""
GELU activation
### GELU activation
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
"""
return nn.GELU()
@ -60,7 +74,7 @@ def _ffn_activation_gelu():
@option(FeedForwardConfigs.ffn, 'default')
def _feed_forward(c: FeedForwardConfigs):
"""
Create feedforward layer
Initialize a [feed forward network](feed_forward.html)
"""
return FeedForward(c.d_model, c.d_ff,
dropout=c.dropout,
@ -70,7 +84,14 @@ def _feed_forward(c: FeedForwardConfigs):
bias2=c.bias2,
bias_gate=c.bias_gate)
# ## GLU Variants
# These are variants with gated hidden layers for the FFN
# as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202).
# We have omitted the bias terms as specified in the paper.
# ### FFN with Gated Linear Units
#
# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
aggregate(FeedForwardConfigs.glu_variant, 'GLU',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
@ -78,24 +99,40 @@ aggregate(FeedForwardConfigs.glu_variant, 'GLU',
(FeedForwardConfigs.bias_gate, False),
(FeedForwardConfigs.activation, nn.Sigmoid()))
# ### FFN with Bilinear hidden layer
#
# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
(FeedForwardConfigs.bias2, False),
(FeedForwardConfigs.bias_gate, False),
(FeedForwardConfigs.activation, nn.Identity()))
# ### FFN with ReLU gate
#
# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
(FeedForwardConfigs.bias2, False),
(FeedForwardConfigs.bias_gate, False),
(FeedForwardConfigs.activation, nn.ReLU()))
# ### FFN with GELU gate
#
# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
(FeedForwardConfigs.bias2, False),
(FeedForwardConfigs.bias_gate, False),
(FeedForwardConfigs.activation, nn.GELU()))
# ### FFN with Swish gate
#
# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
@ -236,7 +273,7 @@ def _generator(c: TransformerConfigs):
return Generator(c.n_tgt_vocab, c.d_model)
# ## Positional Embeddings
# ### Fixed Positional Embeddings
@option(TransformerConfigs.src_embed, 'fixed_pos')
def _src_embed_with_positional(c: TransformerConfigs):
"""
@ -253,7 +290,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
# ## Learned Positional Embeddings
# ### Learned Positional Embeddings
@option(TransformerConfigs.src_embed, 'learned_pos')
def _src_embed_with_learned_positional(c: TransformerConfigs):
"""
@ -270,7 +307,7 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs):
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
# ## No Positional Embeddings
# ### No Positional Embeddings
@option(TransformerConfigs.src_embed, 'no_pos')
def _src_embed_without_positional(c: TransformerConfigs):
"""

View File

@ -21,6 +21,15 @@ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.
Sometimes the
GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
### Gated Linear Units
This is a generic implementation that supports different variants including
[Gated Linear Units](https://arxiv.org/abs/2002.05202) (GLU).
We have also implemented experiments on these:
* [experiment that uses `labml.configs`](glu_variants/experiment.html)
* [simpler version from scratch](glu_variants/simple.html)
"""
import torch
@ -31,7 +40,7 @@ from labml_helpers.module import Module
class FeedForward(Module):
"""
## Position-wise feed-forward network (FFN) module
## FFN module
"""
def __init__(self, d_model: int, d_ff: int,
@ -51,19 +60,32 @@ class FeedForward(Module):
* `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias
"""
super().__init__()
# Layer one parameterized by weight $W_1$ and bias $b_1$
self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
# Layer one parameterized by weight $W_1$ and bias $b_1$
self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
# Hidden layer dropout
self.dropout = nn.Dropout(dropout)
# Activation function $f$
self.activation = activation
# Whether there is a gate
self.is_gated = is_gated
if is_gated:
# If there is a gate the linear layer to transform inputs to
# be multiplied by the gate, parameterized by weight $V$ and bias $c$
self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
def __call__(self, x: torch.Tensor):
# $f(x W_1 + b_1)$
g = self.activation(self.layer1(x))
# If gated, $f(x W_1 + b_1) \otimes (x V + b) $
if self.is_gated:
x = g * self.linear_v(x)
# Otherwise
else:
x = g
# Apply dropout
x = self.dropout(x)
# $(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$
# depending on whether it is gated
return self.layer2(x)

View File

@ -6,9 +6,11 @@ summary: >
for the position-wise feedforward network (FFN).
---
# Train Autoregressive Transformer
# Gated Linear Units and Variants
This trains a simple [transformer](../../) model for auto-regression.
We try different variants for the [position-wise feedforward network](../feed_forward).
The reusable & configurable are defined in [`configs.py`](configs.html).
"""
import torch
@ -72,7 +74,7 @@ def autoregressive_model(c: Configs):
@option(Configs.transformer)
def transformer_c(c: Configs):
"""
Initialize the configurable transformer encoder for our autoregressive model
Initialize the [configurable transformer](../configs.html) encoder for our autoregressive model.
"""
tc = TransformerConfigs()
tc.n_src_vocab = c.n_tokens
@ -104,6 +106,9 @@ def main():
'inner_iterations': 10,
# GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU
#
# These are defined in the [configurable FFN](../configs.html#FFN)
# implementation
'transformer.ffn.glu_variant': 'Bilinear',
# Transformer configurations

View File

@ -6,9 +6,13 @@ summary: >
for the position-wise feedforward network (FFN).
---
# Train Autoregressive Transformer
# Gated Linear Units and Variants
This trains a simple [transformer](../../) model for auto-regression.
We try different variants for the [position-wise feedforward network](../feed_forward).
*This is a simpler implementation that doesn't use [`labml.configs`](experiment.html) module.
We decided to write a simpler implementation to make it easier readers who are not familiar.*
"""
import dataclasses
@ -56,6 +60,9 @@ class AutoregressiveModel(nn.Module):
@dataclasses.dataclass
class Configs:
"""
### Configurations
"""
d_model: int = 512
seq_len: int = 128
batch_size: int = 32
@ -69,71 +76,130 @@ class Configs:
class TinyShakespeareDataset(Dataset):
"""
### Tiny Shakespeare Dataset
"""
def __init__(self, seq_len: int):
# Location of the text file
path = lab.get_data_path() / 'tiny_shakespeare.txt'
# Download the file
download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
# Read the downloaded file
with open(str(path), 'r') as f:
text = f.read()
# Extract the characters
chars = list(set(text))
# Character to id (integer) map
self.stoi = {c: i for i, c in enumerate(chars)}
# Id to character map
self.itos = {i: c for i, c in enumerate(chars)}
# Length of a training sample
self.seq_len = seq_len
# Data in the form of a tensor of ids
self.data = self.text_to_i(text)
def text_to_i(self, text: str):
"""
Transform the text into a tensor of ids
"""
return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
def __len__(self):
"""
Number of samples in the dataset.
*This will read the dataset `seq_len` times in a single epoch.*
"""
return len(self.data) - self.seq_len - 1
def __getitem__(self, idx):
"""
Return a sample
"""
return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
class Trainer:
"""
## Trainer
"""
def __init__(self, configs: Configs):
# Get the device
self.device = torch.device('cpu')
if torch.cuda.is_available():
self.device = torch.device('cuda:0')
# Initialize the dataset
self.dataset = TinyShakespeareDataset(configs.seq_len)
self.dataloader = DataLoader(self.dataset, batch_size=configs.batch_size, collate_fn=transpose_batch,
# Initialize the dataloader
self.dataloader = DataLoader(self.dataset,
batch_size=configs.batch_size,
collate_fn=transpose_batch,
shuffle=True)
# FFN with Gated Linear Unit
# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
if configs.glu_variant == 'GLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
# FFN with Bilinear hidden layer
# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
elif configs.glu_variant == 'Bilinear':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
# FFN with ReLU gate
# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
elif configs.glu_variant == 'ReGLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
# FFN with GELU gate
# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
elif configs.glu_variant == 'GEGLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
# FFN with Swish gate
# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
elif configs.glu_variant == 'SwiGLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
# FFN with ReLU activation
# $$FFN_{ReLU}(x)(x, W_1, W_2, b_1, b_2) = \text{ReLU}_1(x W_1 + b_1) W_2 + b_2$$
elif configs.glu_variant == 'ReLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
# FFN with ReLU activation
# $$FFN_{GELU}(x)(x, W_1, W_2, b_1, b_2) = \text{GELU}_1(x W_1 + b_1) W_2 + b_2$$
elif configs.glu_variant == 'GELU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
else:
raise ValueError(f'Unknown variant {configs.glu_variant}')
# Number of different characters
n_chars = len(self.dataset.stoi)
# Initialize [Multi-Head Attention module](../mha.html)
mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
# Initialize the [Transformer Block](../models.html#TransformerLayer)
transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
feed_forward=ffn, dropout_prob=configs.dropout)
# Initialize the model with an
# [embedding layer](../models.html#EmbeddingsWithPositionalEncoding)
# (with fixed positional encoding)
# [transformer encoder](../models.html#Encoder) and
# a linear layer to generate logits.
self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
Encoder(TransformerLayer(
d_model=configs.d_model,
self_attn=MultiHeadAttention(configs.n_heads, configs.d_model,
configs.dropout),
src_attn=None,
feed_forward=ffn,
dropout_prob=configs.dropout
), configs.n_layers),
Encoder(transformer_layer, configs.n_layers),
nn.Linear(configs.d_model, n_chars))
# Move the model to the current device
self.model.to(self.device)
# Initialize [Noam optimizer](../../optimizers/noam.html)
self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
# Cross-entropy loss
self.loss_func = nn.CrossEntropyLoss()
# Number of training epochs;
# *note that our dataset definition repeats the data `seq_len` times in a single epoch
self.epochs = configs.epochs
# Gradient clipping norm
self.grad_norm_clip = configs.grad_norm_clip
# Set tracker configurations
@ -166,18 +232,28 @@ class Trainer:
logger.log(log)
def train(self):
"""
### Train the model
"""
# Loop for the given number of epochs
for _ in monit.loop(self.epochs):
# Iterate over the minibatches
for i, batch in monit.enum('Train', self.dataloader):
# Move data to the device
data, target = batch[0].to(self.device), batch[1].to(self.device)
# Set tracker step, as the number of characters trained on
tracker.add_global_step(data.shape[0] * data.shape[1])
# Set model state to training
self.model.train()
# Evaluate the model
output = self.model(data)
# Calculate and log loss
# Calculate loss
loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
# Log the loss
tracker.add("loss.train", loss)
# Calculate gradients
@ -186,12 +262,13 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
# Take optimizer step
self.optimizer.step()
# Log the model parameters and gradients on last batch of every epoch
# Log the model parameters and gradients
if (i + 1) % 100 == 0:
tracker.add('model', self.model)
# Clear the gradients
self.optimizer.zero_grad()
# Generate a sample
if (i + 1) % 100 == 0:
self.model.eval()
with torch.no_grad():
@ -201,6 +278,7 @@ class Trainer:
if (i + 1) % 10 == 0:
tracker.save()
# Save the model
experiment.save_checkpoint()
@ -212,12 +290,14 @@ def main():
# Load configurations
experiment.configs(dataclasses.asdict(configs))
# Create trainer
trainer = Trainer(configs)
# Set models for training and loading
experiment.add_pytorch_models({'model': trainer.model})
# Start the experiment
with experiment.start():
# `TrainValidConfigs.run`
# Train the model
trainer.train()