Files
Varuna Jayasiri cf565bcc1d cleanup
2024-06-18 11:09:02 +05:30

752 lines
52 KiB
HTML

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="These are PyTorch implementations of Transformer based encoder and decoder models, as well as other related modules."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Transformer Encoder and Decoder Models"/>
<meta name="twitter:description" content="These are PyTorch implementations of Transformer based encoder and decoder models, as well as other related modules."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/models.html"/>
<meta property="og:title" content="Transformer Encoder and Decoder Models"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Transformer Encoder and Decoder Models"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Transformer Encoder and Decoder Models"/>
<meta property="og:description" content="These are PyTorch implementations of Transformer based encoder and decoder models, as well as other related modules."/>
<title>Transformer Encoder and Decoder Models</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/transformers/models.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">transformers</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/models.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Transformer Encoder and Decoder Models</h1>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/basic/autoregressive_experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">math</span>
<span class="lineno">14</span>
<span class="lineno">15</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">16</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">17</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">.positional_encoding</span> <span class="kn">import</span> <span class="n">get_positional_encoding</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p> <a id="EmbeddingsWithPositionalEncoding"></a></p>
<h2>Embed tokens and add <a href="positional_encoding.html">fixed positional encoding</a></h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">24</span><span class="k">class</span> <span class="nc">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span>
<span class="lineno">32</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">34</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;positional_encodings&#39;</span><span class="p">,</span> <span class="n">get_positional_encoding</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">max_len</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">38</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
<span class="lineno">39</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span> <span class="o">+</span> <span class="n">pe</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p> <a id="EmbeddingsWithLearnedPositionalEncoding"></a></p>
<h2>Embed tokens and add parameterized positional encodings</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span><span class="k">class</span> <span class="nc">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span>
<span class="lineno">50</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">51</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">53</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">56</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
<span class="lineno">57</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span> <span class="o">+</span> <span class="n">pe</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p> <a id="TransformerLayer"></a></p>
<h2>Transformer Layer</h2>
<p>This can act as an encoder layer or a decoder layer. We use pre-norm.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span><span class="k">class</span> <span class="nc">TransformerLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the token embedding size </li>
<li><code class="highlight"><span></span><span class="n">self_attn</span></code>
is the self attention module </li>
<li><code class="highlight"><span></span><span class="n">src_attn</span></code>
is the source attention module (when this is used in a decoder) </li>
<li><code class="highlight"><span></span><span class="n">feed_forward</span></code>
is the feed forward module </li>
<li><code class="highlight"><span></span><span class="n">dropout_prob</span></code>
is the probability of dropping out after self attention and FFN</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">70</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">71</span> <span class="n">self_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">72</span> <span class="n">src_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">73</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">74</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">84</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
<span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="o">=</span> <span class="n">src_attn</span>
<span class="lineno">86</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">87</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">88</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">89</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">90</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">91</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Whether to save input to the feed forward layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">96</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">97</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">98</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">99</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Normalize the vectors before doing self attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Run through self attention, i.e. keys and values are from self </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Add the self attention results </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>If a source is provided, get results from attention to source. This is when you have a decoder layer that pays attention to encoder outputs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">if</span> <span class="n">src</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Normalize vectors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Attention to source. i.e. keys and values are from source </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">attn_src</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Add the source attention results </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn_src</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Normalize for feed-forward </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Save the input to the feed forward layer if specified </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span><span class="p">:</span>
<span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">ff_input</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Pass through the feed-forward network </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Add the feed-forward results back </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
<span class="lineno">127</span>
<span class="lineno">128</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p> <a id="Encoder"></a></p>
<h2>Transformer Encoder</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span><span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">139</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Make copies of the transformer layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Final normalization layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Run through each transformer layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">148</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Finally, normalize the vectors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p> <a id="Decoder"></a></p>
<h2>Transformer Decoder</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span><span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">161</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Make copies of the transformer layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Final normalization layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Run through each transformer layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">170</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">tgt_mask</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Finally, normalize the vectors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p> <a id="Generator"></a></p>
<h2>Generator</h2>
<p>This predicts the tokens and gives the lof softmax of those. You don&#x27;t need this if you are using <code class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span></code>
.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">175</span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">186</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">190</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p> <a id="EncoderDecoder"></a></p>
<h2>Combined Encoder-Decoder</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span><span class="k">class</span> <span class="nc">EncoderDecoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="lineno">201</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">202</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span>
<span class="lineno">203</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span>
<span class="lineno">204</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span>
<span class="lineno">205</span> <span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span> <span class="o">=</span> <span class="n">tgt_embed</span>
<span class="lineno">206</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>This was important from their code. Initialize parameters with Glorot / fan_avg. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
<span class="lineno">211</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">212</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">xavier_uniform_</span><span class="p">(</span><span class="n">p</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Run the source through encoder </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">216</span> <span class="n">enc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Run encodings and targets through decoder </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">218</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">enc</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">221</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="n">src_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">223</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">224</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">(</span><span class="n">tgt</span><span class="p">),</span> <span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>