This commit is contained in:
Varuna Jayasiri
2022-07-02 14:31:16 +05:30
parent ab4264cbda
commit b6bef1d2fe
8 changed files with 215 additions and 223 deletions

View File

@ -76,10 +76,10 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">17</span><span></span><span class="kn">import</span> <span class="nn">torch</span> <div class="highlight"><pre><span class="lineno">17</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">18</span> <span class="lineno">18</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span> <span class="lineno">19</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span> <span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span> <span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span> <span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span><span class="p">,</span> <span class="n">Encoder</span> <span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span><span class="p">,</span> <span class="n">Encoder</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div> <span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div>
@ -94,7 +94,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">AutoregressiveTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">AutoregressiveTransformer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
@ -111,7 +111,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">31</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">31</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>

View File

@ -80,10 +80,9 @@
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">List</span> <span class="lineno">26</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">List</span>
<span class="lineno">27</span> <span class="lineno">27</span>
<span class="lineno">28</span><span class="kn">import</span> <span class="nn">torch</span> <span class="lineno">28</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span> <span class="k">as</span> <span class="n">nn</span> <span class="lineno">29</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">30</span> <span class="lineno">30</span>
<span class="lineno">31</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span> <span class="lineno">31</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">tracker</span></pre></div>
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
@ -97,7 +96,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">34</span><span class="k">class</span> <span class="nc">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
@ -108,8 +107,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">45</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
<span class="lineno">47</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> <span class="lineno">46</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
@ -121,7 +120,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">48</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-4'> <div class='section' id='section-4'>
@ -133,7 +132,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">heads</span></pre></div> <div class="highlight"><pre><span class="lineno">50</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">heads</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-5'> <div class='section' id='section-5'>
@ -145,7 +144,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_k</span></pre></div> <div class="highlight"><pre><span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_k</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
@ -156,7 +155,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">54</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>
@ -170,7 +169,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">head_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div> <div class="highlight"><pre><span class="lineno">58</span> <span class="n">head_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-8'> <div class='section' id='section-8'>
@ -182,7 +181,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">61</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-9'> <div class='section' id='section-9'>
@ -194,7 +193,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">head_shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">64</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">head_shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-10'> <div class='section' id='section-10'>
@ -208,7 +207,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="k">return</span> <span class="n">x</span></pre></div> <div class="highlight"><pre><span class="lineno">67</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-11'> <div class='section' id='section-11'>
@ -251,7 +250,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">71</span><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">70</span><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-12'> <div class='section' id='section-12'>
@ -269,7 +268,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">91</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-13'> <div class='section' id='section-13'>
@ -280,7 +279,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">98</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> <div class="highlight"><pre><span class="lineno">97</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-14'> <div class='section' id='section-14'>
@ -292,7 +291,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">heads</span></pre></div> <div class="highlight"><pre><span class="lineno">100</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">heads</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-15'> <div class='section' id='section-15'>
@ -304,7 +303,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">heads</span></pre></div> <div class="highlight"><pre><span class="lineno">102</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">heads</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-16'> <div class='section' id='section-16'>
@ -319,9 +318,9 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">105</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span>
<span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span> <span class="lineno">106</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span>
<span class="lineno">108</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> <span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-17'> <div class='section' id='section-17'>
@ -334,7 +333,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-18'> <div class='section' id='section-18'>
@ -346,7 +345,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-19'> <div class='section' id='section-19'>
@ -358,7 +357,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-20'> <div class='section' id='section-20'>
@ -370,7 +369,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-21'> <div class='section' id='section-21'>
@ -382,7 +381,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="kc">None</span></pre></div> <div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-22'> <div class='section' id='section-22'>
@ -395,7 +394,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="k">def</span> <span class="nf">get_scores</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">122</span> <span class="k">def</span> <span class="nf">get_scores</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-23'> <div class='section' id='section-23'>
@ -407,7 +406,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ibhd,jbhd-&gt;ijbh&#39;</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">130</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ibhd,jbhd-&gt;ijbh&#39;</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-24'> <div class='section' id='section-24'>
@ -421,7 +420,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">def</span> <span class="nf">prepare_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">query_shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">key_shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span></pre></div> <div class="highlight"><pre><span class="lineno">132</span> <span class="k">def</span> <span class="nf">prepare_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">query_shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">key_shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-25'> <div class='section' id='section-25'>
@ -432,9 +431,9 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">query_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <div class="highlight"><pre><span class="lineno">138</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">query_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="lineno">140</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">key_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="lineno">139</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">key_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="lineno">141</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="n">query_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div> <span class="lineno">140</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="n">query_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-26'> <div class='section' id='section-26'>
@ -446,7 +445,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">143</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-27'> <div class='section' id='section-27'>
@ -459,7 +458,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="k">return</span> <span class="n">mask</span></pre></div> <div class="highlight"><pre><span class="lineno">146</span> <span class="k">return</span> <span class="n">mask</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-28'> <div class='section' id='section-28'>
@ -482,11 +481,11 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <div class="highlight"><pre><span class="lineno">148</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">150</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="lineno">149</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">151</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="lineno">150</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">152</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="lineno">151</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">153</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div> <span class="lineno">152</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-29'> <div class='section' id='section-29'>
@ -502,10 +501,10 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span> <div class="highlight"><pre><span class="lineno">164</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span>
<span class="lineno">166</span> <span class="lineno">165</span>
<span class="lineno">167</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="lineno">166</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">168</span> <span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div> <span class="lineno">167</span> <span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-30'> <div class='section' id='section-30'>
@ -521,9 +520,9 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">171</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
<span class="lineno">173</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="lineno">172</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="lineno">174</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div> <span class="lineno">173</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-31'> <div class='section' id='section-31'>
@ -536,7 +535,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">177</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-32'> <div class='section' id='section-32'>
@ -559,7 +558,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">scores</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div> <div class="highlight"><pre><span class="lineno">180</span> <span class="n">scores</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-33'> <div class='section' id='section-33'>
@ -571,8 +570,8 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <div class="highlight"><pre><span class="lineno">183</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">185</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">))</span></pre></div> <span class="lineno">184</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">))</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-34'> <div class='section' id='section-34'>
@ -595,7 +594,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">188</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-35'> <div class='section' id='section-35'>
@ -607,7 +606,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">192</span> <span class="n">tracker</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="s1">&#39;attn&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">191</span> <span class="n">tracker</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="s1">&#39;attn&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-36'> <div class='section' id='section-36'>
@ -619,7 +618,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">194</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-37'> <div class='section' id='section-37'>
@ -642,7 +641,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">198</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-38'> <div class='section' id='section-38'>
@ -654,7 +653,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div> <div class="highlight"><pre><span class="lineno">201</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-39'> <div class='section' id='section-39'>
@ -666,7 +665,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">204</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-40'> <div class='section' id='section-40'>
@ -678,7 +677,7 @@ M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">207</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='footer'> <div class='footer'>

View File

@ -77,12 +77,11 @@
<span class="lineno">15</span> <span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">import</span> <span class="nn">torch</span> <span class="lineno">16</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">17</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span> <span class="lineno">17</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span> <span class="lineno">18</span>
<span class="lineno">19</span> <span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span> <span class="lineno">20</span><span class="kn">from</span> <span class="nn">.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span> <span class="lineno">21</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span> <span class="lineno">22</span><span class="kn">from</span> <span class="nn">.positional_encoding</span> <span class="kn">import</span> <span class="n">get_positional_encoding</span></pre></div>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">.positional_encoding</span> <span class="kn">import</span> <span class="n">get_positional_encoding</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
@ -95,7 +94,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">26</span><span class="k">class</span> <span class="nc">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">25</span><span class="k">class</span> <span class="nc">EmbeddingsWithPositionalEncoding</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
@ -106,11 +105,11 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">32</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span>
<span class="lineno">34</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="lineno">33</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span> <span class="lineno">34</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">36</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span> <span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;positional_encodings&#39;</span><span class="p">,</span> <span class="n">get_positional_encoding</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">max_len</span><span class="p">))</span></pre></div> <span class="lineno">36</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;positional_encodings&#39;</span><span class="p">,</span> <span class="n">get_positional_encoding</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">max_len</span><span class="p">))</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
@ -121,9 +120,9 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">38</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">40</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span> <span class="lineno">39</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
<span class="lineno">41</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span> <span class="o">+</span> <span class="n">pe</span></pre></div> <span class="lineno">40</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span> <span class="o">+</span> <span class="n">pe</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-4'> <div class='section' id='section-4'>
@ -136,7 +135,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">44</span><span class="k">class</span> <span class="nc">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">43</span><span class="k">class</span> <span class="nc">EmbeddingsWithLearnedPositionalEncoding</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-5'> <div class='section' id='section-5'>
@ -147,11 +146,11 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">50</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span>
<span class="lineno">52</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="lineno">51</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">53</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span> <span class="lineno">52</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span> <span class="lineno">53</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">55</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div> <span class="lineno">54</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
@ -162,9 +161,9 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">56</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">58</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="lineno">57</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
<span class="lineno">59</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span> <span class="o">+</span> <span class="n">pe</span></pre></div> <span class="lineno">58</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span> <span class="o">+</span> <span class="n">pe</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>
@ -179,7 +178,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">62</span><span class="k">class</span> <span class="nc">TransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">61</span><span class="k">class</span> <span class="nc">TransformerLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-8'> <div class='section' id='section-8'>
@ -200,12 +199,12 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <div class="highlight"><pre><span class="lineno">79</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">81</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="lineno">80</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">82</span> <span class="n">self_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span> <span class="lineno">81</span> <span class="n">self_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">83</span> <span class="n">src_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="lineno">82</span> <span class="n">src_attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">84</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span> <span class="lineno">83</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">85</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div> <span class="lineno">84</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-9'> <div class='section' id='section-9'>
@ -216,16 +215,16 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <div class="highlight"><pre><span class="lineno">92</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span> <span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span> <span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
<span class="lineno">96</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="o">=</span> <span class="n">src_attn</span> <span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="o">=</span> <span class="n">src_attn</span>
<span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span> <span class="lineno">96</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">98</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span> <span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">99</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span> <span class="lineno">98</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">100</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="lineno">99</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span> <span class="lineno">100</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">102</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div> <span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-10'> <div class='section' id='section-10'>
@ -237,7 +236,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span> <span class="o">=</span> <span class="kc">False</span></pre></div> <div class="highlight"><pre><span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-11'> <div class='section' id='section-11'>
@ -248,11 +247,11 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <div class="highlight"><pre><span class="lineno">105</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">107</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="lineno">106</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">108</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="lineno">107</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">109</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="lineno">108</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">110</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div> <span class="lineno">109</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-12'> <div class='section' id='section-12'>
@ -264,7 +263,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">111</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-13'> <div class='section' id='section-13'>
@ -276,7 +275,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">113</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-14'> <div class='section' id='section-14'>
@ -288,7 +287,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">115</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-15'> <div class='section' id='section-15'>
@ -300,7 +299,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="k">if</span> <span class="n">src</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div> <div class="highlight"><pre><span class="lineno">120</span> <span class="k">if</span> <span class="n">src</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-16'> <div class='section' id='section-16'>
@ -312,7 +311,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">122</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_src_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-17'> <div class='section' id='section-17'>
@ -324,7 +323,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="n">attn_src</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">124</span> <span class="n">attn_src</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">src</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-18'> <div class='section' id='section-18'>
@ -336,7 +335,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn_src</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">126</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn_src</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-19'> <div class='section' id='section-19'>
@ -348,7 +347,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">129</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-20'> <div class='section' id='section-20'>
@ -360,8 +359,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">132</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span><span class="p">:</span> <div class="highlight"><pre><span class="lineno">131</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_save_ff_input</span><span class="p">:</span>
<span class="lineno">133</span> <span class="bp">self</span><span class="o">.</span><span class="n">ff_input</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span></pre></div> <span class="lineno">132</span> <span class="bp">self</span><span class="o">.</span><span class="n">ff_input</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-21'> <div class='section' id='section-21'>
@ -373,7 +372,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">134</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-22'> <div class='section' id='section-22'>
@ -385,9 +384,9 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">136</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
<span class="lineno">138</span> <span class="lineno">137</span>
<span class="lineno">139</span> <span class="k">return</span> <span class="n">x</span></pre></div> <span class="lineno">138</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-23'> <div class='section' id='section-23'>
@ -400,7 +399,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">142</span><span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">141</span><span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-24'> <div class='section' id='section-24'>
@ -411,8 +410,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">148</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">150</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> <span class="lineno">149</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-25'> <div class='section' id='section-25'>
@ -424,7 +423,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">151</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-26'> <div class='section' id='section-26'>
@ -436,7 +435,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div> <div class="highlight"><pre><span class="lineno">153</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-27'> <div class='section' id='section-27'>
@ -447,7 +446,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">155</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-28'> <div class='section' id='section-28'>
@ -459,8 +458,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span> <div class="highlight"><pre><span class="lineno">157</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">159</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div> <span class="lineno">158</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-29'> <div class='section' id='section-29'>
@ -472,7 +471,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">160</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-30'> <div class='section' id='section-30'>
@ -485,7 +484,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">164</span><span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">163</span><span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-31'> <div class='section' id='section-31'>
@ -496,8 +495,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">170</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">172</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div> <span class="lineno">171</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-32'> <div class='section' id='section-32'>
@ -509,7 +508,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">174</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">173</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-33'> <div class='section' id='section-33'>
@ -521,7 +520,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div> <div class="highlight"><pre><span class="lineno">175</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-34'> <div class='section' id='section-34'>
@ -532,7 +531,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">177</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-35'> <div class='section' id='section-35'>
@ -544,8 +543,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">180</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span> <div class="highlight"><pre><span class="lineno">179</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">181</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">tgt_mask</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div> <span class="lineno">180</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">tgt_mask</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-36'> <div class='section' id='section-36'>
@ -557,7 +556,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">182</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-37'> <div class='section' id='section-37'>
@ -572,7 +571,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">186</span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">185</span><span class="k">class</span> <span class="nc">Generator</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-38'> <div class='section' id='section-38'>
@ -583,9 +582,9 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">195</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">197</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="lineno">196</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">198</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div> <span class="lineno">197</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-39'> <div class='section' id='section-39'>
@ -596,8 +595,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">199</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="lineno">201</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div> <span class="lineno">200</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-40'> <div class='section' id='section-40'>
@ -610,7 +609,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">204</span><span class="k">class</span> <span class="nc">EncoderDecoder</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">203</span><span class="k">class</span> <span class="nc">EncoderDecoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-41'> <div class='section' id='section-41'>
@ -621,13 +620,13 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Module</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">210</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="p">:</span> <span class="n">Decoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">tgt_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="lineno">212</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="lineno">211</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">213</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span> <span class="lineno">212</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span>
<span class="lineno">214</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span> <span class="lineno">213</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span>
<span class="lineno">215</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span> <span class="lineno">214</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">src_embed</span>
<span class="lineno">216</span> <span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span> <span class="o">=</span> <span class="n">tgt_embed</span> <span class="lineno">215</span> <span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span> <span class="o">=</span> <span class="n">tgt_embed</span>
<span class="lineno">217</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div> <span class="lineno">216</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-42'> <div class='section' id='section-42'>
@ -639,9 +638,9 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">221</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span> <div class="highlight"><pre><span class="lineno">220</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
<span class="lineno">222</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span> <span class="lineno">221</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">223</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">xavier_uniform_</span><span class="p">(</span><span class="n">p</span><span class="p">)</span></pre></div> <span class="lineno">222</span> <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">xavier_uniform_</span><span class="p">(</span><span class="n">p</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-43'> <div class='section' id='section-43'>
@ -652,7 +651,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">225</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">224</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-44'> <div class='section' id='section-44'>
@ -664,7 +663,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">enc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">226</span> <span class="n">enc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-45'> <div class='section' id='section-45'>
@ -676,7 +675,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">229</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">enc</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">228</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">enc</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-46'> <div class='section' id='section-46'>
@ -687,8 +686,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">231</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">230</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">232</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="n">src_mask</span><span class="p">)</span></pre></div> <span class="lineno">231</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">src</span><span class="p">),</span> <span class="n">src_mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-47'> <div class='section' id='section-47'>
@ -699,8 +698,8 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">234</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">233</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">memory</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">235</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">(</span><span class="n">tgt</span><span class="p">),</span> <span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div> <span class="lineno">234</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tgt_embed</span><span class="p">(</span><span class="n">tgt</span><span class="p">),</span> <span class="n">memory</span><span class="p">,</span> <span class="n">src_mask</span><span class="p">,</span> <span class="n">tgt_mask</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='footer'> <div class='footer'>

View File

@ -79,9 +79,7 @@
<span class="lineno">24</span> <span class="lineno">24</span>
<span class="lineno">25</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="lineno">25</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="lineno">26</span><span class="kn">import</span> <span class="nn">torch</span> <span class="lineno">26</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">27</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span> <span class="lineno">27</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span></pre></div>
<span class="lineno">28</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-1'> <div class='section' id='section-1'>
@ -92,7 +90,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">PositionalEncoding</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">30</span><span class="k">class</span> <span class="nc">PositionalEncoding</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
@ -103,11 +101,11 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">31</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span>
<span class="lineno">34</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="lineno">32</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span> <span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">36</span> <span class="lineno">34</span>
<span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;positional_encodings&#39;</span><span class="p">,</span> <span class="n">get_positional_encoding</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">max_len</span><span class="p">),</span> <span class="kc">False</span><span class="p">)</span></pre></div> <span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;positional_encodings&#39;</span><span class="p">,</span> <span class="n">get_positional_encoding</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">max_len</span><span class="p">),</span> <span class="kc">False</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-3'> <div class='section' id='section-3'>
@ -118,11 +116,11 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">40</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span> <span class="lineno">38</span> <span class="n">pe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positional_encodings</span><span class="p">[:</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
<span class="lineno">41</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">pe</span> <span class="lineno">39</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">pe</span>
<span class="lineno">42</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="lineno">40</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">43</span> <span class="k">return</span> <span class="n">x</span></pre></div> <span class="lineno">41</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-4'> <div class='section' id='section-4'>
@ -133,7 +131,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">46</span><span class="k">def</span> <span class="nf">get_positional_encoding</span><span class="p">(</span><span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">44</span><span class="k">def</span> <span class="nf">get_positional_encoding</span><span class="p">(</span><span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5000</span><span class="p">):</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-5'> <div class='section' id='section-5'>
@ -145,7 +143,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">48</span> <span class="n">encodings</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">46</span> <span class="n">encodings</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-6'> <div class='section' id='section-6'>
@ -157,7 +155,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="n">position</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">48</span> <span class="n">position</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>
@ -169,7 +167,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">two_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">50</span> <span class="n">two_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-8'> <div class='section' id='section-8'>
@ -181,7 +179,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">div_term</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">two_i</span> <span class="o">*</span> <span class="o">-</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mf">10000.0</span><span class="p">)</span> <span class="o">/</span> <span class="n">d_model</span><span class="p">))</span></pre></div> <div class="highlight"><pre><span class="lineno">52</span> <span class="n">div_term</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">two_i</span> <span class="o">*</span> <span class="o">-</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mf">10000.0</span><span class="p">)</span> <span class="o">/</span> <span class="n">d_model</span><span class="p">))</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-9'> <div class='section' id='section-9'>
@ -193,7 +191,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">encodings</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">54</span> <span class="n">encodings</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-10'> <div class='section' id='section-10'>
@ -205,7 +203,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">encodings</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">56</span> <span class="n">encodings</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-11'> <div class='section' id='section-11'>
@ -217,9 +215,9 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">encodings</span> <span class="o">=</span> <span class="n">encodings</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">59</span> <span class="n">encodings</span> <span class="o">=</span> <span class="n">encodings</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
<span class="lineno">62</span> <span class="lineno">60</span>
<span class="lineno">63</span> <span class="k">return</span> <span class="n">encodings</span></pre></div> <span class="lineno">61</span> <span class="k">return</span> <span class="n">encodings</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-12'> <div class='section' id='section-12'>
@ -230,19 +228,19 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">66</span><span class="k">def</span> <span class="nf">_test_positional_encoding</span><span class="p">():</span> <div class="highlight"><pre><span class="lineno">64</span><span class="k">def</span> <span class="nf">_test_positional_encoding</span><span class="p">():</span>
<span class="lineno">67</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="lineno">65</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
<span class="lineno">68</span> <span class="lineno">66</span>
<span class="lineno">69</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="lineno">67</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="lineno">70</span> <span class="n">pe</span> <span class="o">=</span> <span class="n">get_positional_encoding</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span> <span class="lineno">68</span> <span class="n">pe</span> <span class="o">=</span> <span class="n">get_positional_encoding</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="lineno">71</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">100</span><span class="p">),</span> <span class="n">pe</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">8</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="lineno">69</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">100</span><span class="p">),</span> <span class="n">pe</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">8</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
<span class="lineno">72</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">([</span><span class="s2">&quot;dim </span><span class="si">%d</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">7</span><span class="p">]])</span> <span class="lineno">70</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">([</span><span class="s2">&quot;dim </span><span class="si">%d</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">7</span><span class="p">]])</span>
<span class="lineno">73</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Positional encoding&quot;</span><span class="p">)</span> <span class="lineno">71</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Positional encoding&quot;</span><span class="p">)</span>
<span class="lineno">74</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="lineno">72</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<span class="lineno">75</span> <span class="lineno">73</span>
<span class="lineno">76</span> <span class="lineno">74</span>
<span class="lineno">77</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span> <span class="lineno">75</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">78</span> <span class="n">_test_positional_encoding</span><span class="p">()</span></pre></div> <span class="lineno">76</span> <span class="n">_test_positional_encoding</span><span class="p">()</span></pre></div>
</div> </div>
</div> </div>
<div class='footer'> <div class='footer'>

View File

@ -15,20 +15,20 @@ on an NLP auto-regression task (with Tiny Shakespeare dataset).
""" """
import torch import torch
from torch import nn
from labml import experiment from labml import experiment
from labml.configs import option from labml.configs import option
from labml_helpers.module import Module
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers import TransformerConfigs, Encoder from labml_nn.transformers import TransformerConfigs, Encoder
from labml_nn.transformers.utils import subsequent_mask from labml_nn.transformers.utils import subsequent_mask
class AutoregressiveTransformer(Module): class AutoregressiveTransformer(nn.Module):
""" """
## Auto-Regressive model ## Auto-Regressive model
""" """
def __init__(self, encoder: Encoder, src_embed: Module, generator: Module): def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):
""" """
* `encoder` is the transformer [Encoder](../models.html#Encoder) * `encoder` is the transformer [Encoder](../models.html#Encoder)
* `src_embed` is the token * `src_embed` is the token

View File

@ -26,13 +26,12 @@ import math
from typing import Optional, List from typing import Optional, List
import torch import torch
from torch import nn as nn from torch import nn
from labml import tracker from labml import tracker
from labml_helpers.module import Module
class PrepareForMultiHeadAttention(Module): class PrepareForMultiHeadAttention(nn.Module):
""" """
<a id="PrepareMHA"></a> <a id="PrepareMHA"></a>
@ -68,7 +67,7 @@ class PrepareForMultiHeadAttention(Module):
return x return x
class MultiHeadAttention(Module): class MultiHeadAttention(nn.Module):
r""" r"""
<a id="MHA"></a> <a id="MHA"></a>

View File

@ -15,7 +15,6 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from labml_helpers.module import Module
from labml_nn.utils import clone_module_list from labml_nn.utils import clone_module_list
from .feed_forward import FeedForward from .feed_forward import FeedForward
@ -23,7 +22,7 @@ from .mha import MultiHeadAttention
from .positional_encoding import get_positional_encoding from .positional_encoding import get_positional_encoding
class EmbeddingsWithPositionalEncoding(Module): class EmbeddingsWithPositionalEncoding(nn.Module):
""" """
<a id="EmbeddingsWithPositionalEncoding"></a> <a id="EmbeddingsWithPositionalEncoding"></a>
@ -41,7 +40,7 @@ class EmbeddingsWithPositionalEncoding(Module):
return self.linear(x) * math.sqrt(self.d_model) + pe return self.linear(x) * math.sqrt(self.d_model) + pe
class EmbeddingsWithLearnedPositionalEncoding(Module): class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
""" """
<a id="EmbeddingsWithLearnedPositionalEncoding"></a> <a id="EmbeddingsWithLearnedPositionalEncoding"></a>
@ -59,7 +58,7 @@ class EmbeddingsWithLearnedPositionalEncoding(Module):
return self.linear(x) * math.sqrt(self.d_model) + pe return self.linear(x) * math.sqrt(self.d_model) + pe
class TransformerLayer(Module): class TransformerLayer(nn.Module):
""" """
<a id="TransformerLayer"></a> <a id="TransformerLayer"></a>
@ -139,7 +138,7 @@ class TransformerLayer(Module):
return x return x
class Encoder(Module): class Encoder(nn.Module):
""" """
<a id="Encoder"></a> <a id="Encoder"></a>
@ -161,7 +160,7 @@ class Encoder(Module):
return self.norm(x) return self.norm(x)
class Decoder(Module): class Decoder(nn.Module):
""" """
<a id="Decoder"></a> <a id="Decoder"></a>
@ -183,7 +182,7 @@ class Decoder(Module):
return self.norm(x) return self.norm(x)
class Generator(Module): class Generator(nn.Module):
""" """
<a id="Generator"></a> <a id="Generator"></a>
@ -201,14 +200,14 @@ class Generator(Module):
return self.projection(x) return self.projection(x)
class EncoderDecoder(Module): class EncoderDecoder(nn.Module):
""" """
<a id="EncoderDecoder"></a> <a id="EncoderDecoder"></a>
## Combined Encoder-Decoder ## Combined Encoder-Decoder
""" """
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module): def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder

View File

@ -26,10 +26,8 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from labml_helpers.module import Module
class PositionalEncoding(nn.Module):
class PositionalEncoding(Module):
def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000): def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
super().__init__() super().__init__()
self.dropout = nn.Dropout(dropout_prob) self.dropout = nn.Dropout(dropout_prob)