mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
fast weights experiment
This commit is contained in:
@ -72,17 +72,21 @@
|
|||||||
<div class='section-link'>
|
<div class='section-link'>
|
||||||
<a href='#section-0'>#</a>
|
<a href='#section-0'>#</a>
|
||||||
</div>
|
</div>
|
||||||
|
<h1>Train Fast Weights Transformer</h1>
|
||||||
|
<p>This trains a fast weights transformer model for auto-regression.</p>
|
||||||
|
<p>Here’s a Colab notebook for training a fast weights transformer on Tiny Shakespeare dataset.</p>
|
||||||
|
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
|
||||||
|
<a href="https://app.labml.ai/run/928aadc0846c11eb85710242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">8</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
<div class="highlight"><pre><span class="lineno">17</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||||
<span class="lineno">9</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||||
<span class="lineno">10</span>
|
<span class="lineno">19</span>
|
||||||
<span class="lineno">11</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
|
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
|
||||||
<span class="lineno">12</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
|
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
|
||||||
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">labml.utils.pytorch</span> <span class="kn">import</span> <span class="n">get_modules</span>
|
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml.utils.pytorch</span> <span class="kn">import</span> <span class="n">get_modules</span>
|
||||||
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
||||||
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span></pre></div>
|
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-1'>
|
<div class='section' id='section-1'>
|
||||||
@ -93,7 +97,7 @@
|
|||||||
<h2>Auto regressive model</h2>
|
<h2>Auto regressive model</h2>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">AutoregressiveModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">AutoregressiveModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-2'>
|
<div class='section' id='section-2'>
|
||||||
@ -104,8 +108,8 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">23</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">Module</span><span class="p">):</span>
|
<div class="highlight"><pre><span class="lineno">32</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">Module</span><span class="p">):</span>
|
||||||
<span class="lineno">24</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
<span class="lineno">33</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-3'>
|
<div class='section' id='section-3'>
|
||||||
@ -116,9 +120,9 @@
|
|||||||
<p>Token embedding module</p>
|
<p>Token embedding module</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">26</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
|
<div class="highlight"><pre><span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
|
||||||
<span class="lineno">27</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span> <span class="o">=</span> <span class="n">transformer</span>
|
<span class="lineno">36</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span> <span class="o">=</span> <span class="n">transformer</span>
|
||||||
<span class="lineno">28</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div>
|
<span class="lineno">37</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-4'>
|
<div class='section' id='section-4'>
|
||||||
@ -129,7 +133,7 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">30</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">39</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-5'>
|
<div class='section' id='section-5'>
|
||||||
@ -140,7 +144,7 @@
|
|||||||
<p>Embed the tokens</p>
|
<p>Embed the tokens</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">32</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">41</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-6'>
|
<div class='section' id='section-6'>
|
||||||
@ -151,7 +155,7 @@
|
|||||||
<p>Run it through the the transformer</p>
|
<p>Run it through the the transformer</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">34</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">43</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-7'>
|
<div class='section' id='section-7'>
|
||||||
@ -162,7 +166,7 @@
|
|||||||
<p>Generate logits of the next token</p>
|
<p>Generate logits of the next token</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">36</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">res</span><span class="p">),</span> <span class="kc">None</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">45</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">res</span><span class="p">),</span> <span class="kc">None</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-8'>
|
<div class='section' id='section-8'>
|
||||||
@ -174,7 +178,7 @@
|
|||||||
<p>The default configs can and will be over-ridden when we start the experiment</p>
|
<p>The default configs can and will be over-ridden when we start the experiment</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">39</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">48</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-9'>
|
<div class='section' id='section-9'>
|
||||||
@ -185,14 +189,14 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">46</span> <span class="n">model</span><span class="p">:</span> <span class="n">AutoregressiveModel</span>
|
<div class="highlight"><pre><span class="lineno">55</span> <span class="n">model</span><span class="p">:</span> <span class="n">AutoregressiveModel</span>
|
||||||
<span class="lineno">47</span>
|
<span class="lineno">56</span>
|
||||||
<span class="lineno">48</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span>
|
<span class="lineno">57</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span>
|
||||||
<span class="lineno">49</span> <span class="n">nu</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
<span class="lineno">58</span> <span class="n">nu</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||||||
<span class="lineno">50</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span>
|
<span class="lineno">59</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span>
|
||||||
<span class="lineno">51</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span>
|
<span class="lineno">60</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||||||
<span class="lineno">52</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span>
|
<span class="lineno">61</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span>
|
||||||
<span class="lineno">53</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6</span></pre></div>
|
<span class="lineno">62</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-10'>
|
<div class='section' id='section-10'>
|
||||||
@ -203,8 +207,8 @@
|
|||||||
<p>Create <a href="index.html">fast weights transformer</a>.</p>
|
<p>Create <a href="index.html">fast weights transformer</a>.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">56</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
<div class="highlight"><pre><span class="lineno">65</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||||||
<span class="lineno">57</span><span class="k">def</span> <span class="nf">fast_weights_transformer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
|
<span class="lineno">66</span><span class="k">def</span> <span class="nf">fast_weights_transformer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-11'>
|
<div class='section' id='section-11'>
|
||||||
@ -215,18 +219,18 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">61</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.fast_weights</span> <span class="kn">import</span> <span class="n">FastWeightsAttentionTransformer</span><span class="p">,</span> \
|
<div class="highlight"><pre><span class="lineno">70</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.fast_weights</span> <span class="kn">import</span> <span class="n">FastWeightsAttentionTransformer</span><span class="p">,</span> \
|
||||||
<span class="lineno">62</span> <span class="n">FastWeightsAttentionTransformerLayer</span><span class="p">,</span> <span class="n">FastWeightsAttention</span><span class="p">,</span> <span class="n">FeedForward</span>
|
<span class="lineno">71</span> <span class="n">FastWeightsAttentionTransformerLayer</span><span class="p">,</span> <span class="n">FastWeightsAttention</span><span class="p">,</span> <span class="n">FeedForward</span>
|
||||||
<span class="lineno">63</span>
|
<span class="lineno">72</span>
|
||||||
<span class="lineno">64</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.fast_weights</span> <span class="kn">import</span> <span class="n">DPFP</span>
|
<span class="lineno">73</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.fast_weights</span> <span class="kn">import</span> <span class="n">DPFP</span>
|
||||||
<span class="lineno">65</span> <span class="k">return</span> <span class="n">AutoregressiveModel</span><span class="p">(</span>
|
<span class="lineno">74</span> <span class="k">return</span> <span class="n">AutoregressiveModel</span><span class="p">(</span>
|
||||||
<span class="lineno">66</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
|
<span class="lineno">75</span> <span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
|
||||||
<span class="lineno">67</span> <span class="n">FastWeightsAttentionTransformer</span><span class="p">(</span>
|
<span class="lineno">76</span> <span class="n">FastWeightsAttentionTransformer</span><span class="p">(</span>
|
||||||
<span class="lineno">68</span> <span class="n">FastWeightsAttentionTransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
|
<span class="lineno">77</span> <span class="n">FastWeightsAttentionTransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
|
||||||
<span class="lineno">69</span> <span class="n">attn</span><span class="o">=</span><span class="n">FastWeightsAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">DPFP</span><span class="p">(</span><span class="n">nu</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">nu</span><span class="p">)),</span>
|
<span class="lineno">78</span> <span class="n">attn</span><span class="o">=</span><span class="n">FastWeightsAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">DPFP</span><span class="p">(</span><span class="n">nu</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">nu</span><span class="p">)),</span>
|
||||||
<span class="lineno">70</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">FeedForward</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">),</span>
|
<span class="lineno">79</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">FeedForward</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">),</span>
|
||||||
<span class="lineno">71</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">),</span>
|
<span class="lineno">80</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">),</span>
|
||||||
<span class="lineno">72</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
<span class="lineno">81</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-12'>
|
<div class='section' id='section-12'>
|
||||||
@ -237,7 +241,7 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">75</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">84</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-13'>
|
<div class='section' id='section-13'>
|
||||||
@ -248,7 +252,7 @@
|
|||||||
<p>Create experiment</p>
|
<p>Create experiment</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">77</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"fast_weights_transformer"</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">86</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"fast_weights_transformer"</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-14'>
|
<div class='section' id='section-14'>
|
||||||
@ -259,7 +263,7 @@
|
|||||||
<p>Create configs</p>
|
<p>Create configs</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-15'>
|
<div class='section' id='section-15'>
|
||||||
@ -270,7 +274,7 @@
|
|||||||
<p>Load configurations</p>
|
<p>Load configurations</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">81</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-16'>
|
<div class='section' id='section-16'>
|
||||||
@ -281,20 +285,20 @@
|
|||||||
<p>A dictionary of configurations to override</p>
|
<p>A dictionary of configurations to override</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">83</span> <span class="p">{</span><span class="s1">'tokenizer'</span><span class="p">:</span> <span class="s1">'character'</span><span class="p">,</span>
|
<div class="highlight"><pre><span class="lineno">92</span> <span class="p">{</span><span class="s1">'tokenizer'</span><span class="p">:</span> <span class="s1">'character'</span><span class="p">,</span>
|
||||||
<span class="lineno">84</span> <span class="s1">'text'</span><span class="p">:</span> <span class="s1">'tiny_shakespeare'</span><span class="p">,</span>
|
<span class="lineno">93</span> <span class="s1">'text'</span><span class="p">:</span> <span class="s1">'tiny_shakespeare'</span><span class="p">,</span>
|
||||||
<span class="lineno">85</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">1.0</span><span class="p">,</span>
|
<span class="lineno">94</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">1.0</span><span class="p">,</span>
|
||||||
<span class="lineno">86</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Noam'</span><span class="p">,</span>
|
<span class="lineno">95</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Noam'</span><span class="p">,</span>
|
||||||
<span class="lineno">87</span> <span class="s1">'prompt'</span><span class="p">:</span> <span class="s1">'It is'</span><span class="p">,</span>
|
<span class="lineno">96</span> <span class="s1">'prompt'</span><span class="p">:</span> <span class="s1">'It is'</span><span class="p">,</span>
|
||||||
<span class="lineno">88</span> <span class="s1">'prompt_separator'</span><span class="p">:</span> <span class="s1">''</span><span class="p">,</span>
|
<span class="lineno">97</span> <span class="s1">'prompt_separator'</span><span class="p">:</span> <span class="s1">''</span><span class="p">,</span>
|
||||||
<span class="lineno">89</span>
|
<span class="lineno">98</span>
|
||||||
<span class="lineno">90</span> <span class="s1">'train_loader'</span><span class="p">:</span> <span class="s1">'shuffled_train_loader'</span><span class="p">,</span>
|
<span class="lineno">99</span> <span class="s1">'train_loader'</span><span class="p">:</span> <span class="s1">'shuffled_train_loader'</span><span class="p">,</span>
|
||||||
<span class="lineno">91</span> <span class="s1">'valid_loader'</span><span class="p">:</span> <span class="s1">'shuffled_valid_loader'</span><span class="p">,</span>
|
<span class="lineno">100</span> <span class="s1">'valid_loader'</span><span class="p">:</span> <span class="s1">'shuffled_valid_loader'</span><span class="p">,</span>
|
||||||
<span class="lineno">92</span>
|
<span class="lineno">101</span>
|
||||||
<span class="lineno">93</span> <span class="s1">'seq_len'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
|
<span class="lineno">102</span> <span class="s1">'seq_len'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
|
||||||
<span class="lineno">94</span> <span class="s1">'epochs'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
|
<span class="lineno">103</span> <span class="s1">'epochs'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
|
||||||
<span class="lineno">95</span> <span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span>
|
<span class="lineno">104</span> <span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span>
|
||||||
<span class="lineno">96</span> <span class="s1">'inner_iterations'</span><span class="p">:</span> <span class="mi">25</span><span class="p">})</span></pre></div>
|
<span class="lineno">105</span> <span class="s1">'inner_iterations'</span><span class="p">:</span> <span class="mi">25</span><span class="p">})</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-17'>
|
<div class='section' id='section-17'>
|
||||||
@ -305,7 +309,7 @@
|
|||||||
<p>Set models for saving and loading</p>
|
<p>Set models for saving and loading</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">99</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">get_modules</span><span class="p">(</span><span class="n">conf</span><span class="p">))</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-18'>
|
<div class='section' id='section-18'>
|
||||||
@ -316,7 +320,7 @@
|
|||||||
<p>Start the experiment</p>
|
<p>Start the experiment</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">102</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">111</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-19'>
|
<div class='section' id='section-19'>
|
||||||
@ -327,11 +331,11 @@
|
|||||||
<p>Run the training loop</p>
|
<p>Run the training loop</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
|
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span>
|
||||||
<span class="lineno">105</span>
|
<span class="lineno">114</span>
|
||||||
<span class="lineno">106</span>
|
<span class="lineno">115</span>
|
||||||
<span class="lineno">107</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
<span class="lineno">116</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||||
<span class="lineno">108</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
<span class="lineno">117</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -140,15 +140,19 @@ y^{(i)} &= \frac{1}{z^{(i)} \cdot \color{lightgreen}{\phi(q^{(i)})}}
|
|||||||
<p>The paper introduces a new linear attention projection function $\color{lightgreen}{\phi}$
|
<p>The paper introduces a new linear attention projection function $\color{lightgreen}{\phi}$
|
||||||
a new update rule for $\color{cyan}{W^{(i)}} = f(\color{cyan}{W^{(i-1)}})$ and change the normalization
|
a new update rule for $\color{cyan}{W^{(i)}} = f(\color{cyan}{W^{(i-1)}})$ and change the normalization
|
||||||
$\frac{1}{z^{(i)} \cdot \color{lightgreen}{\phi(q^{(i)})}}$</p>
|
$\frac{1}{z^{(i)} \cdot \color{lightgreen}{\phi(q^{(i)})}}$</p>
|
||||||
|
<p>Here’s <a href="experiment.html">the training code</a> and a notebook for training a fast weights
|
||||||
|
transformer on Tiny Shakespeare dataset.</p>
|
||||||
|
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
|
||||||
|
<a href="https://app.labml.ai/run/928aadc0846c11eb85710242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">86</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
<div class="highlight"><pre><span class="lineno">92</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||||
<span class="lineno">87</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
<span class="lineno">93</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||||
<span class="lineno">88</span>
|
<span class="lineno">94</span>
|
||||||
<span class="lineno">89</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
<span class="lineno">95</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
||||||
<span class="lineno">90</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
|
<span class="lineno">96</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
|
||||||
<span class="lineno">91</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">PrepareForMultiHeadAttention</span>
|
<span class="lineno">97</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">PrepareForMultiHeadAttention</span>
|
||||||
<span class="lineno">92</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
|
<span class="lineno">98</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-1'>
|
<div class='section' id='section-1'>
|
||||||
@ -183,7 +187,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.</p>
|
|||||||
<p><em>Check the paper for derivation.</em></p>
|
<p><em>Check the paper for derivation.</em></p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">95</span><span class="k">class</span> <span class="nc">DPFP</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">101</span><span class="k">class</span> <span class="nc">DPFP</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-2'>
|
<div class='section' id='section-2'>
|
||||||
@ -197,7 +201,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.</p>
|
|||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">129</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">nu</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">135</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">nu</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-3'>
|
<div class='section' id='section-3'>
|
||||||
@ -208,10 +212,10 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.</p>
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">134</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
<div class="highlight"><pre><span class="lineno">140</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||||
<span class="lineno">135</span> <span class="bp">self</span><span class="o">.</span><span class="n">nu</span> <span class="o">=</span> <span class="n">nu</span>
|
<span class="lineno">141</span> <span class="bp">self</span><span class="o">.</span><span class="n">nu</span> <span class="o">=</span> <span class="n">nu</span>
|
||||||
<span class="lineno">136</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span>
|
<span class="lineno">142</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span>
|
||||||
<span class="lineno">137</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span></pre></div>
|
<span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-4'>
|
<div class='section' id='section-4'>
|
||||||
@ -222,7 +226,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.</p>
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">145</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-5'>
|
<div class='section' id='section-5'>
|
||||||
@ -233,7 +237,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.</p>
|
|||||||
<p>Get $\color{lightgreen}{\phi(k)}$</p>
|
<p>Get $\color{lightgreen}{\phi(k)}$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dpfp</span><span class="p">(</span><span class="n">k</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dpfp</span><span class="p">(</span><span class="n">k</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-6'>
|
<div class='section' id='section-6'>
|
||||||
@ -244,7 +248,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.</p>
|
|||||||
<p>Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$</p>
|
<p>Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">return</span> <span class="n">k</span> <span class="o">/</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">149</span> <span class="k">return</span> <span class="n">k</span> <span class="o">/</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-7'>
|
<div class='section' id='section-7'>
|
||||||
@ -257,7 +261,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.</p>
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">145</span> <span class="k">def</span> <span class="nf">dpfp</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">151</span> <span class="k">def</span> <span class="nf">dpfp</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-8'>
|
<div class='section' id='section-8'>
|
||||||
@ -268,7 +272,7 @@ unless $k^{(i)}$ and $k^{(j)}$ are very similar.</p>
|
|||||||
<p>$x = \text{ReLU}\Big(\big[k, -k\big]\Big)$</p>
|
<p>$x = \text{ReLU}\Big(\big[k, -k\big]\Big)$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">k</span><span class="p">,</span> <span class="o">-</span><span class="n">k</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">k</span><span class="p">,</span> <span class="o">-</span><span class="n">k</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-9'>
|
<div class='section' id='section-9'>
|
||||||
@ -281,7 +285,7 @@ to get <script type="math/tex; mode=display">x'_{i,j} = \text{ReLU}\Big(\big[k,
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">x_rolled</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">roll</span><span class="p">(</span><span class="n">shifts</span><span class="o">=</span><span class="n">i</span><span class="p">,</span> <span class="n">dims</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nu</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">159</span> <span class="n">x_rolled</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">roll</span><span class="p">(</span><span class="n">shifts</span><span class="o">=</span><span class="n">i</span><span class="p">,</span> <span class="n">dims</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nu</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-10'>
|
<div class='section' id='section-10'>
|
||||||
@ -294,7 +298,7 @@ to get <script type="math/tex; mode=display">x'_{i,j} = \text{ReLU}\Big(\big[k,
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">x_rolled</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">x_rolled</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">x_rolled</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">x_rolled</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-11'>
|
<div class='section' id='section-11'>
|
||||||
@ -305,7 +309,7 @@ to get <script type="math/tex; mode=display">x'_{i,j} = \text{ReLU}\Big(\big[k,
|
|||||||
<p>Concatenate copies of $x$</p>
|
<p>Concatenate copies of $x$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">x_repeat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">nu</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">x_repeat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">nu</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-12'>
|
<div class='section' id='section-12'>
|
||||||
@ -320,7 +324,7 @@ to get <script type="math/tex; mode=display">x'_{i,j} = \text{ReLU}\Big(\big[k,
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">return</span> <span class="n">x_repeat</span> <span class="o">*</span> <span class="n">x_rolled</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">170</span> <span class="k">return</span> <span class="n">x_repeat</span> <span class="o">*</span> <span class="n">x_rolled</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-13'>
|
<div class='section' id='section-13'>
|
||||||
@ -352,7 +356,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Note that we don’t need the normalization term $z$ because $\color{lightgreen}{\phi’}$ is normalized.</p>
|
<p>Note that we don’t need the normalization term $z$ because $\color{lightgreen}{\phi’}$ is normalized.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">167</span><span class="k">class</span> <span class="nc">FastWeightsAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">173</span><span class="k">class</span> <span class="nc">FastWeightsAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-14'>
|
<div class='section' id='section-14'>
|
||||||
@ -363,8 +367,8 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">195</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">phi</span><span class="p">:</span> <span class="n">DPFP</span><span class="p">):</span>
|
<div class="highlight"><pre><span class="lineno">201</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">phi</span><span class="p">:</span> <span class="n">DPFP</span><span class="p">):</span>
|
||||||
<span class="lineno">196</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
<span class="lineno">202</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-15'>
|
<div class='section' id='section-15'>
|
||||||
@ -375,7 +379,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Number of features per head $d_k$</p>
|
<p>Number of features per head $d_k$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">199</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">heads</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">205</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">heads</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-16'>
|
<div class='section' id='section-16'>
|
||||||
@ -386,7 +390,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Number of heads</p>
|
<p>Number of heads</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">201</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">heads</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">207</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">heads</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-17'>
|
<div class='section' id='section-17'>
|
||||||
@ -397,9 +401,9 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>These transform the <code>query</code>, <code>key</code> and <code>value</code> multi-headed attention.</p>
|
<p>These transform the <code>query</code>, <code>key</code> and <code>value</code> multi-headed attention.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">204</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
<div class="highlight"><pre><span class="lineno">210</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||||
<span class="lineno">205</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
<span class="lineno">211</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||||
<span class="lineno">206</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
<span class="lineno">212</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-18'>
|
<div class='section' id='section-18'>
|
||||||
@ -410,10 +414,10 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Interpolation weight function $\sigma \Big(\color{orange}{W_\beta} x^{(i)} \Big)$ for each head</p>
|
<p>Interpolation weight function $\sigma \Big(\color{orange}{W_\beta} x^{(i)} \Big)$ for each head</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">209</span> <span class="bp">self</span><span class="o">.</span><span class="n">interpolation_weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
|
<div class="highlight"><pre><span class="lineno">215</span> <span class="bp">self</span><span class="o">.</span><span class="n">interpolation_weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
|
||||||
<span class="lineno">210</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
|
<span class="lineno">216</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
|
||||||
<span class="lineno">211</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()</span>
|
<span class="lineno">217</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()</span>
|
||||||
<span class="lineno">212</span> <span class="p">)</span></pre></div>
|
<span class="lineno">218</span> <span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-19'>
|
<div class='section' id='section-19'>
|
||||||
@ -424,7 +428,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>$\color{lightgreen}{\phi’}$</p>
|
<p>$\color{lightgreen}{\phi’}$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">215</span> <span class="bp">self</span><span class="o">.</span><span class="n">phi</span> <span class="o">=</span> <span class="n">phi</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">221</span> <span class="bp">self</span><span class="o">.</span><span class="n">phi</span> <span class="o">=</span> <span class="n">phi</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-20'>
|
<div class='section' id='section-20'>
|
||||||
@ -435,7 +439,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Output layer</p>
|
<p>Output layer</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">218</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">224</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-21'>
|
<div class='section' id='section-21'>
|
||||||
@ -446,7 +450,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Dropout</p>
|
<p>Dropout</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">220</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">226</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-22'>
|
<div class='section' id='section-22'>
|
||||||
@ -457,7 +461,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">222</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">228</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-23'>
|
<div class='section' id='section-23'>
|
||||||
@ -468,7 +472,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Get the number of steps $L$</p>
|
<p>Get the number of steps $L$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">224</span> <span class="n">seq_len</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">seq_len</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-24'>
|
<div class='section' id='section-24'>
|
||||||
@ -479,7 +483,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>$\color{lightgreen}{\phi’(q^{(i)})}$ for all steps and heads</p>
|
<p>$\color{lightgreen}{\phi’(q^{(i)})}$ for all steps and heads</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phi</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">232</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phi</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-25'>
|
<div class='section' id='section-25'>
|
||||||
@ -490,7 +494,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>$\color{lightgreen}{\phi’(k^{(i)})}$ for all steps and heads</p>
|
<p>$\color{lightgreen}{\phi’(k^{(i)})}$ for all steps and heads</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phi</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phi</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-26'>
|
<div class='section' id='section-26'>
|
||||||
@ -501,7 +505,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>$v^{(i)}$ for all steps and heads</p>
|
<p>$v^{(i)}$ for all steps and heads</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">236</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-27'>
|
<div class='section' id='section-27'>
|
||||||
@ -512,7 +516,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>$\beta^{(i)}$ for all steps and heads</p>
|
<p>$\beta^{(i)}$ for all steps and heads</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">232</span> <span class="n">beta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">interpolation_weight</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">238</span> <span class="n">beta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">interpolation_weight</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-28'>
|
<div class='section' id='section-28'>
|
||||||
@ -523,7 +527,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>$\color{cyan}{W^{(0)}}$</p>
|
<p>$\color{cyan}{W^{(0)}}$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">weights</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]))</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">241</span> <span class="n">weights</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]))</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-29'>
|
<div class='section' id='section-29'>
|
||||||
@ -534,7 +538,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>List to store outputs $y^{(i)}$</p>
|
<p>List to store outputs $y^{(i)}$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">237</span> <span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">243</span> <span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-30'>
|
<div class='section' id='section-30'>
|
||||||
@ -545,7 +549,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Iterate through steps</p>
|
<p>Iterate through steps</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">240</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">246</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-31'>
|
<div class='section' id='section-31'>
|
||||||
@ -558,7 +562,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">242</span> <span class="n">value_existing</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhvk,bhk->bhv'</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">key</span><span class="p">[</span><span class="n">i</span><span class="p">])</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">248</span> <span class="n">value_existing</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhvk,bhk->bhv'</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">key</span><span class="p">[</span><span class="n">i</span><span class="p">])</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-32'>
|
<div class='section' id='section-32'>
|
||||||
@ -573,7 +577,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">247</span> <span class="n">weights</span> <span class="o">=</span> <span class="n">weights</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhv,bhk->bhvk'</span><span class="p">,</span> <span class="n">beta</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">value</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="n">value_existing</span><span class="p">),</span> <span class="n">key</span><span class="p">[</span><span class="n">i</span><span class="p">])</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">253</span> <span class="n">weights</span> <span class="o">=</span> <span class="n">weights</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhv,bhk->bhvk'</span><span class="p">,</span> <span class="n">beta</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">value</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="n">value_existing</span><span class="p">),</span> <span class="n">key</span><span class="p">[</span><span class="n">i</span><span class="p">])</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-33'>
|
<div class='section' id='section-33'>
|
||||||
@ -586,7 +590,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">250</span> <span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhvk,bhk->bhv'</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">query</span><span class="p">[</span><span class="n">i</span><span class="p">])</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">256</span> <span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhvk,bhk->bhv'</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">query</span><span class="p">[</span><span class="n">i</span><span class="p">])</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-34'>
|
<div class='section' id='section-34'>
|
||||||
@ -597,7 +601,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Merge multiple heads and append to <code>outputs</code></p>
|
<p>Merge multiple heads and append to <code>outputs</code></p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">253</span> <span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">259</span> <span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-35'>
|
<div class='section' id='section-35'>
|
||||||
@ -608,7 +612,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Stack outputs at each step into a single tensor</p>
|
<p>Stack outputs at each step into a single tensor</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">256</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">262</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-36'>
|
<div class='section' id='section-36'>
|
||||||
@ -619,7 +623,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Output layer</p>
|
<p>Output layer</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">259</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">265</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-37'>
|
<div class='section' id='section-37'>
|
||||||
@ -630,7 +634,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>This is a general transformer layer that combines self attention and feedforward network.</p>
|
<p>This is a general transformer layer that combines self attention and feedforward network.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">262</span><span class="k">class</span> <span class="nc">FastWeightsAttentionTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">268</span><span class="k">class</span> <span class="nc">FastWeightsAttentionTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-38'>
|
<div class='section' id='section-38'>
|
||||||
@ -641,12 +645,12 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">266</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
<div class="highlight"><pre><span class="lineno">272</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||||||
<span class="lineno">267</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
<span class="lineno">273</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||||||
<span class="lineno">268</span> <span class="n">attn</span><span class="p">:</span> <span class="n">FastWeightsAttention</span><span class="p">,</span>
|
<span class="lineno">274</span> <span class="n">attn</span><span class="p">:</span> <span class="n">FastWeightsAttention</span><span class="p">,</span>
|
||||||
<span class="lineno">269</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
|
<span class="lineno">275</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
|
||||||
<span class="lineno">270</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
|
<span class="lineno">276</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
|
||||||
<span class="lineno">271</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
<span class="lineno">277</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-39'>
|
<div class='section' id='section-39'>
|
||||||
@ -657,7 +661,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Transformer size $d_{model}$</p>
|
<p>Transformer size $d_{model}$</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">273</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">279</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-40'>
|
<div class='section' id='section-40'>
|
||||||
@ -668,9 +672,9 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">275</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
|
<div class="highlight"><pre><span class="lineno">281</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
|
||||||
<span class="lineno">276</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
|
<span class="lineno">282</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
|
||||||
<span class="lineno">277</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span></pre></div>
|
<span class="lineno">283</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-41'>
|
<div class='section' id='section-41'>
|
||||||
@ -681,8 +685,8 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Normalization layers</p>
|
<p>Normalization layers</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">280</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
|
<div class="highlight"><pre><span class="lineno">286</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
|
||||||
<span class="lineno">281</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
|
<span class="lineno">287</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-42'>
|
<div class='section' id='section-42'>
|
||||||
@ -693,8 +697,8 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">283</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
<div class="highlight"><pre><span class="lineno">289</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||||||
<span class="lineno">284</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
<span class="lineno">290</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-43'>
|
<div class='section' id='section-43'>
|
||||||
@ -705,7 +709,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Add the self attention results</p>
|
<p>Add the self attention results</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">286</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">292</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-44'>
|
<div class='section' id='section-44'>
|
||||||
@ -716,7 +720,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Normalize for feed-forward</p>
|
<p>Normalize for feed-forward</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">289</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">295</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-45'>
|
<div class='section' id='section-45'>
|
||||||
@ -727,7 +731,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Pass through the feed-forward network</p>
|
<p>Pass through the feed-forward network</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">291</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">297</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-46'>
|
<div class='section' id='section-46'>
|
||||||
@ -738,7 +742,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Add the feed-forward results back</p>
|
<p>Add the feed-forward results back</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">293</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">299</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-47'>
|
<div class='section' id='section-47'>
|
||||||
@ -749,7 +753,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">296</span> <span class="k">return</span> <span class="n">x</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">302</span> <span class="k">return</span> <span class="n">x</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-48'>
|
<div class='section' id='section-48'>
|
||||||
@ -760,7 +764,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>This is a general transformer module with multiple transformer layers</p>
|
<p>This is a general transformer module with multiple transformer layers</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">299</span><span class="k">class</span> <span class="nc">FastWeightsAttentionTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">305</span><span class="k">class</span> <span class="nc">FastWeightsAttentionTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-49'>
|
<div class='section' id='section-49'>
|
||||||
@ -771,8 +775,8 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">303</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">FastWeightsAttentionTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
<div class="highlight"><pre><span class="lineno">309</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">FastWeightsAttentionTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||||
<span class="lineno">304</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
<span class="lineno">310</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-50'>
|
<div class='section' id='section-50'>
|
||||||
@ -783,7 +787,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Make copies of the transformer layer</p>
|
<p>Make copies of the transformer layer</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">306</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">312</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-51'>
|
<div class='section' id='section-51'>
|
||||||
@ -794,7 +798,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Final normalization layer</p>
|
<p>Final normalization layer</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">308</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">314</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-52'>
|
<div class='section' id='section-52'>
|
||||||
@ -805,8 +809,8 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">310</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
<div class="highlight"><pre><span class="lineno">316</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
||||||
<span class="lineno">311</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span></pre></div>
|
<span class="lineno">317</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-53'>
|
<div class='section' id='section-53'>
|
||||||
@ -817,7 +821,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Get layer output</p>
|
<p>Get layer output</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">313</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">319</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class='section' id='section-54'>
|
<div class='section' id='section-54'>
|
||||||
@ -828,7 +832,7 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})}
|
|||||||
<p>Normalize the output</p>
|
<p>Normalize the output</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">316</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">322</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -81,6 +81,12 @@ This is quite similar to fast weights.
|
|||||||
The paper introduces a new linear attention projection function $\color{lightgreen}{\phi}$
|
The paper introduces a new linear attention projection function $\color{lightgreen}{\phi}$
|
||||||
a new update rule for $\color{cyan}{W^{(i)}} = f(\color{cyan}{W^{(i-1)}})$ and change the normalization
|
a new update rule for $\color{cyan}{W^{(i)}} = f(\color{cyan}{W^{(i-1)}})$ and change the normalization
|
||||||
$\frac{1}{z^{(i)} \cdot \color{lightgreen}{\phi(q^{(i)})}}$
|
$\frac{1}{z^{(i)} \cdot \color{lightgreen}{\phi(q^{(i)})}}$
|
||||||
|
|
||||||
|
Here's [the training code](experiment.html) and a notebook for training a fast weights
|
||||||
|
transformer on Tiny Shakespeare dataset.
|
||||||
|
|
||||||
|
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb)
|
||||||
|
[](https://app.labml.ai/run/928aadc0846c11eb85710242ac1c0002)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -250,7 +256,7 @@ class FastWeightsAttention(Module):
|
|||||||
y = torch.einsum('bhvk,bhk->bhv', weights, query[i])
|
y = torch.einsum('bhvk,bhk->bhv', weights, query[i])
|
||||||
|
|
||||||
# Merge multiple heads and append to `outputs`
|
# Merge multiple heads and append to `outputs`
|
||||||
outputs.append(x.reshape(y.shape[0], -1))
|
outputs.append(y.reshape(y.shape[0], -1))
|
||||||
|
|
||||||
# Stack outputs at each step into a single tensor
|
# Stack outputs at each step into a single tensor
|
||||||
x = torch.stack(outputs)
|
x = torch.stack(outputs)
|
||||||
|
@ -3,6 +3,15 @@
|
|||||||
title: Train Fast Weights Transformer
|
title: Train Fast Weights Transformer
|
||||||
summary: This is training code with notes for a Fast Weights Transformer.
|
summary: This is training code with notes for a Fast Weights Transformer.
|
||||||
---
|
---
|
||||||
|
|
||||||
|
# Train Fast Weights Transformer
|
||||||
|
|
||||||
|
This trains a fast weights transformer model for auto-regression.
|
||||||
|
|
||||||
|
Here’s a Colab notebook for training a fast weights transformer on Tiny Shakespeare dataset.
|
||||||
|
|
||||||
|
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/fast_weights/experiment.ipynb)
|
||||||
|
[](https://app.labml.ai/run/928aadc0846c11eb85710242ac1c0002)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
|
|||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='labml-nn',
|
name='labml-nn',
|
||||||
version='0.4.89',
|
version='0.4.91',
|
||||||
author="Varuna Jayasiri, Nipun Wijerathne",
|
author="Varuna Jayasiri, Nipun Wijerathne",
|
||||||
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
||||||
description="A collection of PyTorch implementations of neural network architectures and layers.",
|
description="A collection of PyTorch implementations of neural network architectures and layers.",
|
||||||
|
Reference in New Issue
Block a user