transformer xl links

This commit is contained in:
Varuna Jayasiri
2021-02-07 16:17:50 +05:30
parent ab01567b52
commit bbd442d4a9
9 changed files with 100 additions and 60 deletions

View File

@@ -84,7 +84,10 @@ implementations.</p>
<ul> <ul>
<li><a href="transformers/mha.html">Multi-headed attention</a></li> <li><a href="transformers/mha.html">Multi-headed attention</a></li>
<li><a href="transformers/models.html">Transformer building blocks</a></li> <li><a href="transformers/models.html">Transformer building blocks</a></li>
<li><a href="transformers/xl/relative_mha.html">Relative multi-headed attention</a>.</li> <li><a href="transformers/xl/index.html">Transformer XL</a><ul>
<li><a href="transformers/xl/relative_mha.html">Relative multi-headed attention</a></li>
</ul>
</li>
<li><a href="transformers/gpt/index.html">GPT Architecture</a></li> <li><a href="transformers/gpt/index.html">GPT Architecture</a></li>
<li><a href="transformers/glu_variants/simple.html">GLU Variants</a></li> <li><a href="transformers/glu_variants/simple.html">GLU Variants</a></li>
<li><a href="transformers/knn/index.html">kNN-LM: Generalization through Memorization</a></li> <li><a href="transformers/knn/index.html">kNN-LM: Generalization through Memorization</a></li>

View File

@@ -426,6 +426,13 @@
</url> </url>
<url>
<loc>https://nn.labml.ai/transformers/xl/experiment.html</loc>
<lastmod>2021-02-07T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url> <url>
<loc>https://nn.labml.ai/transformers/xl/index.html</loc> <loc>https://nn.labml.ai/transformers/xl/index.html</loc>
<lastmod>2021-02-07T16:30:00+00:00</lastmod> <lastmod>2021-02-07T16:30:00+00:00</lastmod>

View File

@@ -78,10 +78,12 @@ from paper <a href="https://arxiv.org/abs/1706.03762">Attention Is All You Need<
and derivatives and enhancements of it.</p> and derivatives and enhancements of it.</p>
<ul> <ul>
<li><a href="mha.html">Multi-head attention</a></li> <li><a href="mha.html">Multi-head attention</a></li>
<li><a href="xl/relative_mha.html">Relative multi-head attention</a></li>
<li><a href="models.html">Transformer Encoder and Decoder Models</a></li> <li><a href="models.html">Transformer Encoder and Decoder Models</a></li>
<li><a href="positional_encoding.html">Fixed positional encoding</a></li> <li><a href="positional_encoding.html">Fixed positional encoding</a></li>
</ul> </ul>
<h2><a href="xl/index.html">Transformer XL</a></h2>
<p>This implements Transformer XL model using
<a href="xl/relative_mha.html">relative multi-head attention</a></p>
<h2><a href="gpt">GPT Architecture</a></h2> <h2><a href="gpt">GPT Architecture</a></h2>
<p>This is an implementation of GPT-2 architecture.</p> <p>This is an implementation of GPT-2 architecture.</p>
<h2><a href="glu_variants/simple.html">GLU Variants</a></h2> <h2><a href="glu_variants/simple.html">GLU Variants</a></h2>
@@ -100,10 +102,10 @@ Our implementation only has a few million parameters and doesn&rsquo;t do model
It does single GPU training but we implement the concept of switching as described in the paper.</p> It does single GPU training but we implement the concept of switching as described in the paper.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">49</span><span></span><span class="kn">from</span> <span class="nn">.configs</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span> <div class="highlight"><pre><span class="lineno">52</span><span></span><span class="kn">from</span> <span class="nn">.configs</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
<span class="lineno">50</span><span class="kn">from</span> <span class="nn">.models</span> <span class="kn">import</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">EncoderDecoder</span> <span class="lineno">53</span><span class="kn">from</span> <span class="nn">.models</span> <span class="kn">import</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">EncoderDecoder</span>
<span class="lineno">51</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span> <span class="lineno">54</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">52</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.xl.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span></pre></div> <span class="lineno">55</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.xl.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span></pre></div>
</div> </div>
</div> </div>
</div> </div>

View File

@@ -93,15 +93,15 @@ are introduced at the attention calculation.</p>
<a href="https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p> <a href="https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">37</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span> <div class="highlight"><pre><span class="lineno">36</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span>
<span class="lineno">38</span> <span class="lineno">37</span>
<span class="lineno">39</span><span class="kn">import</span> <span class="nn">torch</span> <span class="lineno">38</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">40</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span> <span class="lineno">39</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">41</span> <span class="lineno">40</span>
<span class="lineno">42</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span> <span class="lineno">41</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span> <span class="lineno">42</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span> <span class="lineno">43</span><span class="kn">from</span> <span class="nn">.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span>
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">..feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span></pre></div> <span class="lineno">44</span><span class="kn">from</span> <span class="nn">..feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
@@ -113,7 +113,7 @@ are introduced at the attention calculation.</p>
<p>The transformer XL model comprises of a number of these layers.</p> <p>The transformer XL model comprises of a number of these layers.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">48</span><span class="k">class</span> <span class="nc">TransformerXLLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">47</span><span class="k">class</span> <span class="nc">TransformerXLLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
@@ -129,11 +129,11 @@ are introduced at the attention calculation.</p>
</ul> </ul>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <div class="highlight"><pre><span class="lineno">53</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">55</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="lineno">54</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">56</span> <span class="n">self_attn</span><span class="p">:</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">,</span> <span class="lineno">55</span> <span class="n">self_attn</span><span class="p">:</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">,</span>
<span class="lineno">57</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span> <span class="lineno">56</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">58</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div> <span class="lineno">57</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
@@ -144,13 +144,13 @@ are introduced at the attention calculation.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <div class="highlight"><pre><span class="lineno">64</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">66</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span> <span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span> <span class="lineno">66</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
<span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span> <span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">69</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span> <span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">70</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span> <span class="lineno">69</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">71</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div> <span class="lineno">70</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-4'> <div class='section' id='section-4'>
@@ -166,10 +166,10 @@ are introduced at the attention calculation.</p>
</ul> </ul>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <div class="highlight"><pre><span class="lineno">72</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">74</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="lineno">73</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">75</span> <span class="n">mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="lineno">74</span> <span class="n">mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="lineno">76</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <span class="lineno">75</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-5'> <div class='section' id='section-5'>
@@ -180,7 +180,7 @@ are introduced at the attention calculation.</p>
<p>Normalize the vectors before doing self attention</p> <p>Normalize the vectors before doing self attention</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">83</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
@@ -191,7 +191,7 @@ are introduced at the attention calculation.</p>
<p>If there is memory</p> <p>If there is memory</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="k">if</span> <span class="n">mem</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div> <div class="highlight"><pre><span class="lineno">85</span> <span class="k">if</span> <span class="n">mem</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>
@@ -202,7 +202,7 @@ are introduced at the attention calculation.</p>
<p>Normalize it</p> <p>Normalize it</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">87</span> <span class="n">mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-8'> <div class='section' id='section-8'>
@@ -213,7 +213,7 @@ are introduced at the attention calculation.</p>
<p>Concatenate with <code>z</code></p> <p>Concatenate with <code>z</code></p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">m_z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">mem</span><span class="p">,</span> <span class="n">z</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">89</span> <span class="n">m_z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">mem</span><span class="p">,</span> <span class="n">z</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-9'> <div class='section' id='section-9'>
@@ -224,8 +224,8 @@ are introduced at the attention calculation.</p>
<p>Ignore if there is no memory</p> <p>Ignore if there is no memory</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="k">else</span><span class="p">:</span> <div class="highlight"><pre><span class="lineno">91</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">93</span> <span class="n">m_z</span> <span class="o">=</span> <span class="n">z</span></pre></div> <span class="lineno">92</span> <span class="n">m_z</span> <span class="o">=</span> <span class="n">z</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-10'> <div class='section' id='section-10'>
@@ -236,7 +236,7 @@ are introduced at the attention calculation.</p>
<p>Attention</p> <p>Attention</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">m_z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">m_z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">94</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">m_z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">m_z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-11'> <div class='section' id='section-11'>
@@ -247,7 +247,7 @@ are introduced at the attention calculation.</p>
<p>Add the attention results</p> <p>Add the attention results</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">96</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-12'> <div class='section' id='section-12'>
@@ -258,7 +258,7 @@ are introduced at the attention calculation.</p>
<p>Normalize for feed-forward</p> <p>Normalize for feed-forward</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">99</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-13'> <div class='section' id='section-13'>
@@ -269,7 +269,7 @@ are introduced at the attention calculation.</p>
<p>Pass through the feed-forward network</p> <p>Pass through the feed-forward network</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">102</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">101</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-14'> <div class='section' id='section-14'>
@@ -280,7 +280,7 @@ are introduced at the attention calculation.</p>
<p>Add the feed-forward results back</p> <p>Add the feed-forward results back</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">103</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-15'> <div class='section' id='section-15'>
@@ -291,7 +291,7 @@ are introduced at the attention calculation.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">return</span> <span class="n">x</span></pre></div> <div class="highlight"><pre><span class="lineno">106</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-16'> <div class='section' id='section-16'>
@@ -303,7 +303,7 @@ are introduced at the attention calculation.</p>
<p>This consists of multiple transformer XL layers</p> <p>This consists of multiple transformer XL layers</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">110</span><span class="k">class</span> <span class="nc">TransformerXL</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">109</span><span class="k">class</span> <span class="nc">TransformerXL</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-17'> <div class='section' id='section-17'>
@@ -314,8 +314,8 @@ are introduced at the attention calculation.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerXLLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">116</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerXLLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">118</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> <span class="lineno">117</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-18'> <div class='section' id='section-18'>
@@ -326,7 +326,7 @@ are introduced at the attention calculation.</p>
<p>Make copies of the transformer layer</p> <p>Make copies of the transformer layer</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">119</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-19'> <div class='section' id='section-19'>
@@ -337,7 +337,7 @@ are introduced at the attention calculation.</p>
<p>Final normalization layer</p> <p>Final normalization layer</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div> <div class="highlight"><pre><span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-20'> <div class='section' id='section-20'>
@@ -352,7 +352,7 @@ are introduced at the attention calculation.</p>
</ul> </ul>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">123</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-21'> <div class='section' id='section-21'>
@@ -364,7 +364,7 @@ are introduced at the attention calculation.</p>
which will be the memories for the next sequential batch.</p> which will be the memories for the next sequential batch.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="n">new_mem</span> <span class="o">=</span> <span class="p">[]</span></pre></div> <div class="highlight"><pre><span class="lineno">131</span> <span class="n">new_mem</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-22'> <div class='section' id='section-22'>
@@ -375,7 +375,7 @@ which will be the memories for the next sequential batch.</p>
<p>Run through each transformer layer</p> <p>Run through each transformer layer</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">133</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-23'> <div class='section' id='section-23'>
@@ -386,7 +386,7 @@ which will be the memories for the next sequential batch.</p>
<p>Add to the list of feature vectors</p> <p>Add to the list of feature vectors</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">new_mem</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span></pre></div> <div class="highlight"><pre><span class="lineno">135</span> <span class="n">new_mem</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-24'> <div class='section' id='section-24'>
@@ -397,7 +397,7 @@ which will be the memories for the next sequential batch.</p>
<p>Memory</p> <p>Memory</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">m</span> <span class="o">=</span> <span class="n">mem</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">mem</span> <span class="k">else</span> <span class="kc">None</span></pre></div> <div class="highlight"><pre><span class="lineno">137</span> <span class="n">m</span> <span class="o">=</span> <span class="n">mem</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">mem</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-25'> <div class='section' id='section-25'>
@@ -408,7 +408,7 @@ which will be the memories for the next sequential batch.</p>
<p>Run through the transformer XL layer</p> <p>Run through the transformer XL layer</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mem</span><span class="o">=</span><span class="n">m</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">139</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mem</span><span class="o">=</span><span class="n">m</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-26'> <div class='section' id='section-26'>
@@ -419,7 +419,7 @@ which will be the memories for the next sequential batch.</p>
<p>Finally, normalize the vectors</p> <p>Finally, normalize the vectors</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">new_mem</span></pre></div> <div class="highlight"><pre><span class="lineno">141</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">new_mem</span></pre></div>
</div> </div>
</div> </div>
</div> </div>

View File

@@ -17,7 +17,8 @@ implementations.
* [Multi-headed attention](transformers/mha.html) * [Multi-headed attention](transformers/mha.html)
* [Transformer building blocks](transformers/models.html) * [Transformer building blocks](transformers/models.html)
* [Relative multi-headed attention](transformers/xl/relative_mha.html). * [Transformer XL](transformers/xl/index.html)
* [Relative multi-headed attention](transformers/xl/relative_mha.html)
* [GPT Architecture](transformers/gpt/index.html) * [GPT Architecture](transformers/gpt/index.html)
* [GLU Variants](transformers/glu_variants/simple.html) * [GLU Variants](transformers/glu_variants/simple.html)
* [kNN-LM: Generalization through Memorization](transformers/knn/index.html) * [kNN-LM: Generalization through Memorization](transformers/knn/index.html)

View File

@@ -14,10 +14,13 @@ from paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762),
and derivatives and enhancements of it. and derivatives and enhancements of it.
* [Multi-head attention](mha.html) * [Multi-head attention](mha.html)
* [Relative multi-head attention](xl/relative_mha.html)
* [Transformer Encoder and Decoder Models](models.html) * [Transformer Encoder and Decoder Models](models.html)
* [Fixed positional encoding](positional_encoding.html) * [Fixed positional encoding](positional_encoding.html)
## [Transformer XL](xl/index.html)
This implements Transformer XL model using
[relative multi-head attention](xl/relative_mha.html)
## [GPT Architecture](gpt) ## [GPT Architecture](gpt)
This is an implementation of GPT-2 architecture. This is an implementation of GPT-2 architecture.

View File

@@ -30,7 +30,6 @@ Here's [the training code](experiment.html) and a notebook for training a transf
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/xl/experiment.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/xl/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002) [![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002)
""" """

View File

@@ -0,0 +1,24 @@
# Transformer XL
This is an implementation of
[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860)
in [PyTorch](https://pytorch.org).
Transformer has a limited attention span,
equal to the length of the sequence trained in parallel.
All these positions have a fixed positional encoding.
Transformer XL increases this attention span by letting
each of the positions pay attention to precalculated past embeddings.
For instance if the context length is $l$ it will keep the embeddings of
all layers for previous batch of length $l$ and feed them to current step.
If we use fixed-positional encodings these pre-calculated embeddings will have
the same positions as the current context.
They introduce relative positional encoding, where the positional encodings
are introduced at the attention calculation.
Annotated implementation of relative multi-headed attention is in [`relative_mha.py`](relative_mha.html).
Here's [the training code](experiment.html) and a notebook for training a transformer XL model on Tiny Shakespeare dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/xl/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002)

View File

@@ -23,7 +23,8 @@ implementations almost weekly.
* [Multi-headed attention](https://nn.labml.ai/transformers/mha.html) * [Multi-headed attention](https://nn.labml.ai/transformers/mha.html)
* [Transformer building blocks](https://nn.labml.ai/transformers/models.html) * [Transformer building blocks](https://nn.labml.ai/transformers/models.html)
* [Relative multi-headed attention](https://nn.labml.ai/transformers/xl/relative_mha.html). * [Transformer XL](https://nn.labml.ai/transformers/xl/index.html)
* [Relative multi-headed attention](https://nn.labml.ai/transformers/xl/relative_mha.html)
* [GPT Architecture](https://nn.labml.ai/transformers/gpt/index.html) * [GPT Architecture](https://nn.labml.ai/transformers/gpt/index.html)
* [GLU Variants](https://nn.labml.ai/transformers/glu_variants/simple.html) * [GLU Variants](https://nn.labml.ai/transformers/glu_variants/simple.html)
* [kNN-LM: Generalization through Memorization](https://nn.labml.ai/transformers/knn) * [kNN-LM: Generalization through Memorization](https://nn.labml.ai/transformers/knn)