mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 14:29:43 +08:00 
			
		
		
		
	📚 glu variants
This commit is contained in:
		
							
								
								
									
										
											BIN
										
									
								
								docs/optimizers/noam_lr.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/optimizers/noam_lr.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 35 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								docs/optimizers/radam_r_t.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/optimizers/radam_r_t.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 30 KiB  | 
@ -86,11 +86,15 @@
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-1'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
        <div class='docs doc-strings'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-1'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                
 | 
			
		||||
                <p><a id="FFN"></p>
 | 
			
		||||
<h2>FFN Configurations</h2>
 | 
			
		||||
<p></a></p>
 | 
			
		||||
<p>Creates a Position-wise FeedForward Network defined in
 | 
			
		||||
<a href="feed_forward.html"><code>feed_forward.py</code></a>.</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">21</span><span class="k">class</span> <span class="nc">FeedForwardConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
@ -104,7 +108,7 @@
 | 
			
		||||
                <p>Position-wise feedforward layer</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">23</span>    <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForward</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">31</span>    <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForward</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-3'>
 | 
			
		||||
@ -115,7 +119,7 @@
 | 
			
		||||
                <p>Number of features in the embedding</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">25</span>    <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">33</span>    <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-4'>
 | 
			
		||||
@ -126,7 +130,7 @@
 | 
			
		||||
                <p>Number of features in in the hidden layer</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">27</span>    <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">35</span>    <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-5'>
 | 
			
		||||
@ -137,7 +141,7 @@
 | 
			
		||||
                <p>Dropout probability</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">29</span>    <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">37</span>    <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-6'>
 | 
			
		||||
@ -148,7 +152,7 @@
 | 
			
		||||
                <p>Activation in position-wise feedforward layer</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">31</span>    <span class="n">activation</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="s1">'ReLU'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">39</span>    <span class="n">activation</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="s1">'ReLU'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-7'>
 | 
			
		||||
@ -159,7 +163,7 @@
 | 
			
		||||
                <p>Whether the FFN layer should be gated</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">33</span>    <span class="n">is_gated</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">41</span>    <span class="n">is_gated</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-8'>
 | 
			
		||||
@ -170,7 +174,7 @@
 | 
			
		||||
                <p>Whether the first fully connected layer should have a learnable bias</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">35</span>    <span class="n">bias1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">43</span>    <span class="n">bias1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-9'>
 | 
			
		||||
@ -181,7 +185,7 @@
 | 
			
		||||
                <p>Whether the second fully connected layer should have a learnable bias</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">37</span>    <span class="n">bias2</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">45</span>    <span class="n">bias2</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-10'>
 | 
			
		||||
@ -192,7 +196,7 @@
 | 
			
		||||
                <p>Whether the fully connected layer for the gate should have a learnable bias</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">39</span>    <span class="n">bias_gate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">47</span>    <span class="n">bias_gate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-11'>
 | 
			
		||||
@ -203,7 +207,7 @@
 | 
			
		||||
                <p>Predefined GLU variants</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">41</span>    <span class="n">glu_variant</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'none'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">49</span>    <span class="n">glu_variant</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'none'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-12'>
 | 
			
		||||
@ -211,11 +215,14 @@
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-12'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>ReLU activation</p>
 | 
			
		||||
                <h3>ReLU activation</h3>
 | 
			
		||||
<p>
 | 
			
		||||
<script type="math/tex; mode=display">\max(0, x)</script>
 | 
			
		||||
</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">44</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">'ReLU'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">45</span><span class="k">def</span> <span class="nf">_ffn_activation_relu</span><span class="p">():</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">52</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">'ReLU'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">53</span><span class="k">def</span> <span class="nf">_ffn_activation_relu</span><span class="p">():</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-13'>
 | 
			
		||||
@ -226,7 +233,7 @@
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">49</span>    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">59</span>    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-14'>
 | 
			
		||||
@ -234,11 +241,14 @@
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-14'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>GELU activation</p>
 | 
			
		||||
                <h3>GELU activation</h3>
 | 
			
		||||
<p>
 | 
			
		||||
<script type="math/tex; mode=display">x \Phi(x)</script> where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$</p>
 | 
			
		||||
<p>It was introduced in paper <a href="https://arxiv.org/abs/1606.08415">Gaussian Error Linear Units</a>.</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">52</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">'GELU'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">53</span><span class="k">def</span> <span class="nf">_ffn_activation_gelu</span><span class="p">():</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">62</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="s1">'GELU'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">63</span><span class="k">def</span> <span class="nf">_ffn_activation_gelu</span><span class="p">():</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-15'>
 | 
			
		||||
@ -249,7 +259,7 @@
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">57</span>    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">71</span>    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-16'>
 | 
			
		||||
@ -257,11 +267,11 @@
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-16'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Create feedforward layer</p>
 | 
			
		||||
                <p>Initialize a <a href="feed_forward.html">feed forward network</a></p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">60</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">61</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">74</span><span class="nd">@option</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">75</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-17'>
 | 
			
		||||
@ -272,53 +282,129 @@
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">65</span>    <span class="k">return</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">66</span>                       <span class="n">dropout</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">67</span>                       <span class="n">activation</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">68</span>                       <span class="n">is_gated</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">69</span>                       <span class="n">bias1</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">70</span>                       <span class="n">bias2</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">71</span>                       <span class="n">bias_gate</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">72</span>
 | 
			
		||||
<span class="lineno">73</span>
 | 
			
		||||
<span class="lineno">74</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'GLU'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">75</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">76</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">77</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">78</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">79</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()))</span>
 | 
			
		||||
<span class="lineno">80</span>
 | 
			
		||||
<span class="lineno">81</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'Bilinear'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">82</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">83</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">84</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">85</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">86</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()))</span>
 | 
			
		||||
<span class="lineno">87</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'ReGLU'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">88</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">89</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">90</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">91</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">92</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()))</span>
 | 
			
		||||
<span class="lineno">93</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'GEGLU'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">94</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">95</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">96</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">97</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">98</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()))</span>
 | 
			
		||||
<span class="lineno">99</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'SwiGLU'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">100</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">101</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">102</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">103</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">104</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">SiLU</span><span class="p">()))</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">79</span>    <span class="k">return</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">80</span>                       <span class="n">dropout</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">81</span>                       <span class="n">activation</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">82</span>                       <span class="n">is_gated</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">83</span>                       <span class="n">bias1</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">84</span>                       <span class="n">bias2</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">85</span>                       <span class="n">bias_gate</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-18'>
 | 
			
		||||
        <div class='docs doc-strings'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-18'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h2>GLU Variants</h2>
 | 
			
		||||
<p>These are variants with gated hidden layers for the FFN
 | 
			
		||||
as introduced in paper <a href="https://arxiv.org/abs/2002.05202">GLU Variants Improve Transformer</a>.
 | 
			
		||||
We have omitted the bias terms as specified in the paper.</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-19'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-19'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h3>FFN with Gated Linear Units</h3>
 | 
			
		||||
<p>
 | 
			
		||||
<script type="math/tex; mode=display">FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2</script>
 | 
			
		||||
</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">95</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'GLU'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">96</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">97</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">98</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">99</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">100</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()))</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-20'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-20'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h3>FFN with Bilinear hidden layer</h3>
 | 
			
		||||
<p>
 | 
			
		||||
<script type="math/tex; mode=display">FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2</script>
 | 
			
		||||
</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">105</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'Bilinear'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">106</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">107</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">108</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">109</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">110</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()))</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-21'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-21'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h3>FFN with ReLU gate</h3>
 | 
			
		||||
<p>
 | 
			
		||||
<script type="math/tex; mode=display">FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2</script>
 | 
			
		||||
</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">115</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'ReGLU'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">116</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">117</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">118</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">119</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">120</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()))</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-22'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-22'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h3>FFN with GELU gate</h3>
 | 
			
		||||
<p>
 | 
			
		||||
<script type="math/tex; mode=display">FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2</script>
 | 
			
		||||
</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">125</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'GEGLU'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">126</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">127</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">128</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">129</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">130</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()))</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-23'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-23'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h3>FFN with Swish gate</h3>
 | 
			
		||||
<p>
 | 
			
		||||
<script type="math/tex; mode=display">FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2</script>
 | 
			
		||||
where $\text{Swish}_\beta(x) = x \sigma(\beta x)$</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">136</span><span class="n">aggregate</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">glu_variant</span><span class="p">,</span> <span class="s1">'SwiGLU'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">137</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">is_gated</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">138</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias1</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">139</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias2</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">140</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">bias_gate</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">141</span>          <span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">SiLU</span><span class="p">()))</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-24'>
 | 
			
		||||
        <div class='docs doc-strings'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-24'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p><a id="TransformerConfigs"></p>
 | 
			
		||||
<h2>Transformer Configurations</h2>
 | 
			
		||||
<p></a></p>
 | 
			
		||||
@ -328,73 +414,7 @@ These are lazy loaded and therefore only the necessary modules
 | 
			
		||||
are calculated.</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">107</span><span class="k">class</span> <span class="nc">TransformerConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-19'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-19'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Number of attention heads</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">119</span>    <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-20'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-20'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Transformer embedding size</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">121</span>    <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-21'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-21'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Number of layers</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">123</span>    <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-22'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-22'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Dropout probability</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">125</span>    <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-23'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-23'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Number of tokens in the source vocabulary (for token embeddings)</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">127</span>    <span class="n">n_src_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-24'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-24'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Number of tokens in the target vocabulary (to generate logits for prediction)</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">129</span>    <span class="n">n_tgt_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">144</span><span class="k">class</span> <span class="nc">TransformerConfigs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-25'>
 | 
			
		||||
@ -402,10 +422,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-25'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>The encoder self attention</p>
 | 
			
		||||
                <p>Number of attention heads</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">132</span>    <span class="n">encoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">'mha'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">156</span>    <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-26'>
 | 
			
		||||
@ -413,10 +433,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-26'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>The decoder self attention</p>
 | 
			
		||||
                <p>Transformer embedding size</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">134</span>    <span class="n">decoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">'mha'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">158</span>    <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-27'>
 | 
			
		||||
@ -424,10 +444,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-27'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>The decoder memory attention</p>
 | 
			
		||||
                <p>Number of layers</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">136</span>    <span class="n">decoder_mem_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">'mha'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">160</span>    <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-28'>
 | 
			
		||||
@ -435,10 +455,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-28'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Configurable Feedforward Layer</p>
 | 
			
		||||
                <p>Dropout probability</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">139</span>    <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">162</span>    <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-29'>
 | 
			
		||||
@ -446,10 +466,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-29'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Encoder layer</p>
 | 
			
		||||
                <p>Number of tokens in the source vocabulary (for token embeddings)</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">142</span>    <span class="n">encoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">164</span>    <span class="n">n_src_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-30'>
 | 
			
		||||
@ -457,10 +477,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-30'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Decoder layer</p>
 | 
			
		||||
                <p>Number of tokens in the target vocabulary (to generate logits for prediction)</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">144</span>    <span class="n">decoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">166</span>    <span class="n">n_tgt_vocab</span><span class="p">:</span> <span class="nb">int</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-31'>
 | 
			
		||||
@ -468,10 +488,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-31'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Encoder consisting of multiple encoder layers</p>
 | 
			
		||||
                <p>The encoder self attention</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">147</span>    <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">169</span>    <span class="n">encoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">'mha'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-32'>
 | 
			
		||||
@ -479,10 +499,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-32'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Encoder consisting of multiple decoder layers</p>
 | 
			
		||||
                <p>The decoder self attention</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">149</span>    <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">171</span>    <span class="n">decoder_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">'mha'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-33'>
 | 
			
		||||
@ -490,10 +510,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-33'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Embedding layer for source</p>
 | 
			
		||||
                <p>The decoder memory attention</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">152</span>    <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">'fixed_pos'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">173</span>    <span class="n">decoder_mem_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="s1">'mha'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-34'>
 | 
			
		||||
@ -501,10 +521,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-34'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Embedding layer for target (for decoder)</p>
 | 
			
		||||
                <p>Configurable Feedforward Layer</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">154</span>    <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">'fixed_pos'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">176</span>    <span class="n">ffn</span><span class="p">:</span> <span class="n">FeedForwardConfigs</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-35'>
 | 
			
		||||
@ -512,10 +532,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-35'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Logit generator for prediction</p>
 | 
			
		||||
                <p>Encoder layer</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">157</span>    <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">179</span>    <span class="n">encoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-36'>
 | 
			
		||||
@ -523,10 +543,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-36'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Encoder-decoder</p>
 | 
			
		||||
                <p>Decoder layer</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">160</span>    <span class="n">encoder_decoder</span><span class="p">:</span> <span class="n">EncoderDecoder</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">181</span>    <span class="n">decoder_layer</span><span class="p">:</span> <span class="n">TransformerLayer</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-37'>
 | 
			
		||||
@ -534,16 +554,10 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-37'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h3>Multi-head Attention</h3>
 | 
			
		||||
                <p>Encoder consisting of multiple encoder layers</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">164</span><span class="k">def</span> <span class="nf">_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">165</span>    <span class="k">return</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">166</span>
 | 
			
		||||
<span class="lineno">167</span>
 | 
			
		||||
<span class="lineno">168</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">'mha'</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">169</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">'mha'</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">170</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">'mha'</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">184</span>    <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-38'>
 | 
			
		||||
@ -551,29 +565,21 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-38'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h3>Relative Multi-head Attention</h3>
 | 
			
		||||
                <p>Encoder consisting of multiple decoder layers</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">174</span><span class="k">def</span> <span class="nf">_relative_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">175</span>    <span class="kn">from</span> <span class="nn">.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span>
 | 
			
		||||
<span class="lineno">176</span>    <span class="k">return</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">177</span>
 | 
			
		||||
<span class="lineno">178</span>
 | 
			
		||||
<span class="lineno">179</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">'relative'</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">180</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">'relative'</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">181</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">'relative'</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">186</span>    <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-39'>
 | 
			
		||||
        <div class='docs doc-strings'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-39'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Create feedforward layer configurations</p>
 | 
			
		||||
                <p>Embedding layer for source</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">184</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">185</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">189</span>    <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">'fixed_pos'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-40'>
 | 
			
		||||
@ -581,25 +587,21 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-40'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                
 | 
			
		||||
                <p>Embedding layer for target (for decoder)</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">189</span>    <span class="n">conf</span> <span class="o">=</span> <span class="n">FeedForwardConfigs</span><span class="p">()</span>
 | 
			
		||||
<span class="lineno">190</span>    <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">191</span>    <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">192</span>    <span class="k">return</span> <span class="n">conf</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">191</span>    <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">Module</span> <span class="o">=</span> <span class="s1">'fixed_pos'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-41'>
 | 
			
		||||
        <div class='docs doc-strings'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-41'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Encoder layer</p>
 | 
			
		||||
                <p>Logit generator for prediction</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">195</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">196</span><span class="k">def</span> <span class="nf">_encoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">194</span>    <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span> <span class="o">=</span> <span class="s1">'default'</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-42'>
 | 
			
		||||
@ -607,24 +609,27 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-42'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                
 | 
			
		||||
                <p>Encoder-decoder</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">200</span>    <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">201</span>                            <span class="n">src_attn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">202</span>                            <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">197</span>    <span class="n">encoder_decoder</span><span class="p">:</span> <span class="n">EncoderDecoder</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-43'>
 | 
			
		||||
        <div class='docs doc-strings'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-43'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Decoder layer</p>
 | 
			
		||||
                <h3>Multi-head Attention</h3>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">205</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">206</span><span class="k">def</span> <span class="nf">_decoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">201</span><span class="k">def</span> <span class="nf">_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">202</span>    <span class="k">return</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">203</span>
 | 
			
		||||
<span class="lineno">204</span>
 | 
			
		||||
<span class="lineno">205</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">'mha'</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">206</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">'mha'</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">207</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">'mha'</span><span class="p">,</span> <span class="n">_mha</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-44'>
 | 
			
		||||
@ -632,12 +637,17 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-44'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                
 | 
			
		||||
                <h3>Relative Multi-head Attention</h3>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">210</span>    <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">211</span>                            <span class="n">src_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">212</span>                            <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">211</span><span class="k">def</span> <span class="nf">_relative_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">212</span>    <span class="kn">from</span> <span class="nn">.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span>
 | 
			
		||||
<span class="lineno">213</span>    <span class="k">return</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">214</span>
 | 
			
		||||
<span class="lineno">215</span>
 | 
			
		||||
<span class="lineno">216</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">'relative'</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">217</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">'relative'</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">218</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">'relative'</span><span class="p">,</span> <span class="n">_relative_mha</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-45'>
 | 
			
		||||
@ -645,11 +655,11 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-45'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Encoder</p>
 | 
			
		||||
                <p>Create feedforward layer configurations</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">215</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">216</span><span class="k">def</span> <span class="nf">_encoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">221</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">ffn</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">222</span><span class="k">def</span> <span class="nf">_feed_forward</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-46'>
 | 
			
		||||
@ -660,7 +670,10 @@ are calculated.</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">220</span>    <span class="k">return</span> <span class="n">Encoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">226</span>    <span class="n">conf</span> <span class="o">=</span> <span class="n">FeedForwardConfigs</span><span class="p">()</span>
 | 
			
		||||
<span class="lineno">227</span>    <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">228</span>    <span class="n">conf</span><span class="o">.</span><span class="n">set_default</span><span class="p">(</span><span class="n">FeedForwardConfigs</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">229</span>    <span class="k">return</span> <span class="n">conf</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-47'>
 | 
			
		||||
@ -668,11 +681,11 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-47'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Decoder</p>
 | 
			
		||||
                <p>Encoder layer</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">223</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">224</span><span class="k">def</span> <span class="nf">_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">232</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">233</span><span class="k">def</span> <span class="nf">_encoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-48'>
 | 
			
		||||
@ -683,7 +696,9 @@ are calculated.</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">228</span>    <span class="k">return</span> <span class="n">Decoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">237</span>    <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">238</span>                            <span class="n">src_attn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">239</span>                            <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-49'>
 | 
			
		||||
@ -691,11 +706,11 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-49'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Logit generator</p>
 | 
			
		||||
                <p>Decoder layer</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">231</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">generator</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">232</span><span class="k">def</span> <span class="nf">_generator</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">242</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">243</span><span class="k">def</span> <span class="nf">_decoder_layer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-50'>
 | 
			
		||||
@ -706,7 +721,9 @@ are calculated.</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">236</span>    <span class="k">return</span> <span class="n">Generator</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">247</span>    <span class="k">return</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">248</span>                            <span class="n">src_attn</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">ffn</span><span class="p">),</span>
 | 
			
		||||
<span class="lineno">249</span>                            <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-51'>
 | 
			
		||||
@ -714,12 +731,11 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-51'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h2>Positional Embeddings</h2>
 | 
			
		||||
<p>Source embedding with fixed positional encodings</p>
 | 
			
		||||
                <p>Encoder</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">240</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">'fixed_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">241</span><span class="k">def</span> <span class="nf">_src_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">252</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">253</span><span class="k">def</span> <span class="nf">_encoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-52'>
 | 
			
		||||
@ -730,7 +746,7 @@ are calculated.</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">245</span>    <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">257</span>    <span class="k">return</span> <span class="n">Encoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-53'>
 | 
			
		||||
@ -738,11 +754,11 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-53'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Target embedding with fixed positional encodings</p>
 | 
			
		||||
                <p>Decoder</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">248</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">'fixed_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">249</span><span class="k">def</span> <span class="nf">_tgt_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">260</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">261</span><span class="k">def</span> <span class="nf">_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-54'>
 | 
			
		||||
@ -753,7 +769,7 @@ are calculated.</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">253</span>    <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">265</span>    <span class="k">return</span> <span class="n">Decoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-55'>
 | 
			
		||||
@ -761,12 +777,11 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-55'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h2>Learned Positional Embeddings</h2>
 | 
			
		||||
<p>Source embedding with learned positional encodings</p>
 | 
			
		||||
                <p>Logit generator</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">257</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">'learned_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">258</span><span class="k">def</span> <span class="nf">_src_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">268</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">generator</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">269</span><span class="k">def</span> <span class="nf">_generator</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-56'>
 | 
			
		||||
@ -777,7 +792,7 @@ are calculated.</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">262</span>    <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">273</span>    <span class="k">return</span> <span class="n">Generator</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-57'>
 | 
			
		||||
@ -785,11 +800,12 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-57'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Target embedding with learned positional encodings</p>
 | 
			
		||||
                <h3>Fixed Positional Embeddings</h3>
 | 
			
		||||
<p>Source embedding with fixed positional encodings</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">265</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">'learned_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">266</span><span class="k">def</span> <span class="nf">_tgt_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">277</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">'fixed_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">278</span><span class="k">def</span> <span class="nf">_src_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-58'>
 | 
			
		||||
@ -800,7 +816,7 @@ are calculated.</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">270</span>    <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">282</span>    <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-59'>
 | 
			
		||||
@ -808,12 +824,11 @@ are calculated.</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-59'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h2>No Positional Embeddings</h2>
 | 
			
		||||
<p>Source embedding without positional encodings</p>
 | 
			
		||||
                <p>Target embedding with fixed positional encodings</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">274</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">'no_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">275</span><span class="k">def</span> <span class="nf">_src_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">285</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">'fixed_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">286</span><span class="k">def</span> <span class="nf">_tgt_embed_with_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-60'>
 | 
			
		||||
@ -824,25 +839,96 @@ are calculated.</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">279</span>    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">290</span>    <span class="k">return</span> <span class="n">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-61'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
        <div class='docs doc-strings'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-61'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h3>Learned Positional Embeddings</h3>
 | 
			
		||||
<p>Source embedding with learned positional encodings</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">294</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">'learned_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">295</span><span class="k">def</span> <span class="nf">_src_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-62'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-62'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">282</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">'no_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">283</span><span class="k">def</span> <span class="nf">_tgt_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">284</span>    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">285</span>
 | 
			
		||||
<span class="lineno">286</span>
 | 
			
		||||
<span class="lineno">287</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_decoder</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">288</span><span class="k">def</span> <span class="nf">_encoder_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">289</span>    <span class="k">return</span> <span class="n">EncoderDecoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">299</span>    <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-63'>
 | 
			
		||||
        <div class='docs doc-strings'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-63'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Target embedding with learned positional encodings</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">302</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">'learned_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">303</span><span class="k">def</span> <span class="nf">_tgt_embed_with_learned_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-64'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-64'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">307</span>    <span class="k">return</span> <span class="n">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-65'>
 | 
			
		||||
        <div class='docs doc-strings'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-65'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h3>No Positional Embeddings</h3>
 | 
			
		||||
<p>Source embedding without positional encodings</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">311</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="s1">'no_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">312</span><span class="k">def</span> <span class="nf">_src_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-66'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-66'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">316</span>    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_src_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-67'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-67'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">319</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="s1">'no_pos'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">320</span><span class="k">def</span> <span class="nf">_tgt_embed_without_positional</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">321</span>    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tgt_vocab</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">322</span>
 | 
			
		||||
<span class="lineno">323</span>
 | 
			
		||||
<span class="lineno">324</span><span class="nd">@option</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_decoder</span><span class="p">,</span> <span class="s1">'default'</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">325</span><span class="k">def</span> <span class="nf">_encoder_decoder</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">326</span>    <span class="k">return</span> <span class="n">EncoderDecoder</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">decoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    </div>
 | 
			
		||||
 | 
			
		||||
@ -84,12 +84,20 @@ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.</p>
 | 
			
		||||
<p>Sometimes the
 | 
			
		||||
GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
 | 
			
		||||
<script type="math/tex; mode=display">x \Phi(x)</script> where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$</p>
 | 
			
		||||
<h3>Gated Linear Units</h3>
 | 
			
		||||
<p>This is a generic implementation that supports different variants including
 | 
			
		||||
<a href="https://arxiv.org/abs/2002.05202">Gated Linear Units</a> (GLU).
 | 
			
		||||
We have also implemented experiments on these:</p>
 | 
			
		||||
<ul>
 | 
			
		||||
<li><a href="glu_variants/experiment.html">experiment that uses <code>labml.configs</code></a></li>
 | 
			
		||||
<li><a href="glu_variants/simple.html">simpler version from scratch</a></li>
 | 
			
		||||
</ul>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">26</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
 | 
			
		||||
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span> <span class="k">as</span> <span class="n">nn</span>
 | 
			
		||||
<span class="lineno">28</span>
 | 
			
		||||
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">35</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
 | 
			
		||||
<span class="lineno">36</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span> <span class="k">as</span> <span class="n">nn</span>
 | 
			
		||||
<span class="lineno">37</span>
 | 
			
		||||
<span class="lineno">38</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-1'>
 | 
			
		||||
@ -97,10 +105,10 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-1'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h2>Position-wise feed-forward network (FFN) module</h2>
 | 
			
		||||
                <h2>FFN module</h2>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">41</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-2'>
 | 
			
		||||
@ -119,13 +127,13 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
 | 
			
		||||
</ul>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">37</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">38</span>                 <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">39</span>                 <span class="n">activation</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
 | 
			
		||||
<span class="lineno">40</span>                 <span class="n">is_gated</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">41</span>                 <span class="n">bias1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">42</span>                 <span class="n">bias2</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">43</span>                 <span class="n">bias_gate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">46</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">47</span>                 <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">48</span>                 <span class="n">activation</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
 | 
			
		||||
<span class="lineno">49</span>                 <span class="n">is_gated</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">50</span>                 <span class="n">bias1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">51</span>                 <span class="n">bias2</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">52</span>                 <span class="n">bias_gate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-3'>
 | 
			
		||||
@ -136,14 +144,7 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">53</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
 | 
			
		||||
<span class="lineno">54</span>        <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias1</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">55</span>        <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias2</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">56</span>        <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">57</span>        <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span>
 | 
			
		||||
<span class="lineno">58</span>        <span class="bp">self</span><span class="o">.</span><span class="n">is_gated</span> <span class="o">=</span> <span class="n">is_gated</span>
 | 
			
		||||
<span class="lineno">59</span>        <span class="k">if</span> <span class="n">is_gated</span><span class="p">:</span>
 | 
			
		||||
<span class="lineno">60</span>            <span class="bp">self</span><span class="o">.</span><span class="n">linear_v</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias_gate</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">62</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-4'>
 | 
			
		||||
@ -151,17 +152,136 @@ GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-4'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Layer one parameterized by weight $W_1$ and bias $b_1$</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">64</span>        <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias1</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-5'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-5'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Layer one parameterized by weight $W_1$ and bias $b_1$</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">66</span>        <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias2</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-6'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-6'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Hidden layer dropout</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">68</span>        <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-7'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-7'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Activation function $f$</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">70</span>        <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-8'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-8'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Whether there is a gate</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">72</span>        <span class="bp">self</span><span class="o">.</span><span class="n">is_gated</span> <span class="o">=</span> <span class="n">is_gated</span>
 | 
			
		||||
<span class="lineno">73</span>        <span class="k">if</span> <span class="n">is_gated</span><span class="p">:</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-9'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-9'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>If there is a gate the linear layer to transform inputs to
 | 
			
		||||
be multiplied by the gate, parameterized by weight $V$ and bias $c$</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">76</span>            <span class="bp">self</span><span class="o">.</span><span class="n">linear_v</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias_gate</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-10'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-10'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">62</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">63</span>        <span class="n">g</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
 | 
			
		||||
<span class="lineno">64</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_gated</span><span class="p">:</span>
 | 
			
		||||
<span class="lineno">65</span>            <span class="n">x</span> <span class="o">=</span> <span class="n">g</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_v</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">66</span>        <span class="k">else</span><span class="p">:</span>
 | 
			
		||||
<span class="lineno">67</span>            <span class="n">x</span> <span class="o">=</span> <span class="n">g</span>
 | 
			
		||||
<span class="lineno">68</span>        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">69</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">78</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-11'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-11'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>$f(x W_1 + b_1)$</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">80</span>        <span class="n">g</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-12'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-12'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>If gated, $f(x W_1 + b_1) \otimes (x V + b) $</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">82</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_gated</span><span class="p">:</span>
 | 
			
		||||
<span class="lineno">83</span>            <span class="n">x</span> <span class="o">=</span> <span class="n">g</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_v</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-13'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-13'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Otherwise</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">85</span>        <span class="k">else</span><span class="p">:</span>
 | 
			
		||||
<span class="lineno">86</span>            <span class="n">x</span> <span class="o">=</span> <span class="n">g</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-14'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-14'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Apply dropout</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">88</span>        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-15'>
 | 
			
		||||
            <div class='docs'>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-15'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>$(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$
 | 
			
		||||
depending on whether it is gated</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">91</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    </div>
 | 
			
		||||
 | 
			
		||||
@ -71,19 +71,21 @@
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-0'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <h1>Train Autoregressive Transformer</h1>
 | 
			
		||||
<p>This trains a simple <a href="../../">transformer</a> model for auto-regression.</p>
 | 
			
		||||
                <h1>Gated Linear Units and Variants</h1>
 | 
			
		||||
<p>This trains a simple <a href="../../">transformer</a> model for auto-regression.
 | 
			
		||||
We try different variants for the <a href="../feed_forward">position-wise feedforward network</a>.
 | 
			
		||||
The reusable & configurable are defined in <a href="configs.html"><code>configs.py</code></a>.</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
 | 
			
		||||
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
 | 
			
		||||
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
 | 
			
		||||
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml.utils.pytorch</span> <span class="kn">import</span> <span class="n">get_modules</span>
 | 
			
		||||
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
 | 
			
		||||
<span class="lineno">19</span>
 | 
			
		||||
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
 | 
			
		||||
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">TransformerConfigs</span>
 | 
			
		||||
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">16</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
 | 
			
		||||
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
 | 
			
		||||
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
 | 
			
		||||
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml.utils.pytorch</span> <span class="kn">import</span> <span class="n">get_modules</span>
 | 
			
		||||
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
 | 
			
		||||
<span class="lineno">21</span>
 | 
			
		||||
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
 | 
			
		||||
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">TransformerConfigs</span>
 | 
			
		||||
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-1'>
 | 
			
		||||
@ -94,7 +96,7 @@
 | 
			
		||||
                <h2>Auto regressive model</h2>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">25</span><span class="k">class</span> <span class="nc">AutoregressiveModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">AutoregressiveModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-2'>
 | 
			
		||||
@ -105,8 +107,8 @@
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">30</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">31</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">32</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Generator</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">33</span>        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-3'>
 | 
			
		||||
@ -117,7 +119,7 @@
 | 
			
		||||
                <p>Token embedding module</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">33</span>        <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">35</span>        <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-4'>
 | 
			
		||||
@ -128,7 +130,7 @@
 | 
			
		||||
                <p>Transformer based encoder</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">35</span>        <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">37</span>        <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-5'>
 | 
			
		||||
@ -140,7 +142,7 @@
 | 
			
		||||
this give logits  of the the next token</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">38</span>        <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">40</span>        <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-6'>
 | 
			
		||||
@ -151,7 +153,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>This will be initialized on the first call</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">40</span>        <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">42</span>        <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-7'>
 | 
			
		||||
@ -162,7 +164,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">42</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">44</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-8'>
 | 
			
		||||
@ -173,8 +175,8 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Create subsequent mask, so that the transformer can only pay attention to past tokens.</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">44</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">45</span>            <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="o">=</span> <span class="n">subsequent_mask</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">src</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">46</span>        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">):</span>
 | 
			
		||||
<span class="lineno">47</span>            <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span> <span class="o">=</span> <span class="n">subsequent_mask</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">src</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-9'>
 | 
			
		||||
@ -185,7 +187,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Embed the tokens (<code>src</code>) and run it through the the transformer</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">47</span>        <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">49</span>        <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-10'>
 | 
			
		||||
@ -196,7 +198,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Generate logits of the next token</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">49</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">res</span><span class="p">),</span> <span class="kc">None</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">51</span>        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">res</span><span class="p">),</span> <span class="kc">None</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-11'>
 | 
			
		||||
@ -208,7 +210,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
<p>The default configs can and will be over-ridden when we start the experiment</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">52</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">54</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-12'>
 | 
			
		||||
@ -219,8 +221,8 @@ this give logits  of the the next token</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">59</span>    <span class="n">transformer</span><span class="p">:</span> <span class="n">TransformerConfigs</span>
 | 
			
		||||
<span class="lineno">60</span>    <span class="n">model</span><span class="p">:</span> <span class="n">AutoregressiveModel</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">61</span>    <span class="n">transformer</span><span class="p">:</span> <span class="n">TransformerConfigs</span>
 | 
			
		||||
<span class="lineno">62</span>    <span class="n">model</span><span class="p">:</span> <span class="n">AutoregressiveModel</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-13'>
 | 
			
		||||
@ -231,8 +233,8 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Initialize the auto-regressive model</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">63</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">64</span><span class="k">def</span> <span class="nf">autoregressive_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">65</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">66</span><span class="k">def</span> <span class="nf">autoregressive_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-14'>
 | 
			
		||||
@ -243,8 +245,8 @@ this give logits  of the the next token</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">68</span>    <span class="n">m</span> <span class="o">=</span> <span class="n">AutoregressiveModel</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">69</span>    <span class="k">return</span> <span class="n">m</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">70</span>    <span class="n">m</span> <span class="o">=</span> <span class="n">AutoregressiveModel</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">src_embed</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">generator</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">71</span>    <span class="k">return</span> <span class="n">m</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-15'>
 | 
			
		||||
@ -252,11 +254,11 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <div class='section-link'>
 | 
			
		||||
                    <a href='#section-15'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>Initialize the configurable transformer encoder for our autoregressive model</p>
 | 
			
		||||
                <p>Initialize the <a href="../configs.html">configurable transformer</a> encoder for our autoregressive model.</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">72</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">73</span><span class="k">def</span> <span class="nf">transformer_c</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">74</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">transformer</span><span class="p">)</span>
 | 
			
		||||
<span class="lineno">75</span><span class="k">def</span> <span class="nf">transformer_c</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-16'>
 | 
			
		||||
@ -267,11 +269,11 @@ this give logits  of the the next token</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">77</span>    <span class="n">tc</span> <span class="o">=</span> <span class="n">TransformerConfigs</span><span class="p">()</span>
 | 
			
		||||
<span class="lineno">78</span>    <span class="n">tc</span><span class="o">.</span><span class="n">n_src_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
 | 
			
		||||
<span class="lineno">79</span>    <span class="n">tc</span><span class="o">.</span><span class="n">n_tgt_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
 | 
			
		||||
<span class="lineno">80</span>
 | 
			
		||||
<span class="lineno">81</span>    <span class="k">return</span> <span class="n">tc</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">79</span>    <span class="n">tc</span> <span class="o">=</span> <span class="n">TransformerConfigs</span><span class="p">()</span>
 | 
			
		||||
<span class="lineno">80</span>    <span class="n">tc</span><span class="o">.</span><span class="n">n_src_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
 | 
			
		||||
<span class="lineno">81</span>    <span class="n">tc</span><span class="o">.</span><span class="n">n_tgt_vocab</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span>
 | 
			
		||||
<span class="lineno">82</span>
 | 
			
		||||
<span class="lineno">83</span>    <span class="k">return</span> <span class="n">tc</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-17'>
 | 
			
		||||
@ -282,7 +284,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">84</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">86</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-18'>
 | 
			
		||||
@ -293,7 +295,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Create experiment</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">86</span>    <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"glu_variants"</span><span class="p">)</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">88</span>    <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"glu_variants"</span><span class="p">)</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-19'>
 | 
			
		||||
@ -304,7 +306,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Create configs</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">88</span>    <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">90</span>    <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-20'>
 | 
			
		||||
@ -315,7 +317,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Load configurations</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">90</span>    <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">92</span>    <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-21'>
 | 
			
		||||
@ -326,19 +328,19 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>A dictionary of configurations to override</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">92</span>                       <span class="p">{</span><span class="s1">'tokenizer'</span><span class="p">:</span> <span class="s1">'character'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">93</span>                        <span class="s1">'prompt_separator'</span><span class="p">:</span> <span class="s1">''</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">94</span>                        <span class="s1">'prompt'</span><span class="p">:</span> <span class="s1">'It is '</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">95</span>                        <span class="s1">'text'</span><span class="p">:</span> <span class="s1">'tiny_shakespeare'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">96</span>
 | 
			
		||||
<span class="lineno">97</span>                        <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Noam'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">98</span>                        <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">1.</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">99</span>                        <span class="s1">'optimizer.d_model'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">100</span>
 | 
			
		||||
<span class="lineno">101</span>                        <span class="s1">'seq_len'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">102</span>                        <span class="s1">'epochs'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">103</span>                        <span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">6</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">104</span>                        <span class="s1">'inner_iterations'</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">94</span>                       <span class="p">{</span><span class="s1">'tokenizer'</span><span class="p">:</span> <span class="s1">'character'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">95</span>                        <span class="s1">'prompt_separator'</span><span class="p">:</span> <span class="s1">''</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">96</span>                        <span class="s1">'prompt'</span><span class="p">:</span> <span class="s1">'It is '</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">97</span>                        <span class="s1">'text'</span><span class="p">:</span> <span class="s1">'tiny_shakespeare'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">98</span>
 | 
			
		||||
<span class="lineno">99</span>                        <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Noam'</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">100</span>                        <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">1.</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">101</span>                        <span class="s1">'optimizer.d_model'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">102</span>
 | 
			
		||||
<span class="lineno">103</span>                        <span class="s1">'seq_len'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">104</span>                        <span class="s1">'epochs'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">105</span>                        <span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">6</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">106</span>                        <span class="s1">'inner_iterations'</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-22'>
 | 
			
		||||
@ -347,9 +349,11 @@ this give logits  of the the next token</p>
 | 
			
		||||
                    <a href='#section-22'>#</a>
 | 
			
		||||
                </div>
 | 
			
		||||
                <p>GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU</p>
 | 
			
		||||
<p>These are defined in the <a href="../configs.html#FFN">configurable FFN</a>
 | 
			
		||||
implementation</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">107</span>                        <span class="s1">'transformer.ffn.glu_variant'</span><span class="p">:</span> <span class="s1">'Bilinear'</span><span class="p">,</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">112</span>                        <span class="s1">'transformer.ffn.glu_variant'</span><span class="p">:</span> <span class="s1">'Bilinear'</span><span class="p">,</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-23'>
 | 
			
		||||
@ -360,10 +364,10 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Transformer configurations</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">110</span>                        <span class="s1">'transformer.d_model'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">111</span>                        <span class="s1">'transformer.ffn.d_ff'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">112</span>                        <span class="s1">'transformer.n_heads'</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">113</span>                        <span class="s1">'transformer.n_layers'</span><span class="p">:</span> <span class="mi">6</span><span class="p">})</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">115</span>                        <span class="s1">'transformer.d_model'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">116</span>                        <span class="s1">'transformer.ffn.d_ff'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">117</span>                        <span class="s1">'transformer.n_heads'</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
 | 
			
		||||
<span class="lineno">118</span>                        <span class="s1">'transformer.n_layers'</span><span class="p">:</span> <span class="mi">6</span><span class="p">})</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-24'>
 | 
			
		||||
@ -374,7 +378,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>This is needed to initialize models</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">116</span>    <span class="n">conf</span><span class="o">.</span><span class="n">n_tokens</span> <span class="o">=</span> <span class="n">conf</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">n_tokens</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">121</span>    <span class="n">conf</span><span class="o">.</span><span class="n">n_tokens</span> <span class="o">=</span> <span class="n">conf</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">n_tokens</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-25'>
 | 
			
		||||
@ -385,7 +389,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Set models for saving and loading</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">119</span>    <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">124</span>    <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-26'>
 | 
			
		||||
@ -396,7 +400,7 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p>Start the experiment</p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">122</span>    <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">127</span>    <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    <div class='section' id='section-27'>
 | 
			
		||||
@ -407,11 +411,11 @@ this give logits  of the the next token</p>
 | 
			
		||||
                <p><code>TrainValidConfigs.run</code></p>
 | 
			
		||||
            </div>
 | 
			
		||||
            <div class='code'>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">124</span>        <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
 | 
			
		||||
<span class="lineno">125</span>
 | 
			
		||||
<span class="lineno">126</span>
 | 
			
		||||
<span class="lineno">127</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
 | 
			
		||||
<span class="lineno">128</span>    <span class="n">main</span><span class="p">()</span></pre></div>
 | 
			
		||||
                <div class="highlight"><pre><span class="lineno">129</span>        <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
 | 
			
		||||
<span class="lineno">130</span>
 | 
			
		||||
<span class="lineno">131</span>
 | 
			
		||||
<span class="lineno">132</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
 | 
			
		||||
<span class="lineno">133</span>    <span class="n">main</span><span class="p">()</span></pre></div>
 | 
			
		||||
            </div>
 | 
			
		||||
        </div>
 | 
			
		||||
    </div>
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -19,6 +19,14 @@ from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPosit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FeedForwardConfigs(BaseConfigs):
 | 
			
		||||
    """
 | 
			
		||||
    <a id="FFN">
 | 
			
		||||
    ## FFN Configurations
 | 
			
		||||
    </a>
 | 
			
		||||
 | 
			
		||||
    Creates a Position-wise FeedForward Network defined in
 | 
			
		||||
    [`feed_forward.py`](feed_forward.html).
 | 
			
		||||
    """
 | 
			
		||||
    # Position-wise feedforward layer
 | 
			
		||||
    ffn: FeedForward
 | 
			
		||||
    # Number of features in the embedding
 | 
			
		||||
@ -44,7 +52,9 @@ class FeedForwardConfigs(BaseConfigs):
 | 
			
		||||
@option(FeedForwardConfigs.activation, 'ReLU')
 | 
			
		||||
def _ffn_activation_relu():
 | 
			
		||||
    """
 | 
			
		||||
    ReLU activation
 | 
			
		||||
    ### ReLU activation
 | 
			
		||||
 | 
			
		||||
    $$\max(0, x)$$
 | 
			
		||||
    """
 | 
			
		||||
    return nn.ReLU()
 | 
			
		||||
 | 
			
		||||
@ -52,7 +62,11 @@ def _ffn_activation_relu():
 | 
			
		||||
@option(FeedForwardConfigs.activation, 'GELU')
 | 
			
		||||
def _ffn_activation_gelu():
 | 
			
		||||
    """
 | 
			
		||||
    GELU activation
 | 
			
		||||
    ### GELU activation
 | 
			
		||||
 | 
			
		||||
    $$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
 | 
			
		||||
 | 
			
		||||
    It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
 | 
			
		||||
    """
 | 
			
		||||
    return nn.GELU()
 | 
			
		||||
 | 
			
		||||
@ -60,7 +74,7 @@ def _ffn_activation_gelu():
 | 
			
		||||
@option(FeedForwardConfigs.ffn, 'default')
 | 
			
		||||
def _feed_forward(c: FeedForwardConfigs):
 | 
			
		||||
    """
 | 
			
		||||
    Create feedforward layer
 | 
			
		||||
    Initialize a [feed forward network](feed_forward.html)
 | 
			
		||||
    """
 | 
			
		||||
    return FeedForward(c.d_model, c.d_ff,
 | 
			
		||||
                       dropout=c.dropout,
 | 
			
		||||
@ -70,7 +84,14 @@ def _feed_forward(c: FeedForwardConfigs):
 | 
			
		||||
                       bias2=c.bias2,
 | 
			
		||||
                       bias_gate=c.bias_gate)
 | 
			
		||||
 | 
			
		||||
# ## GLU Variants
 | 
			
		||||
# These are variants with gated hidden layers for the FFN
 | 
			
		||||
# as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202).
 | 
			
		||||
# We have omitted the bias terms as specified in the paper.
 | 
			
		||||
 | 
			
		||||
# ### FFN with Gated Linear Units
 | 
			
		||||
#
 | 
			
		||||
# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
 | 
			
		||||
aggregate(FeedForwardConfigs.glu_variant, 'GLU',
 | 
			
		||||
          (FeedForwardConfigs.is_gated, True),
 | 
			
		||||
          (FeedForwardConfigs.bias1, False),
 | 
			
		||||
@ -78,24 +99,40 @@ aggregate(FeedForwardConfigs.glu_variant, 'GLU',
 | 
			
		||||
          (FeedForwardConfigs.bias_gate, False),
 | 
			
		||||
          (FeedForwardConfigs.activation, nn.Sigmoid()))
 | 
			
		||||
 | 
			
		||||
# ### FFN with Bilinear hidden layer
 | 
			
		||||
#
 | 
			
		||||
# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
 | 
			
		||||
aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
 | 
			
		||||
          (FeedForwardConfigs.is_gated, True),
 | 
			
		||||
          (FeedForwardConfigs.bias1, False),
 | 
			
		||||
          (FeedForwardConfigs.bias2, False),
 | 
			
		||||
          (FeedForwardConfigs.bias_gate, False),
 | 
			
		||||
          (FeedForwardConfigs.activation, nn.Identity()))
 | 
			
		||||
 | 
			
		||||
# ### FFN with ReLU gate
 | 
			
		||||
#
 | 
			
		||||
# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
 | 
			
		||||
aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
 | 
			
		||||
          (FeedForwardConfigs.is_gated, True),
 | 
			
		||||
          (FeedForwardConfigs.bias1, False),
 | 
			
		||||
          (FeedForwardConfigs.bias2, False),
 | 
			
		||||
          (FeedForwardConfigs.bias_gate, False),
 | 
			
		||||
          (FeedForwardConfigs.activation, nn.ReLU()))
 | 
			
		||||
 | 
			
		||||
# ### FFN with GELU gate
 | 
			
		||||
#
 | 
			
		||||
# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
 | 
			
		||||
aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
 | 
			
		||||
          (FeedForwardConfigs.is_gated, True),
 | 
			
		||||
          (FeedForwardConfigs.bias1, False),
 | 
			
		||||
          (FeedForwardConfigs.bias2, False),
 | 
			
		||||
          (FeedForwardConfigs.bias_gate, False),
 | 
			
		||||
          (FeedForwardConfigs.activation, nn.GELU()))
 | 
			
		||||
 | 
			
		||||
# ### FFN with Swish gate
 | 
			
		||||
#
 | 
			
		||||
# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
 | 
			
		||||
# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
 | 
			
		||||
aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
 | 
			
		||||
          (FeedForwardConfigs.is_gated, True),
 | 
			
		||||
          (FeedForwardConfigs.bias1, False),
 | 
			
		||||
@ -236,7 +273,7 @@ def _generator(c: TransformerConfigs):
 | 
			
		||||
    return Generator(c.n_tgt_vocab, c.d_model)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ## Positional Embeddings
 | 
			
		||||
# ### Fixed Positional Embeddings
 | 
			
		||||
@option(TransformerConfigs.src_embed, 'fixed_pos')
 | 
			
		||||
def _src_embed_with_positional(c: TransformerConfigs):
 | 
			
		||||
    """
 | 
			
		||||
@ -253,7 +290,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
 | 
			
		||||
    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ## Learned Positional Embeddings
 | 
			
		||||
# ### Learned Positional Embeddings
 | 
			
		||||
@option(TransformerConfigs.src_embed, 'learned_pos')
 | 
			
		||||
def _src_embed_with_learned_positional(c: TransformerConfigs):
 | 
			
		||||
    """
 | 
			
		||||
@ -270,7 +307,7 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs):
 | 
			
		||||
    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ## No Positional Embeddings
 | 
			
		||||
# ### No Positional Embeddings
 | 
			
		||||
@option(TransformerConfigs.src_embed, 'no_pos')
 | 
			
		||||
def _src_embed_without_positional(c: TransformerConfigs):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -21,6 +21,15 @@ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.
 | 
			
		||||
Sometimes the
 | 
			
		||||
GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
 | 
			
		||||
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
 | 
			
		||||
 | 
			
		||||
### Gated Linear Units
 | 
			
		||||
 | 
			
		||||
This is a generic implementation that supports different variants including
 | 
			
		||||
[Gated Linear Units](https://arxiv.org/abs/2002.05202) (GLU).
 | 
			
		||||
We have also implemented experiments on these:
 | 
			
		||||
 | 
			
		||||
* [experiment that uses `labml.configs`](glu_variants/experiment.html)
 | 
			
		||||
* [simpler version from scratch](glu_variants/simple.html)
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -31,7 +40,7 @@ from labml_helpers.module import Module
 | 
			
		||||
 | 
			
		||||
class FeedForward(Module):
 | 
			
		||||
    """
 | 
			
		||||
    ## Position-wise feed-forward network (FFN) module
 | 
			
		||||
    ## FFN module
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, d_model: int, d_ff: int,
 | 
			
		||||
@ -51,19 +60,32 @@ class FeedForward(Module):
 | 
			
		||||
        * `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # Layer one parameterized by weight $W_1$ and bias $b_1$
 | 
			
		||||
        self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
 | 
			
		||||
        # Layer one parameterized by weight $W_1$ and bias $b_1$
 | 
			
		||||
        self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
 | 
			
		||||
        # Hidden layer dropout
 | 
			
		||||
        self.dropout = nn.Dropout(dropout)
 | 
			
		||||
        # Activation function $f$
 | 
			
		||||
        self.activation = activation
 | 
			
		||||
        # Whether there is a gate
 | 
			
		||||
        self.is_gated = is_gated
 | 
			
		||||
        if is_gated:
 | 
			
		||||
            # If there is a gate the linear layer to transform inputs to
 | 
			
		||||
            # be multiplied by the gate, parameterized by weight $V$ and bias $c$
 | 
			
		||||
            self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
 | 
			
		||||
 | 
			
		||||
    def __call__(self, x: torch.Tensor):
 | 
			
		||||
        # $f(x W_1 + b_1)$
 | 
			
		||||
        g = self.activation(self.layer1(x))
 | 
			
		||||
        # If gated, $f(x W_1 + b_1) \otimes (x V + b) $
 | 
			
		||||
        if self.is_gated:
 | 
			
		||||
            x = g * self.linear_v(x)
 | 
			
		||||
        # Otherwise
 | 
			
		||||
        else:
 | 
			
		||||
            x = g
 | 
			
		||||
        # Apply dropout
 | 
			
		||||
        x = self.dropout(x)
 | 
			
		||||
        # $(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$
 | 
			
		||||
        # depending on whether it is gated
 | 
			
		||||
        return self.layer2(x)
 | 
			
		||||
 | 
			
		||||
@ -6,9 +6,11 @@ summary: >
 | 
			
		||||
  for the position-wise feedforward network (FFN).
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
# Train Autoregressive Transformer
 | 
			
		||||
# Gated Linear Units and Variants
 | 
			
		||||
 | 
			
		||||
This trains a simple [transformer](../../) model for auto-regression.
 | 
			
		||||
We try different variants for the [position-wise feedforward network](../feed_forward).
 | 
			
		||||
The reusable & configurable are defined in [`configs.py`](configs.html).
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -72,7 +74,7 @@ def autoregressive_model(c: Configs):
 | 
			
		||||
@option(Configs.transformer)
 | 
			
		||||
def transformer_c(c: Configs):
 | 
			
		||||
    """
 | 
			
		||||
    Initialize the configurable transformer encoder for our autoregressive model
 | 
			
		||||
    Initialize the [configurable transformer](../configs.html) encoder for our autoregressive model.
 | 
			
		||||
    """
 | 
			
		||||
    tc = TransformerConfigs()
 | 
			
		||||
    tc.n_src_vocab = c.n_tokens
 | 
			
		||||
@ -104,6 +106,9 @@ def main():
 | 
			
		||||
                        'inner_iterations': 10,
 | 
			
		||||
 | 
			
		||||
                        # GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU
 | 
			
		||||
                        #
 | 
			
		||||
                        # These are defined in the [configurable FFN](../configs.html#FFN)
 | 
			
		||||
                        # implementation
 | 
			
		||||
                        'transformer.ffn.glu_variant': 'Bilinear',
 | 
			
		||||
 | 
			
		||||
                        # Transformer configurations
 | 
			
		||||
 | 
			
		||||
@ -6,9 +6,13 @@ summary: >
 | 
			
		||||
  for the position-wise feedforward network (FFN).
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
# Train Autoregressive Transformer
 | 
			
		||||
# Gated Linear Units and Variants
 | 
			
		||||
 | 
			
		||||
This trains a simple [transformer](../../) model for auto-regression.
 | 
			
		||||
We try different variants for the [position-wise feedforward network](../feed_forward).
 | 
			
		||||
 | 
			
		||||
*This is a simpler implementation that doesn't use [`labml.configs`](experiment.html) module.
 | 
			
		||||
We decided to write a simpler implementation to make it easier readers who are not familiar.*
 | 
			
		||||
"""
 | 
			
		||||
import dataclasses
 | 
			
		||||
 | 
			
		||||
@ -56,6 +60,9 @@ class AutoregressiveModel(nn.Module):
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class Configs:
 | 
			
		||||
    """
 | 
			
		||||
    ### Configurations
 | 
			
		||||
    """
 | 
			
		||||
    d_model: int = 512
 | 
			
		||||
    seq_len: int = 128
 | 
			
		||||
    batch_size: int = 32
 | 
			
		||||
@ -69,71 +76,130 @@ class Configs:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TinyShakespeareDataset(Dataset):
 | 
			
		||||
    """
 | 
			
		||||
    ### Tiny Shakespeare Dataset
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, seq_len: int):
 | 
			
		||||
        # Location of the text file
 | 
			
		||||
        path = lab.get_data_path() / 'tiny_shakespeare.txt'
 | 
			
		||||
        # Download the file
 | 
			
		||||
        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
 | 
			
		||||
        # Read the downloaded file
 | 
			
		||||
        with open(str(path), 'r') as f:
 | 
			
		||||
            text = f.read()
 | 
			
		||||
 | 
			
		||||
        # Extract the characters
 | 
			
		||||
        chars = list(set(text))
 | 
			
		||||
        # Character to id (integer) map
 | 
			
		||||
        self.stoi = {c: i for i, c in enumerate(chars)}
 | 
			
		||||
        # Id to character map
 | 
			
		||||
        self.itos = {i: c for i, c in enumerate(chars)}
 | 
			
		||||
        # Length of a training sample
 | 
			
		||||
        self.seq_len = seq_len
 | 
			
		||||
        # Data in the form of a tensor of ids
 | 
			
		||||
        self.data = self.text_to_i(text)
 | 
			
		||||
 | 
			
		||||
    def text_to_i(self, text: str):
 | 
			
		||||
        """
 | 
			
		||||
        Transform the text into a tensor of ids
 | 
			
		||||
        """
 | 
			
		||||
        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        """
 | 
			
		||||
        Number of samples in the dataset.
 | 
			
		||||
 | 
			
		||||
        *This will read the dataset `seq_len` times in a single epoch.*
 | 
			
		||||
        """
 | 
			
		||||
        return len(self.data) - self.seq_len - 1
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        """
 | 
			
		||||
        Return a sample
 | 
			
		||||
        """
 | 
			
		||||
        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Trainer:
 | 
			
		||||
    """
 | 
			
		||||
    ## Trainer
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, configs: Configs):
 | 
			
		||||
        # Get the device
 | 
			
		||||
        self.device = torch.device('cpu')
 | 
			
		||||
        if torch.cuda.is_available():
 | 
			
		||||
            self.device = torch.device('cuda:0')
 | 
			
		||||
        # Initialize the dataset
 | 
			
		||||
        self.dataset = TinyShakespeareDataset(configs.seq_len)
 | 
			
		||||
        self.dataloader = DataLoader(self.dataset, batch_size=configs.batch_size, collate_fn=transpose_batch,
 | 
			
		||||
        # Initialize the dataloader
 | 
			
		||||
        self.dataloader = DataLoader(self.dataset,
 | 
			
		||||
                                     batch_size=configs.batch_size,
 | 
			
		||||
                                     collate_fn=transpose_batch,
 | 
			
		||||
                                     shuffle=True)
 | 
			
		||||
 | 
			
		||||
        # FFN with Gated Linear Unit
 | 
			
		||||
        # $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
 | 
			
		||||
        if configs.glu_variant == 'GLU':
 | 
			
		||||
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
 | 
			
		||||
        # FFN with Bilinear hidden layer
 | 
			
		||||
        # $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
 | 
			
		||||
        elif configs.glu_variant == 'Bilinear':
 | 
			
		||||
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
 | 
			
		||||
        # FFN with ReLU gate
 | 
			
		||||
        # $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
 | 
			
		||||
        elif configs.glu_variant == 'ReGLU':
 | 
			
		||||
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
 | 
			
		||||
        # FFN with GELU gate
 | 
			
		||||
        # $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
 | 
			
		||||
        elif configs.glu_variant == 'GEGLU':
 | 
			
		||||
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
 | 
			
		||||
        # FFN with Swish gate
 | 
			
		||||
        # $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
 | 
			
		||||
        # where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
 | 
			
		||||
        elif configs.glu_variant == 'SwiGLU':
 | 
			
		||||
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
 | 
			
		||||
        # FFN with ReLU activation
 | 
			
		||||
        # $$FFN_{ReLU}(x)(x, W_1, W_2, b_1, b_2) = \text{ReLU}_1(x W_1 + b_1) W_2 + b_2$$
 | 
			
		||||
        elif configs.glu_variant == 'ReLU':
 | 
			
		||||
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
 | 
			
		||||
        # FFN with ReLU activation
 | 
			
		||||
        # $$FFN_{GELU}(x)(x, W_1, W_2, b_1, b_2) = \text{GELU}_1(x W_1 + b_1) W_2 + b_2$$
 | 
			
		||||
        elif configs.glu_variant == 'GELU':
 | 
			
		||||
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f'Unknown variant {configs.glu_variant}')
 | 
			
		||||
 | 
			
		||||
        # Number of different characters
 | 
			
		||||
        n_chars = len(self.dataset.stoi)
 | 
			
		||||
 | 
			
		||||
        # Initialize [Multi-Head Attention module](../mha.html)
 | 
			
		||||
        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
 | 
			
		||||
        # Initialize the [Transformer Block](../models.html#TransformerLayer)
 | 
			
		||||
        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
 | 
			
		||||
                                             feed_forward=ffn, dropout_prob=configs.dropout)
 | 
			
		||||
        # Initialize the model with an
 | 
			
		||||
        # [embedding layer](../models.html#EmbeddingsWithPositionalEncoding)
 | 
			
		||||
        # (with fixed positional encoding)
 | 
			
		||||
        # [transformer encoder](../models.html#Encoder) and
 | 
			
		||||
        # a linear layer to generate logits.
 | 
			
		||||
        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
 | 
			
		||||
                                         Encoder(TransformerLayer(
 | 
			
		||||
                                             d_model=configs.d_model,
 | 
			
		||||
                                             self_attn=MultiHeadAttention(configs.n_heads, configs.d_model,
 | 
			
		||||
                                                                          configs.dropout),
 | 
			
		||||
                                             src_attn=None,
 | 
			
		||||
                                             feed_forward=ffn,
 | 
			
		||||
                                             dropout_prob=configs.dropout
 | 
			
		||||
                                         ), configs.n_layers),
 | 
			
		||||
                                         Encoder(transformer_layer, configs.n_layers),
 | 
			
		||||
                                         nn.Linear(configs.d_model, n_chars))
 | 
			
		||||
 | 
			
		||||
        # Move the model to the current device
 | 
			
		||||
        self.model.to(self.device)
 | 
			
		||||
 | 
			
		||||
        # Initialize [Noam optimizer](../../optimizers/noam.html)
 | 
			
		||||
        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
 | 
			
		||||
 | 
			
		||||
        # Cross-entropy loss
 | 
			
		||||
        self.loss_func = nn.CrossEntropyLoss()
 | 
			
		||||
        # Number of training epochs;
 | 
			
		||||
        # *note that our dataset definition repeats the data `seq_len` times in a single epoch
 | 
			
		||||
        self.epochs = configs.epochs
 | 
			
		||||
        # Gradient clipping norm
 | 
			
		||||
        self.grad_norm_clip = configs.grad_norm_clip
 | 
			
		||||
 | 
			
		||||
        # Set tracker configurations
 | 
			
		||||
@ -166,18 +232,28 @@ class Trainer:
 | 
			
		||||
        logger.log(log)
 | 
			
		||||
 | 
			
		||||
    def train(self):
 | 
			
		||||
        """
 | 
			
		||||
        ### Train the model
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Loop for the given number of epochs
 | 
			
		||||
        for _ in monit.loop(self.epochs):
 | 
			
		||||
            # Iterate over the minibatches
 | 
			
		||||
            for i, batch in monit.enum('Train', self.dataloader):
 | 
			
		||||
                # Move data to the device
 | 
			
		||||
                data, target = batch[0].to(self.device), batch[1].to(self.device)
 | 
			
		||||
 | 
			
		||||
                # Set tracker step, as the number of characters trained on
 | 
			
		||||
                tracker.add_global_step(data.shape[0] * data.shape[1])
 | 
			
		||||
 | 
			
		||||
                # Set model state to training
 | 
			
		||||
                self.model.train()
 | 
			
		||||
                # Evaluate the model
 | 
			
		||||
                output = self.model(data)
 | 
			
		||||
 | 
			
		||||
                # Calculate and log loss
 | 
			
		||||
                # Calculate loss
 | 
			
		||||
                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
 | 
			
		||||
                # Log the loss
 | 
			
		||||
                tracker.add("loss.train", loss)
 | 
			
		||||
 | 
			
		||||
                # Calculate gradients
 | 
			
		||||
@ -186,12 +262,13 @@ class Trainer:
 | 
			
		||||
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
 | 
			
		||||
                # Take optimizer step
 | 
			
		||||
                self.optimizer.step()
 | 
			
		||||
                # Log the model parameters and gradients on last batch of every epoch
 | 
			
		||||
                # Log the model parameters and gradients
 | 
			
		||||
                if (i + 1) % 100 == 0:
 | 
			
		||||
                    tracker.add('model', self.model)
 | 
			
		||||
                # Clear the gradients
 | 
			
		||||
                self.optimizer.zero_grad()
 | 
			
		||||
 | 
			
		||||
                # Generate a sample
 | 
			
		||||
                if (i + 1) % 100 == 0:
 | 
			
		||||
                    self.model.eval()
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
@ -201,6 +278,7 @@ class Trainer:
 | 
			
		||||
                if (i + 1) % 10 == 0:
 | 
			
		||||
                    tracker.save()
 | 
			
		||||
 | 
			
		||||
            # Save the model
 | 
			
		||||
            experiment.save_checkpoint()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -212,12 +290,14 @@ def main():
 | 
			
		||||
    # Load configurations
 | 
			
		||||
    experiment.configs(dataclasses.asdict(configs))
 | 
			
		||||
 | 
			
		||||
    # Create trainer
 | 
			
		||||
    trainer = Trainer(configs)
 | 
			
		||||
    # Set models for training and loading
 | 
			
		||||
    experiment.add_pytorch_models({'model': trainer.model})
 | 
			
		||||
 | 
			
		||||
    # Start the experiment
 | 
			
		||||
    with experiment.start():
 | 
			
		||||
        # `TrainValidConfigs.run`
 | 
			
		||||
        # Train the model
 | 
			
		||||
        trainer.train()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user