mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 16:50:39 +08:00
RoPER (#126)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -16,3 +16,4 @@ html/
|
||||
diagrams/
|
||||
.comet.config
|
||||
settings.md
|
||||
labml_app.log
|
@ -19,3 +19,4 @@ indicators:
|
||||
name: optim.*
|
||||
options:
|
||||
comet: false
|
||||
web_api: http://localhost:5005/api/v1/track?
|
||||
|
900
docs/experiments/arithmetic_dataset.html
Normal file
900
docs/experiments/arithmetic_dataset.html
Normal file
@ -0,0 +1,900 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="This creates arithmetic problems."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Arithmetic Dataset"/>
|
||||
<meta name="twitter:description" content="This creates arithmetic problems."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/experiments/arithmetic_dataset.html"/>
|
||||
<meta property="og:title" content="Arithmetic Dataset"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Arithmetic Dataset"/>
|
||||
<meta property="og:description" content="This creates arithmetic problems."/>
|
||||
|
||||
<title>Arithmetic Dataset</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css?v=1">
|
||||
<link rel="canonical" href="https://nn.labml.ai/experiments/arithmetic_dataset.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="index.html">experiments</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/experiments/arithmetic_dataset.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<p><em>This is based on code by <a href="https://twitter.com/gharik">Georges Harik (@gharik)</a>.</em></p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">import</span> <span class="nn">random</span>
|
||||
<span class="lineno">12</span><span class="kn">import</span> <span class="nn">string</span>
|
||||
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
|
||||
<span class="lineno">14</span>
|
||||
<span class="lineno">15</span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
|
||||
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">Dataset</span>
|
||||
<span class="lineno">18</span>
|
||||
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">tracker</span>
|
||||
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
|
||||
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">,</span> <span class="n">transpose_batch</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>Arithmetic Dataset</h2>
|
||||
<p>This creates arithmetic addition problems and solutions with workings. We've only implemented addition so far.</p>
|
||||
<p>It's based on a character level tokenization.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">24</span><span class="k">class</span> <span class="nc">ArithmeticDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<ul><li><code>seq_len</code> is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch</li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">34</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_digits</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_sequences</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_sequences</span> <span class="o">=</span> <span class="n">n_sequences</span>
|
||||
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">=</span> <span class="n">max_digits</span>
|
||||
<span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span> <span class="o">=</span> <span class="n">seq_len</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Token id to string </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">itos</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">string</span><span class="o">.</span><span class="n">digits</span> <span class="o">+</span> <span class="s1">'xe =</span><span class="se">\n</span><span class="s1">?+;'</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Character to token id </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">stoi</span> <span class="o">=</span> <span class="p">{</span><span class="n">c</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">itos</span><span class="p">)}</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p> Generates an integer with <code class="highlight"><span></span><span class="n">n_digit</span></code>
|
||||
number of digits</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="nd">@staticmethod</span>
|
||||
<span class="lineno">50</span> <span class="k">def</span> <span class="nf">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">res</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="lineno">55</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_digits</span><span class="p">):</span>
|
||||
<span class="lineno">56</span> <span class="n">d</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
|
||||
<span class="lineno">57</span> <span class="n">res</span> <span class="o">=</span> <span class="n">res</span> <span class="o">*</span> <span class="mi">10</span> <span class="o">+</span> <span class="n">d</span>
|
||||
<span class="lineno">58</span>
|
||||
<span class="lineno">59</span> <span class="k">return</span> <span class="n">res</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p> Generates the workings for <code class="highlight"><span></span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span></code>
|
||||
. For example for <code class="highlight"><span></span><span class="mi">11</span><span class="o">+</span><span class="mi">29</span></code>
|
||||
it generates <code class="highlight"><span></span><span class="mf">1e0</span><span class="o">+</span><span class="mf">9e0</span><span class="o">+</span><span class="mf">0e0</span><span class="o">=</span><span class="mf">10e0</span> <span class="mf">1e0</span><span class="o">+</span><span class="mf">2e0</span><span class="o">+</span><span class="mf">1e0</span><span class="o">=</span><span class="mf">4e0</span></code>
|
||||
.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">61</span> <span class="nd">@staticmethod</span>
|
||||
<span class="lineno">62</span> <span class="k">def</span> <span class="nf">get_add_explanation</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">69</span> <span class="n">carry</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="lineno">70</span> <span class="n">e</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="lineno">71</span> <span class="n">explanation</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="lineno">72</span> <span class="k">while</span> <span class="n">x</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">y</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">carry</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="lineno">73</span> <span class="n">rx</span><span class="p">,</span> <span class="n">ry</span> <span class="o">=</span> <span class="n">x</span> <span class="o">%</span> <span class="mi">10</span><span class="p">,</span> <span class="n">y</span> <span class="o">%</span> <span class="mi">10</span>
|
||||
<span class="lineno">74</span> <span class="n">total</span> <span class="o">=</span> <span class="n">rx</span> <span class="o">+</span> <span class="n">ry</span> <span class="o">+</span> <span class="n">carry</span>
|
||||
<span class="lineno">75</span> <span class="n">explanation</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">rx</span><span class="si">}</span><span class="s2">e</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">+</span><span class="si">{</span><span class="n">ry</span><span class="si">}</span><span class="s2">e</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">+</span><span class="si">{</span><span class="n">carry</span><span class="si">}</span><span class="s2">e</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">==</span><span class="si">{</span><span class="n">total</span><span class="si">}</span><span class="s2">e</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="lineno">76</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">carry</span> <span class="o">=</span> <span class="n">x</span> <span class="o">//</span> <span class="mi">10</span><span class="p">,</span> <span class="n">y</span> <span class="o">//</span> <span class="mi">10</span><span class="p">,</span> <span class="n">total</span> <span class="o">//</span> <span class="mi">10</span>
|
||||
<span class="lineno">77</span> <span class="n">e</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="lineno">78</span>
|
||||
<span class="lineno">79</span> <span class="k">return</span> <span class="s1">' '</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">explanation</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Make a problem with a pre_explanation or not</p>
|
||||
<p>Creates an arithmetic addition problem with workings and answer.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">82</span> <span class="k">def</span> <span class="nf">make_add_problem</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">86</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="lineno">87</span> <span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="lineno">88</span>
|
||||
<span class="lineno">89</span> <span class="n">explanation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_add_explanation</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="lineno">90</span> <span class="k">return</span> <span class="sa">f</span><span class="s2">"x=</span><span class="si">{</span><span class="n">x</span><span class="si">}</span><span class="s2">+</span><span class="si">{</span><span class="n">y</span><span class="si">}</span><span class="s2">; </span><span class="si">{</span><span class="n">explanation</span><span class="si">}</span><span class="s2"> x==</span><span class="si">{</span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="si">}</span><span class="se">\n</span><span class="s2">"</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p> Get arithmetic problem and answer. This is used for evaluation.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">92</span> <span class="k">def</span> <span class="nf">get_qa</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="lineno">97</span> <span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||||
<span class="lineno">98</span>
|
||||
<span class="lineno">99</span> <span class="k">return</span> <span class="sa">f</span><span class="s1">'x=</span><span class="si">{</span><span class="n">x</span><span class="si">}</span><span class="s1">+</span><span class="si">{</span><span class="n">y</span><span class="si">}</span><span class="s1">;'</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="si">}</span><span class="s1">'</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p> Generate multiple problems and pack them into a sequence.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">def</span> <span class="nf">get_packed_math_input</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">s_enc</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="lineno">106</span> <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">s_enc</span><span class="p">)</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">:</span>
|
||||
<span class="lineno">107</span> <span class="n">s_part</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_add_problem</span><span class="p">()</span>
|
||||
<span class="lineno">108</span> <span class="n">s_part_enc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s1">'?'</span> <span class="o">+</span> <span class="n">s_part</span><span class="p">)</span>
|
||||
<span class="lineno">109</span> <span class="n">s_enc</span> <span class="o">=</span> <span class="n">s_enc</span> <span class="o">+</span> <span class="n">s_part_enc</span>
|
||||
<span class="lineno">110</span> <span class="k">return</span> <span class="n">s_enc</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p> Encode a given string</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">s</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">c</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">s</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p> Decode a list of token ids</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">118</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">arr</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">return</span> <span class="s1">''</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">c</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">arr</span><span class="p">])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p> Get a input and target pair for auto-regressive modelling</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">s</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_packed_math_input</span><span class="p">())</span>
|
||||
<span class="lineno">129</span> <span class="k">return</span> <span class="n">s</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">],</span> <span class="n">s</span><span class="p">[</span><span class="mi">1</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p> Number of sequences per epoch</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">135</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_sequences</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-24'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<h2>Arithmetic Task Experiment Configurations</h2>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">138</span><span class="k">class</span> <span class="nc">ArithmeticAutoregression</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-25'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Maximum number of digits per operand integer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">max_digits</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-26'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<p>Number of training sequences per epoch </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">train_sequences_per_epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">**</span> <span class="mi">12</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-27'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-27'>#</a>
|
||||
</div>
|
||||
<p>Training data loader </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="s1">'arithmetic_train_loader'</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-28'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-28'>#</a>
|
||||
</div>
|
||||
<p>Number of problems in evaluation </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">n_tests</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-29'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-29'>#</a>
|
||||
</div>
|
||||
<p>No need of a validation dataset </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">validator</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-30'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-30'>#</a>
|
||||
</div>
|
||||
<p>Number of times to run evaluations per epoch </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">inner_iterations</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-31'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-31'>#</a>
|
||||
</div>
|
||||
<p>Number of tokens in the vocabulary </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">155</span> <span class="n">n_tokens</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">ArithmeticDataset</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">itos</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-32'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-32'>#</a>
|
||||
</div>
|
||||
<h3>Evaluation</h3>
|
||||
<p>We use the sampling function to evaluate the model on a set of problems</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">157</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
|
||||
<span class="lineno">158</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-33'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-33'>#</a>
|
||||
</div>
|
||||
<p>Skip in the first epoch </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">166</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_loop</span><span class="o">.</span><span class="n">idx</span> <span class="o"><</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="lineno">167</span> <span class="k">return</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-34'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-34'>#</a>
|
||||
</div>
|
||||
<p>Create a dataset to generate problems </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">170</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ArithmeticDataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-35'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-35'>#</a>
|
||||
</div>
|
||||
<p>Get a set of problems and answers </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">qa</span> <span class="o">=</span> <span class="p">[</span><span class="n">dataset</span><span class="o">.</span><span class="n">get_qa</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_tests</span><span class="p">)]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-36'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-36'>#</a>
|
||||
</div>
|
||||
<p>Collect the problems only </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">174</span> <span class="n">questions</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">qa</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-37'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-37'>#</a>
|
||||
</div>
|
||||
<p>Create a tensor with only the initial token </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">177</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">questions</span><span class="p">]])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-38'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-38'>#</a>
|
||||
</div>
|
||||
<p>Move to device </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-39'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-39'>#</a>
|
||||
</div>
|
||||
<p>Number of sequences that have completed </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">finished</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">questions</span><span class="p">),))</span><span class="o">.</span><span class="n">bool</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-40'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-40'>#</a>
|
||||
</div>
|
||||
<p>Token id of the new line character - this marks end of the answer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">new_line</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="s1">'</span><span class="se">\n</span><span class="s1">'</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-41'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-41'>#</a>
|
||||
</div>
|
||||
<p>Sampled results </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">187</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">questions</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-42'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-42'>#</a>
|
||||
</div>
|
||||
<p>Sample upto sequence length </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">190</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">'Sample'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-43'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-43'>#</a>
|
||||
</div>
|
||||
<p>If all the sequences have completed we skip this </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">192</span> <span class="k">if</span> <span class="n">finished</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">finished</span><span class="p">):</span>
|
||||
<span class="lineno">193</span> <span class="k">continue</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-44'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-44'>#</a>
|
||||
</div>
|
||||
<p>Get the model output </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">output</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-45'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-45'>#</a>
|
||||
</div>
|
||||
<p>Get the model prediction (greedy) </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">198</span> <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-46'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-46'>#</a>
|
||||
</div>
|
||||
<p>Find which sequences have finished </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">finished</span> <span class="o">=</span> <span class="n">finished</span> <span class="o">|</span> <span class="p">(</span><span class="n">output</span> <span class="o">==</span> <span class="n">new_line</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-47'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-47'>#</a>
|
||||
</div>
|
||||
<p>Skip if all have finished </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">if</span> <span class="n">finished</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">finished</span><span class="p">):</span>
|
||||
<span class="lineno">204</span> <span class="k">continue</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-48'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-48'>#</a>
|
||||
</div>
|
||||
<p>Override with the question </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">207</span> <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">questions</span><span class="p">):</span>
|
||||
<span class="lineno">208</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="o">></span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="lineno">209</span> <span class="n">output</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-49'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-49'>#</a>
|
||||
</div>
|
||||
<p>Add the next token to the input </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">212</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">data</span><span class="p">,</span> <span class="n">output</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-50'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-50'>#</a>
|
||||
</div>
|
||||
<p>Get the sampled results </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">215</span> <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">output</span><span class="p">):</span>
|
||||
<span class="lineno">216</span> <span class="n">results</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">c</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-51'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-51'>#</a>
|
||||
</div>
|
||||
<p>Discard everything after the answer in the results </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'</span><span class="se">\n</span><span class="s1">'</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">results</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-52'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-52'>#</a>
|
||||
</div>
|
||||
<p>Log a sample </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">222</span> <span class="n">res_sample</span> <span class="o">=</span> <span class="n">results</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">';'</span><span class="p">)</span>
|
||||
<span class="lineno">223</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">([(</span><span class="n">res_sample</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">Text</span><span class="o">.</span><span class="n">key</span><span class="p">),</span> <span class="p">(</span><span class="s1">';'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="p">(</span><span class="s1">';'</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">res_sample</span><span class="p">[</span><span class="mi">1</span><span class="p">:]),</span> <span class="n">Text</span><span class="o">.</span><span class="n">none</span><span class="p">)])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-53'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-53'>#</a>
|
||||
</div>
|
||||
<p>Get the answers </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'x=='</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">results</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-54'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-54'>#</a>
|
||||
</div>
|
||||
<p>Count the number of correct answers </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">229</span> <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="lineno">230</span> <span class="k">for</span> <span class="n">r</span><span class="p">,</span> <span class="n">_qa</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">results</span><span class="p">,</span> <span class="n">qa</span><span class="p">):</span>
|
||||
<span class="lineno">231</span> <span class="k">if</span> <span class="n">r</span> <span class="o">==</span> <span class="n">_qa</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
|
||||
<span class="lineno">232</span> <span class="n">correct</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-55'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-55'>#</a>
|
||||
</div>
|
||||
<p>Log the score </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'score'</span><span class="p">,</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">results</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-56'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-56'>#</a>
|
||||
</div>
|
||||
<p> Training data loader</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">238</span><span class="nd">@option</span><span class="p">(</span><span class="n">ArithmeticAutoregression</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
|
||||
<span class="lineno">239</span><span class="k">def</span> <span class="nf">arithmetic_train_loader</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">ArithmeticAutoregression</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-57'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-57'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">243</span> <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">ArithmeticDataset</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">max_digits</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">train_sequences_per_epoch</span><span class="p">),</span>
|
||||
<span class="lineno">244</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="lineno">245</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">transpose_batch</span><span class="p">,</span>
|
||||
<span class="lineno">246</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-58'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-58'>#</a>
|
||||
</div>
|
||||
<p> Code to test generated problems</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">249</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-59'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-59'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">253</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ArithmeticDataset</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
|
||||
<span class="lineno">254</span>
|
||||
<span class="lineno">255</span> <span class="nb">print</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">get_packed_math_input</span><span class="p">()))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-60'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-60'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">259</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">260</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src=../interactive.js?v=1"></script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -70,7 +70,8 @@
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1><a href="index.html">DeepNorm</a> Experiment</h1>
|
||||
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
|
||||
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a> <a href="https://www.comet.ml/labml/deep-norm/61d817f80ff143c8825fba4aacd431d4?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=step"><img alt="Open In Comet" src="https://images.labml.ai/images/comet.svg?experiment=deep_norm&file=experiment"></a></p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">copy</span>
|
||||
|
@ -204,7 +204,7 @@
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/normalization/deep_norm/index.html</loc>
|
||||
<lastmod>2022-04-23T16:30:00+00:00</lastmod>
|
||||
<lastmod>2022-05-18T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
@ -244,6 +244,13 @@
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/experiments/arithmetic_dataset.html</loc>
|
||||
<lastmod>2022-06-02T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/experiments/index.html</loc>
|
||||
<lastmod>2020-12-26T16:30:00+00:00</lastmod>
|
||||
@ -603,14 +610,35 @@
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/transformers/rope/index.html</loc>
|
||||
<lastmod>2022-04-05T16:30:00+00:00</lastmod>
|
||||
<lastmod>2022-05-31T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/transformers/rope/value_pe/arithmetic_experiment.html</loc>
|
||||
<lastmod>2022-06-02T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/transformers/rope/value_pe/index.html</loc>
|
||||
<lastmod>2022-06-02T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/transformers/rope/value_pe/experiment.html</loc>
|
||||
<lastmod>2022-05-31T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/transformers/rope/experiment.html</loc>
|
||||
<lastmod>2022-03-12T16:30:00+00:00</lastmod>
|
||||
<lastmod>2022-05-31T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
@ -92,7 +92,7 @@
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">21</span><span class="k">def</span> <span class="nf">_rotary_pe_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
|
||||
<span class="lineno">22</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.rope</span> <span class="kn">import</span> <span class="n">RotaryPEMultiHeadAttention</span>
|
||||
<span class="lineno">23</span> <span class="k">return</span> <span class="n">RotaryPEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div>
|
||||
<span class="lineno">23</span> <span class="k">return</span> <span class="n">RotaryPEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
@ -157,7 +157,7 @@
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">46</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"rotary_pe_transformer"</span><span class="p">)</span></pre></div>
|
||||
<div class="highlight"><pre><span class="lineno">46</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"rotary_pe_transformer"</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">'screen'</span><span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
|
File diff suppressed because one or more lines are too long
403
docs/transformers/rope/value_pe/arithmetic_experiment.html
Normal file
403
docs/transformers/rope/value_pe/arithmetic_experiment.html
Normal file
@ -0,0 +1,403 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="This experiment trains a transformer model with Rotary Positional Embeddings with Relative Distance (RoPER) on the arithmetic addition task."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Rotary Positional Embeddings with Relative distance (RoPER) Experiment"/>
|
||||
<meta name="twitter:description" content="This experiment trains a transformer model with Rotary Positional Embeddings with Relative Distance (RoPER) on the arithmetic addition task."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/transformers/rope/value_pe/arithmetic_experiment.html"/>
|
||||
<meta property="og:title" content="Rotary Positional Embeddings with Relative distance (RoPER) Experiment"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Rotary Positional Embeddings with Relative distance (RoPER) Experiment"/>
|
||||
<meta property="og:description" content="This experiment trains a transformer model with Rotary Positional Embeddings with Relative Distance (RoPER) on the arithmetic addition task."/>
|
||||
|
||||
<title>Rotary Positional Embeddings with Relative distance (RoPER) Experiment</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../../pylit.css?v=1">
|
||||
<link rel="canonical" href="https://nn.labml.ai/transformers/rope/value_pe/arithmetic_experiment.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../../index.html">transformers</a>
|
||||
<a class="parent" href="../index.html">rope</a>
|
||||
<a class="parent" href="index.html">value_pe</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Rotary Positional Embeddings with Relative distance (<a href="index.html">RoPER</a>) Experiment</h1>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
|
||||
<span class="lineno">12</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">calculate</span>
|
||||
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.arithmetic_dataset</span> <span class="kn">import</span> <span class="n">ArithmeticAutoregression</span>
|
||||
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
|
||||
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.rope.experiment</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">RoPEConfigs</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<p> We inherit <a href="../experiment.html">RoPE experiment</a> and use it for <a href="../../experiments/arithmetic_dataset.html">arithmetic addition task</a>.</p>
|
||||
<p>We add the option to change attention to use Rotary Positional Embeddings with Relative distance (RoPER) below.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">RoPEConfigs</span><span class="p">,</span> <span class="n">ArithmeticAutoregression</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">26</span> <span class="k">pass</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p> Use Rotary Positional Embeddings with Relative distance (<a href="index.html">RoPER</a>) in attention.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">29</span><span class="k">def</span> <span class="nf">_rotary_value_pe_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">33</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.rope.value_pe</span> <span class="kn">import</span> <span class="n">RotaryValuePEMultiHeadAttention</span>
|
||||
<span class="lineno">34</span> <span class="k">return</span> <span class="n">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Configuration options </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">38</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">'rotary_value'</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span>
|
||||
<span class="lineno">39</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">'rotary_value'</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span>
|
||||
<span class="lineno">40</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">'rotary_value'</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">43</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Create experiment </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"roper_addition"</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s2">"rotary value 7"</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">'screen'</span><span class="p">,</span> <span class="s1">'labml'</span><span class="p">,</span> <span class="s1">'comet'</span><span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Create configs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>Override configurations </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
|
||||
<span class="lineno">50</span> <span class="s1">'max_digits'</span><span class="p">:</span> <span class="mi">7</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>No fixed positional embeddings </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">53</span> <span class="s1">'transformer.src_embed'</span><span class="p">:</span> <span class="s1">'no_pos'</span><span class="p">,</span>
|
||||
<span class="lineno">54</span> <span class="s1">'transformer.tgt_embed'</span><span class="p">:</span> <span class="s1">'no_pos'</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Encoder with RoPER attention </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">57</span> <span class="s1">'transformer.encoder_attn'</span><span class="p">:</span> <span class="s1">'rotary_value'</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Encoder with RoPE attention 'transformer.encoder_attn': 'rotary', </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">62</span> <span class="s1">'model'</span><span class="p">:</span> <span class="s1">'rotary_pe_transformer'</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Use a context size of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">256</span></span></span></span> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="s1">'seq_len'</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Train for 32 epochs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">67</span> <span class="s1">'epochs'</span><span class="p">:</span> <span class="mi">20</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Batch size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">4</span></span></span></span> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">69</span> <span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Model size </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="s1">'d_model'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
|
||||
<span class="lineno">73</span> <span class="s1">'transformer.ffn.d_ff'</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
|
||||
<span class="lineno">74</span> <span class="s1">'transformer.n_heads'</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
|
||||
<span class="lineno">75</span> <span class="s1">'transformer.dropout'</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Use <a href="../../optimizers/noam.html">Adam optimizer</a> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">78</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span>
|
||||
<span class="lineno">79</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
|
||||
<span class="lineno">80</span> <span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Set models for saving and loading </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">83</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">'model'</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Start the experiment </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">86</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Run training </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">92</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">93</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src=../../../interactive.js?v=1"></script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
454
docs/transformers/rope/value_pe/experiment.html
Normal file
454
docs/transformers/rope/value_pe/experiment.html
Normal file
@ -0,0 +1,454 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Rotary Positional Embeddings (RoPE) Experiment"/>
|
||||
<meta name="twitter:description" content="This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/transformers/rope/value_pe/experiment.html"/>
|
||||
<meta property="og:title" content="Rotary Positional Embeddings (RoPE) Experiment"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Rotary Positional Embeddings (RoPE) Experiment"/>
|
||||
<meta property="og:description" content="This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset."/>
|
||||
|
||||
<title>Rotary Positional Embeddings (RoPE) Experiment</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../../pylit.css?v=1">
|
||||
<link rel="canonical" href="https://nn.labml.ai/transformers/rope/value_pe/experiment.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../../index.html">transformers</a>
|
||||
<a class="parent" href="../index.html">rope</a>
|
||||
<a class="parent" href="index.html">value_pe</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/rope/value_pe/experiment.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Rotary Positional Embeddings (RoPE) Experiment</h1>
|
||||
<p>This is an annotated PyTorch experiment to train a transformer model with Rotary Positional Embeddings (RoPE).</p>
|
||||
<p><a href="https://app.labml.ai/run/1cf508e693be11ecacc98de8b38a61fe"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
|
||||
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">calculate</span>
|
||||
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
|
||||
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.rope.experiment</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">RoPEConfigs</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h3>Rotary PE attention</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">22</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">RoPEConfigs</span><span class="p">):</span> <span class="c1"># , ArithmeticAutoregression):</span>
|
||||
<span class="lineno">23</span> <span class="k">pass</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">26</span><span class="k">def</span> <span class="nf">_rotary_value_pe_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
|
||||
<span class="lineno">27</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.rope.value_pe</span> <span class="kn">import</span> <span class="n">RotaryValuePEMultiHeadAttention</span>
|
||||
<span class="lineno">28</span> <span class="k">return</span> <span class="n">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Configuration options </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">32</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">'rotary_value'</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span>
|
||||
<span class="lineno">33</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">'rotary_value'</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span>
|
||||
<span class="lineno">34</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">'rotary_value'</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">37</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Create experiment </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"rotary_shakespeare"</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s2">"rotary value"</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">'screen'</span><span class="p">,</span> <span class="s1">'labml'</span><span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Create configs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">41</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Override configurations </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">43</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>No fixed positional embeddings </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="s1">'transformer.src_embed'</span><span class="p">:</span> <span class="s1">'no_pos'</span><span class="p">,</span>
|
||||
<span class="lineno">46</span> <span class="s1">'transformer.tgt_embed'</span><span class="p">:</span> <span class="s1">'no_pos'</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Encoder with RoPE </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="s1">'transformer.encoder_attn'</span><span class="p">:</span> <span class="s1">'rotary_value'</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>'transformer.encoder_attn': 'rotary', </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">53</span> <span class="s1">'model'</span><span class="p">:</span> <span class="s1">'rotary_pe_transformer'</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p>Use character level tokenizer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">56</span> <span class="s1">'tokenizer'</span><span class="p">:</span> <span class="s1">'character'</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Prompt separator is blank </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">58</span> <span class="s1">'prompt_separator'</span><span class="p">:</span> <span class="s1">''</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Starting prompt for sampling </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">60</span> <span class="s1">'prompt'</span><span class="p">:</span> <span class="s1">'It is '</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Use Tiny Shakespeare dataset </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">62</span> <span class="s1">'text'</span><span class="p">:</span> <span class="s1">'tiny_shakespeare'</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Use a context size of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">256</span></span></span></span> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="s1">'seq_len'</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Train for 32 epochs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">67</span> <span class="s1">'epochs'</span><span class="p">:</span> <span class="mi">24</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Batch size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">4</span></span></span></span> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">69</span> <span class="s1">'batch_size'</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Switch between training and validation for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">10</span></span></span></span> times per epoch </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="s1">'inner_iterations'</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Model size </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="s1">'d_model'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
|
||||
<span class="lineno">76</span> <span class="s1">'transformer.ffn.d_ff'</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
|
||||
<span class="lineno">77</span> <span class="s1">'transformer.n_heads'</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
|
||||
<span class="lineno">78</span> <span class="s1">'transformer.dropout'</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">,</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>Use <a href="../../optimizers/noam.html">Adam optimizer</a> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">81</span> <span class="s1">'optimizer.optimizer'</span><span class="p">:</span> <span class="s1">'Adam'</span><span class="p">,</span>
|
||||
<span class="lineno">82</span> <span class="s1">'optimizer.learning_rate'</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
|
||||
<span class="lineno">83</span>
|
||||
<span class="lineno">84</span> <span class="s1">'dataloader_shuffle_with_replacement'</span><span class="p">:</span> <span class="kc">True</span>
|
||||
<span class="lineno">85</span> <span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>Set models for saving and loading </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">'model'</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-24'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>Start the experiment </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">91</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-25'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Run training </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-26'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">97</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">98</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src=../../../interactive.js?v=1"></script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
510
docs/transformers/rope/value_pe/index.html
Normal file
510
docs/transformers/rope/value_pe/index.html
Normal file
File diff suppressed because one or more lines are too long
260
labml_nn/experiments/arithmetic_dataset.py
Normal file
260
labml_nn/experiments/arithmetic_dataset.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""
|
||||
---
|
||||
title: Arithmetic Dataset
|
||||
summary: >
|
||||
This creates arithmetic problems.
|
||||
---
|
||||
|
||||
*This is based on code by [Georges Harik (@gharik)](https://twitter.com/gharik).*
|
||||
"""
|
||||
|
||||
import random
|
||||
import string
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from labml.logger import Text
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from labml import monit, logger, tracker
|
||||
from labml.configs import option
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
|
||||
|
||||
|
||||
class ArithmeticDataset(Dataset):
|
||||
"""
|
||||
## Arithmetic Dataset
|
||||
|
||||
This creates arithmetic addition problems and solutions with workings.
|
||||
We've only implemented addition so far.
|
||||
|
||||
It's based on a character level tokenization.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
|
||||
"""
|
||||
:param seq_len: is the sequence length of generated math problems.
|
||||
We fill as many problems as possible upto this length
|
||||
:max_digits: is the maximum number of digits in the operand integers
|
||||
:n_sequences: is the number of sequences per epoch
|
||||
"""
|
||||
self.n_sequences = n_sequences
|
||||
self.max_digits = max_digits
|
||||
self.seq_len = seq_len
|
||||
# Token id to string
|
||||
self.itos = list(string.digits + 'xe =\n?+;')
|
||||
# Character to token id
|
||||
self.stoi = {c: i for i, c in enumerate(self.itos)}
|
||||
|
||||
@staticmethod
|
||||
def make_int(n_digits: int):
|
||||
"""
|
||||
Generates an integer with `n_digit` number of digits
|
||||
"""
|
||||
res = 0
|
||||
for i in range(n_digits):
|
||||
d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
|
||||
res = res * 10 + d
|
||||
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def get_add_explanation(x: int, y: int):
|
||||
"""
|
||||
Generates the workings for `x + y`.
|
||||
For example for `11+29` it generates
|
||||
`1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0`.
|
||||
"""
|
||||
|
||||
carry = 0
|
||||
e = 0
|
||||
explanation = []
|
||||
while x > 0 or y > 0 or carry > 0:
|
||||
rx, ry = x % 10, y % 10
|
||||
total = rx + ry + carry
|
||||
explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
|
||||
x, y, carry = x // 10, y // 10, total // 10
|
||||
e += 1
|
||||
|
||||
return ' '.join(explanation)
|
||||
|
||||
# Make a problem with a pre_explanation or not
|
||||
def make_add_problem(self):
|
||||
"""
|
||||
Creates an arithmetic addition problem with workings and answer.
|
||||
"""
|
||||
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
|
||||
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
|
||||
|
||||
explanation = self.get_add_explanation(x, y)
|
||||
return f"x={x}+{y}; {explanation} x=={x + y}\n"
|
||||
|
||||
def get_qa(self):
|
||||
"""
|
||||
Get arithmetic problem and answer. This is used for evaluation.
|
||||
"""
|
||||
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
|
||||
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
|
||||
|
||||
return f'x={x}+{y};', f'{x + y}'
|
||||
|
||||
def get_packed_math_input(self):
|
||||
"""
|
||||
Generate multiple problems and pack them into a sequence.
|
||||
"""
|
||||
s_enc = []
|
||||
while len(s_enc) <= self.seq_len:
|
||||
s_part = self.make_add_problem()
|
||||
s_part_enc = self.encode('?' + s_part)
|
||||
s_enc = s_enc + s_part_enc
|
||||
return s_enc
|
||||
|
||||
def encode(self, s: str):
|
||||
"""
|
||||
Encode a given string
|
||||
"""
|
||||
return [self.stoi[c] for c in s]
|
||||
|
||||
def decode(self, arr: List[int]):
|
||||
"""
|
||||
Decode a list of token ids
|
||||
"""
|
||||
return ''.join([self.itos[c] for c in arr])
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
"""
|
||||
Get a input and target pair for auto-regressive modelling
|
||||
"""
|
||||
s = torch.tensor(self.get_packed_math_input())
|
||||
return s[:self.seq_len], s[1:self.seq_len + 1]
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of sequences per epoch
|
||||
"""
|
||||
return self.n_sequences
|
||||
|
||||
|
||||
class ArithmeticAutoregression(NLPAutoRegressionConfigs):
|
||||
"""
|
||||
## Arithmetic Task Experiment Configurations
|
||||
"""
|
||||
# Maximum number of digits per operand integer
|
||||
max_digits: int = 4
|
||||
# Number of training sequences per epoch
|
||||
train_sequences_per_epoch: int = 2 ** 12
|
||||
# Training data loader
|
||||
train_loader: DataLoader = 'arithmetic_train_loader'
|
||||
# Number of problems in evaluation
|
||||
n_tests: int = 64
|
||||
# No need of a validation dataset
|
||||
validator = None
|
||||
# Number of times to run evaluations per epoch
|
||||
inner_iterations = 4
|
||||
# Number of tokens in the vocabulary
|
||||
n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self):
|
||||
"""
|
||||
### Evaluation
|
||||
|
||||
We use the sampling function to evaluate the model on a set of problems
|
||||
"""
|
||||
|
||||
# Skip in the first epoch
|
||||
if self.training_loop.idx < 1:
|
||||
return
|
||||
|
||||
# Create a dataset to generate problems
|
||||
dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
|
||||
# Get a set of problems and answers
|
||||
qa = [dataset.get_qa() for _ in range(self.n_tests)]
|
||||
# Collect the problems only
|
||||
questions = [p[0] for p in qa]
|
||||
|
||||
# Create a tensor with only the initial token
|
||||
data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
|
||||
# Move to device
|
||||
data = data.to(self.device)
|
||||
|
||||
# Number of sequences that have completed
|
||||
finished = torch.zeros((len(questions),)).bool().to(self.device)
|
||||
# Token id of the new line character - this marks end of the answer
|
||||
new_line = dataset.stoi['\n']
|
||||
|
||||
# Sampled results
|
||||
results = [p[0] for p in questions]
|
||||
|
||||
# Sample upto sequence length
|
||||
for i in monit.iterate('Sample', self.seq_len - 1):
|
||||
# If all the sequences have completed we skip this
|
||||
if finished.sum() == len(finished):
|
||||
continue
|
||||
|
||||
# Get the model output
|
||||
output, *_ = self.model(data)
|
||||
# Get the model prediction (greedy)
|
||||
output = output[-1].argmax(dim=-1)
|
||||
|
||||
# Find which sequences have finished
|
||||
finished = finished | (output == new_line)
|
||||
# Skip if all have finished
|
||||
if finished.sum() == len(finished):
|
||||
continue
|
||||
|
||||
# Override with the question
|
||||
for j, p in enumerate(questions):
|
||||
if len(p) > i + 1:
|
||||
output[j] = dataset.stoi[p[i + 1]]
|
||||
|
||||
# Add the next token to the input
|
||||
data = torch.cat([data, output[None, :]], dim=0)
|
||||
|
||||
# Get the sampled results
|
||||
for j, c in enumerate(output):
|
||||
results[j] += dataset.itos[c]
|
||||
|
||||
# Discard everything after the answer in the results
|
||||
results = [r.split('\n')[0] for r in results]
|
||||
|
||||
# Log a sample
|
||||
res_sample = results[0].split(';')
|
||||
logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
|
||||
|
||||
# Get the answers
|
||||
results = [r.split('x==')[-1] for r in results]
|
||||
|
||||
# Count the number of correct answers
|
||||
correct = 0
|
||||
for r, _qa in zip(results, qa):
|
||||
if r == _qa[1]:
|
||||
correct += 1
|
||||
|
||||
# Log the score
|
||||
tracker.save('score', correct / len(results))
|
||||
|
||||
|
||||
@option(ArithmeticAutoregression.train_loader)
|
||||
def arithmetic_train_loader(c: ArithmeticAutoregression):
|
||||
"""
|
||||
Training data loader
|
||||
"""
|
||||
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
|
||||
batch_size=c.batch_size,
|
||||
collate_fn=transpose_batch,
|
||||
num_workers=4)
|
||||
|
||||
|
||||
def _test():
|
||||
"""
|
||||
Code to test generated problems
|
||||
"""
|
||||
dataset = ArithmeticDataset(256, 8, 10)
|
||||
|
||||
print(dataset.decode(dataset.get_packed_math_input()))
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
_test()
|
@ -115,37 +115,67 @@ class RotaryPositionalEmbeddings(nn.Module):
|
||||
\end{pmatrix} \\
|
||||
\end{align}
|
||||
"""
|
||||
|
||||
def __init__(self, d: int, base: int = 10_000):
|
||||
"""
|
||||
* `d` is the number of features $d$
|
||||
* `base` is the constant used for calculating $\Theta$
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.base = base
|
||||
self.d = d
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
|
||||
def _build_cache(self, x: torch.Tensor):
|
||||
"""
|
||||
Cache $\cos$ and $\sin$ values
|
||||
"""
|
||||
# Return if cache is already built
|
||||
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
|
||||
return
|
||||
|
||||
# Get sequence length
|
||||
seq_len = x.shape[0]
|
||||
|
||||
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||||
self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
||||
"""
|
||||
# Extract the shape
|
||||
seq_len, batch_size, n_heads, d = x.shape
|
||||
|
||||
# $\frac{d}{2}$
|
||||
d_2 = d // 2
|
||||
theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
|
||||
|
||||
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||||
seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
|
||||
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
||||
|
||||
# Calculate the product of position index and $\theta_i$
|
||||
idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
|
||||
idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
|
||||
|
||||
# Concatenate so that for row $m$ we have
|
||||
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
|
||||
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
||||
|
||||
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
|
||||
neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
|
||||
# Cache them
|
||||
self.cos_cached = idx_theta2.cos()[:, None, None, :]
|
||||
self.sin_cached = idx_theta2.sin()[:, None, None, :]
|
||||
|
||||
def _neg_half(self, x: torch.Tensor):
|
||||
# $\frac{d}{2}$
|
||||
d_2 = self.d // 2
|
||||
|
||||
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
||||
"""
|
||||
# Cache $\cos$ and $\sin$ values
|
||||
self._build_cache(x)
|
||||
|
||||
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
||||
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
|
||||
|
||||
# Calculate
|
||||
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
neg_half_x = self._neg_half(x_rope)
|
||||
|
||||
# Calculate
|
||||
#
|
||||
@ -157,10 +187,10 @@ class RotaryPositionalEmbeddings(nn.Module):
|
||||
# \end{align}
|
||||
#
|
||||
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
||||
rx = (x * idx_theta2.cos()[:, None, None, :]) + (neg_half_x * idx_theta2.sin()[:, None, None, :])
|
||||
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
|
||||
|
||||
#
|
||||
return rx
|
||||
return torch.cat((x_rope, x_pass), dim=-1)
|
||||
|
||||
|
||||
class RotaryPEMultiHeadAttention(MultiHeadAttention):
|
||||
@ -170,15 +200,13 @@ class RotaryPEMultiHeadAttention(MultiHeadAttention):
|
||||
We override [multi-head attention from original transformer](../mha.html).
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||
# The linear transformations do not need a bias since we
|
||||
# explicitly include it when calculating scores.
|
||||
# However having a bias for `value` might make sense.
|
||||
super().__init__(heads, d_model, dropout_prob, bias=False)
|
||||
def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
|
||||
super().__init__(heads, d_model, dropout_prob)
|
||||
|
||||
# Rotary positional embedding layers
|
||||
self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
||||
self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k)
|
||||
d_rope = int(self.d_k * rope_percentage)
|
||||
self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||
self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
|
||||
|
||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||
"""
|
||||
|
@ -20,7 +20,7 @@ from labml_nn.transformers.basic.autoregressive_experiment import Autoregressive
|
||||
# ### Rotary PE attention
|
||||
def _rotary_pe_mha(c: TransformerConfigs):
|
||||
from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
|
||||
return RotaryPEMultiHeadAttention(c.n_heads, c.d_model)
|
||||
return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 1.)
|
||||
|
||||
|
||||
# Configuration options
|
||||
@ -43,7 +43,7 @@ def _model(c: Configs):
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name="rotary_pe_transformer")
|
||||
experiment.create(name="rotary_pe_transformer", writers={'screen'})
|
||||
# Create configs
|
||||
conf = Configs()
|
||||
# Override configurations
|
||||
|
246
labml_nn/transformers/rope/value_pe/__init__.py
Normal file
246
labml_nn/transformers/rope/value_pe/__init__.py
Normal file
@ -0,0 +1,246 @@
|
||||
"""
|
||||
---
|
||||
title: Rotary Positional Embeddings with Relative distance (RoPER)
|
||||
summary: >
|
||||
This is an implementation of RoPER which adds relative distance information to embeddings on
|
||||
top of RoPE introduced in RoFormer: Enhanced Transformer with Rotary Position Embedding
|
||||
---
|
||||
|
||||
*RoPER is work by [Georges Harik (@gharik)](https://twitter.com/gharik),
|
||||
and this implementation is based on his original code.*
|
||||
|
||||
# Rotary Positional Embeddings with Relative distance (RoPER)
|
||||
|
||||
[Rotary Positional Embeddings (RoPE)](https://papers.labml.ai/paper/2104.09864) includes
|
||||
relative positions in attention score calculation.
|
||||
However, the embeddings themselves do not get any positional information
|
||||
, [except what it can get implicitly from causal attention](https://papers.labml.ai/paper/2c364684b15b11ecac827bce58715ee7).
|
||||
|
||||
RoPER adds relative positional information explicitly to value embeddings.
|
||||
Specifically, it adds the relative positions of the tokens it paid attention to.
|
||||
We use same rotary positional embeddings to rotate the values in attention,
|
||||
Then, after taking the weighted sum,
|
||||
we rotate the final in the opposite direction.
|
||||
Which is equivalent to rotating each of the values (before attention) relative to the current position.
|
||||
|
||||
Here's [the training code](experiment.html) for training a transformer model with RoPER
|
||||
on an arithmetic addition where we can see significant improvement over RoPE.
|
||||
|
||||
### Relative distances in embeddings
|
||||
|
||||
For any head, let $a_{n,i}$ be the attention from position $n$ to position $i$,
|
||||
and $v_i$ be the value embeddings at position $i$. Let's denote individual features
|
||||
as $v^{(1)}_i, v^{(2)}_i, \dots$.
|
||||
|
||||
Normally, we would take the weight sum of value embeddings
|
||||
|
||||
$$o^{(j)}_n = \sum_i a_{n,i} v^{(j)}_i$$
|
||||
|
||||
This doesn't explicitly add any distance information about the positions $i$ to final
|
||||
result $o^{(j)}_n$.
|
||||
|
||||
RoPER pairs features like RoPE and transform them.
|
||||
For a pair $v^{(1)}_m$ and $v^{(2)}_m$ it transforms them by
|
||||
$RoPE\big(v^{(1)}_m, v^{(2)}_m, m\big)$.
|
||||
Let us donate the transformed features with $\hat{v}^{(1)}_m, \hat{v}^{(2)}_m$.
|
||||
Then it rotates the weighted sum $\hat{o}^{(j)}_n$ in the the reverse direction with
|
||||
$RoPE\big(\hat{o}^{(1)}_n, \hat{o}^{(2)}_n, -n\big)$.
|
||||
*Note the *$-n$.
|
||||
|
||||
Note that,
|
||||
|
||||
\begin{align}
|
||||
RoPE\big(x^{(1)}_m, x^{(2)}_m, m\big) &=
|
||||
\begin{pmatrix}
|
||||
\cos m \theta & - \sin m \theta \\
|
||||
\sin m \theta & \cos m \theta
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
x^{(1)}_m \\
|
||||
x^{(2)}_m \\
|
||||
\end{pmatrix} \\
|
||||
&=
|
||||
\begin{pmatrix}
|
||||
x^{(1)}_m \cos m\theta - x^{(2)}_m \sin m \theta \\
|
||||
x^{(2)}_m \cos m\theta + x^{(1)}_m \sin m \theta \\
|
||||
\end{pmatrix} \\
|
||||
\end{align}
|
||||
|
||||
Final output after with the transformations is,
|
||||
|
||||
\begin{align}
|
||||
RoPE\big(\hat{o}^{(1)}_n, \hat{o}^{(2)}_n, -n\big) &= \\
|
||||
\begin{pmatrix}
|
||||
\hat{o}^{(1)}_n \cos n\theta + \hat{o}^{(2)}_n \sin n \theta \\
|
||||
\hat{o}^{(2)}_n \cos n\theta - \hat{o}^{(1)}_n \sin n \theta \\
|
||||
\end{pmatrix} \\
|
||||
\end{align}
|
||||
|
||||
*Note that *$\sin (-n \theta) = -\sin n \theta$.
|
||||
|
||||
Let's expand the first term $\hat{o}^{(1)}_n \cos n\theta + \hat{o}^{(2)}_n \sin n \theta$,
|
||||
|
||||
\begin{align}
|
||||
\hat{o}^{(1)}_n \cos n\theta + \hat{o}^{(2)}_n \sin n \theta &= \\
|
||||
\sum_i a_{n,i} \hat{v}^{(1)}_i \cos n\theta + \sum_i a_{n,i} \hat{v}^{(2)}_i \sin n \theta &= \\
|
||||
|
||||
\sum_i a_{n,i} \Big( v^{(1)}_i \cos i\theta - v^{(2)}_i \sin i \theta \Big) \cos n\theta &+ \\
|
||||
\sum_i a_{n,i} \Big( v^{(2)}_i \cos i\theta + v^{(1)}_i \sin i \theta \Big) \sin m \theta &= \\
|
||||
|
||||
\sum_i a_{n,i} v^{(1)}_i \Big( \cos i\theta \cos n\theta + \sin i \theta \sin n \theta \Big) &+ \\
|
||||
\sum_i a_{n,i} v^{(2)}_i \Big( \cos i\theta \sin n\theta - \sin i \theta \cos n \theta \Big) &= \\
|
||||
|
||||
\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta - \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta &= \\
|
||||
|
||||
\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta - \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta
|
||||
\end{align}
|
||||
|
||||
Simiarly we can show the second term is equal to,
|
||||
|
||||
$$\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta + \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta$$
|
||||
|
||||
Which gives,
|
||||
|
||||
\begin{align}
|
||||
RoPE\big(\hat{o}^{(1)}_n, \hat{o}^{(2)}_n, -n\big) &= \\
|
||||
\begin{pmatrix}
|
||||
\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta - \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta \\
|
||||
\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta + \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta \\
|
||||
\end{pmatrix} &= \\
|
||||
\sum_i a_{n,i} RoPE \big (v^{(1)}_i, v^{(1)}_i, (i - n) \theta \big)
|
||||
\end{align}
|
||||
|
||||
That is, the weighted average of values rotated relative to current position.
|
||||
|
||||
[Here's an experiment](arithmetic_experiment.html) that uses RoPER on an arthmetic addition task.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from labml_nn.transformers.rope import RotaryPositionalEmbeddings, RotaryPEMultiHeadAttention
|
||||
|
||||
|
||||
class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
|
||||
"""
|
||||
## RoPE module that rotates in the opposite direction
|
||||
|
||||
This inherits from [RoPE rotation implementation](../index.html) and changes the direction.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
||||
"""
|
||||
# Cache $\cos$ and $\sin$ values
|
||||
self._build_cache(x)
|
||||
|
||||
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
||||
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
|
||||
|
||||
# Calculate
|
||||
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
neg_half_x = self._neg_half(x_rope)
|
||||
|
||||
# Calculate
|
||||
#
|
||||
# \begin{align}
|
||||
# \begin{pmatrix}
|
||||
# x^{(i)}_m \cos -m \theta_i - x^{(i + \frac{d}{2})}_m \sin -m \theta_i \\
|
||||
# x^{(i + \frac{d}{2})}_m \cos -m\theta_i + x^{(i)}_m \sin -m \theta_i \\
|
||||
# \end{pmatrix} = \\
|
||||
# \begin{pmatrix}
|
||||
# x^{(i)}_m \cos m \theta_i + x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
|
||||
# x^{(i + \frac{d}{2})}_m \cos m\theta_i - x^{(i)}_m \sin m \theta_i \\
|
||||
# \end{pmatrix} \\
|
||||
# \end{align}
|
||||
#
|
||||
# for $i \in {1, 2, ..., \frac{d}{2}}$
|
||||
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
|
||||
|
||||
#
|
||||
return torch.cat((x_rope, x_pass), dim=-1)
|
||||
|
||||
|
||||
class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):
|
||||
"""
|
||||
## Multi-head attention with rotary positional embeddings
|
||||
|
||||
We override [multi-head attention from original transformer](../mha.html).
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int,
|
||||
rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
|
||||
dropout_prob: float = 0.0):
|
||||
super().__init__(heads, d_model, rope_percentage, dropout_prob)
|
||||
|
||||
# Rotary positional embedding layers
|
||||
d_rope_value = int(self.d_k * rope_value_percentage)
|
||||
|
||||
self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
|
||||
self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
|
||||
|
||||
def forward(self, *,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
`query`, `key` and `value` are the tensors that store
|
||||
collection of *query*, *key* and *value* vectors.
|
||||
They have shape `[seq_len, batch_size, d_model]`.
|
||||
|
||||
`mask` has shape `[seq_len, seq_len, batch_size]` and
|
||||
`mask[i, j, b]` indicates whether for batch `b`,
|
||||
query at position `i` has access to key-value at position `j`.
|
||||
"""
|
||||
|
||||
# `query`, `key` and `value` have shape `[seq_len, batch_size, d_model]`
|
||||
seq_len, batch_size, _ = query.shape
|
||||
|
||||
if mask is not None:
|
||||
mask = self.prepare_mask(mask, query.shape, key.shape)
|
||||
|
||||
# Prepare `query`, `key` and `value` for attention computation.
|
||||
# These will then have shape `[seq_len, batch_size, heads, d_k]`.
|
||||
query = self.query(query)
|
||||
key = self.key(key)
|
||||
value = self.value(value)
|
||||
|
||||
# Compute attention scores $Q K^\top$.
|
||||
# This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
|
||||
scores = self.get_scores(query, key)
|
||||
|
||||
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
|
||||
scores *= self.scale
|
||||
|
||||
# Apply mask
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, float('-inf'))
|
||||
|
||||
# $softmax$ attention along the key sequence dimension
|
||||
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
|
||||
attn = self.softmax(scores)
|
||||
|
||||
# Apply dropout
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# Rotate value embeddings before taking the weighted sum so that they contain positional information
|
||||
value = self.value_rotary_pe(value)
|
||||
|
||||
# Multiply by values
|
||||
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
||||
x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
|
||||
|
||||
# Rotate in the opposite direction so that each embedding hold the relative positions
|
||||
x = self.value_reverse_rotary_pe(x)
|
||||
|
||||
# Save attentions for any other calculations
|
||||
self.attn = attn.detach()
|
||||
|
||||
# Concatenate multiple heads
|
||||
x = x.reshape(seq_len, batch_size, -1)
|
||||
|
||||
# Output layer
|
||||
return self.output(x)
|
93
labml_nn/transformers/rope/value_pe/arithmetic_experiment.py
Normal file
93
labml_nn/transformers/rope/value_pe/arithmetic_experiment.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""
|
||||
---
|
||||
title: Rotary Positional Embeddings with Relative distance (RoPER) Experiment
|
||||
summary: This experiment trains a transformer model with Rotary Positional Embeddings with
|
||||
Relative Distance (RoPER) on the arithmetic addition task.
|
||||
---
|
||||
|
||||
# Rotary Positional Embeddings with Relative distance ([RoPER](index.html)) Experiment
|
||||
"""
|
||||
|
||||
from labml import experiment
|
||||
from labml.configs import calculate
|
||||
from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression
|
||||
from labml_nn.transformers import TransformerConfigs
|
||||
from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
|
||||
|
||||
|
||||
class Configs(RoPEConfigs, ArithmeticAutoregression):
|
||||
"""
|
||||
We inherit [RoPE experiment](../experiment.html) and use it for
|
||||
[arithmetic addition task](../../experiments/arithmetic_dataset.html).
|
||||
|
||||
We add the option to change attention to use Rotary Positional Embeddings with Relative distance (RoPER)
|
||||
below.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _rotary_value_pe_mha(c: TransformerConfigs):
|
||||
"""
|
||||
Use Rotary Positional Embeddings with Relative distance ([RoPER](index.html)) in attention.
|
||||
"""
|
||||
from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
|
||||
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
|
||||
|
||||
|
||||
# Configuration options
|
||||
calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha)
|
||||
calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha)
|
||||
calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha)
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name="roper_addition", comment="rotary value 7", writers={'screen', 'labml', 'comet'})
|
||||
# Create configs
|
||||
conf = Configs()
|
||||
# Override configurations
|
||||
experiment.configs(conf, {
|
||||
'max_digits': 7,
|
||||
|
||||
# No fixed positional embeddings
|
||||
'transformer.src_embed': 'no_pos',
|
||||
'transformer.tgt_embed': 'no_pos',
|
||||
|
||||
# Encoder with RoPER attention
|
||||
'transformer.encoder_attn': 'rotary_value',
|
||||
# Encoder with RoPE attention
|
||||
# 'transformer.encoder_attn': 'rotary',
|
||||
|
||||
#
|
||||
'model': 'rotary_pe_transformer',
|
||||
|
||||
# Use a context size of $256$
|
||||
'seq_len': 512,
|
||||
# Train for 32 epochs
|
||||
'epochs': 20,
|
||||
# Batch size $4$
|
||||
'batch_size': 16,
|
||||
|
||||
# Model size
|
||||
'd_model': 128,
|
||||
'transformer.ffn.d_ff': 512,
|
||||
'transformer.n_heads': 4,
|
||||
'transformer.dropout': 0.0,
|
||||
|
||||
# Use [Adam optimizer](../../optimizers/noam.html)
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 2.5e-4,
|
||||
})
|
||||
|
||||
# Set models for saving and loading
|
||||
experiment.add_pytorch_models({'model': conf.model})
|
||||
|
||||
# Start the experiment
|
||||
with experiment.start():
|
||||
# Run training
|
||||
conf.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
98
labml_nn/transformers/rope/value_pe/experiment.py
Normal file
98
labml_nn/transformers/rope/value_pe/experiment.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
---
|
||||
title: Rotary Positional Embeddings (RoPE) Experiment
|
||||
summary: This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset.
|
||||
---
|
||||
|
||||
# Rotary Positional Embeddings (RoPE) Experiment
|
||||
|
||||
This is an annotated PyTorch experiment to train a transformer model with Rotary Positional Embeddings (RoPE).
|
||||
|
||||
[](https://app.labml.ai/run/1cf508e693be11ecacc98de8b38a61fe)
|
||||
"""
|
||||
|
||||
from labml import experiment
|
||||
from labml.configs import calculate
|
||||
from labml_nn.transformers import TransformerConfigs
|
||||
from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
|
||||
|
||||
|
||||
# ### Rotary PE attention
|
||||
|
||||
class Configs(RoPEConfigs): # , ArithmeticAutoregression):
|
||||
pass
|
||||
|
||||
|
||||
def _rotary_value_pe_mha(c: TransformerConfigs):
|
||||
from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
|
||||
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
|
||||
|
||||
|
||||
# Configuration options
|
||||
calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha)
|
||||
calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha)
|
||||
calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha)
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name="rotary_shakespeare", comment="rotary value", writers={'screen', 'labml'})
|
||||
# Create configs
|
||||
conf = Configs()
|
||||
# Override configurations
|
||||
experiment.configs(conf, {
|
||||
# No fixed positional embeddings
|
||||
'transformer.src_embed': 'no_pos',
|
||||
'transformer.tgt_embed': 'no_pos',
|
||||
|
||||
# Encoder with RoPE
|
||||
'transformer.encoder_attn': 'rotary_value',
|
||||
# 'transformer.encoder_attn': 'rotary',
|
||||
|
||||
#
|
||||
'model': 'rotary_pe_transformer',
|
||||
|
||||
# Use character level tokenizer
|
||||
'tokenizer': 'character',
|
||||
# Prompt separator is blank
|
||||
'prompt_separator': '',
|
||||
# Starting prompt for sampling
|
||||
'prompt': 'It is ',
|
||||
# Use Tiny Shakespeare dataset
|
||||
'text': 'tiny_shakespeare',
|
||||
|
||||
# Use a context size of $256$
|
||||
'seq_len': 512,
|
||||
# Train for 32 epochs
|
||||
'epochs': 24,
|
||||
# Batch size $4$
|
||||
'batch_size': 16,
|
||||
# Switch between training and validation for $10$ times
|
||||
# per epoch
|
||||
'inner_iterations': 4,
|
||||
|
||||
# Model size
|
||||
'd_model': 128,
|
||||
'transformer.ffn.d_ff': 512,
|
||||
'transformer.n_heads': 4,
|
||||
'transformer.dropout': 0.0,
|
||||
|
||||
# Use [Adam optimizer](../../optimizers/noam.html)
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 2.5e-4,
|
||||
|
||||
'dataloader_shuffle_with_replacement': True
|
||||
})
|
||||
|
||||
# Set models for saving and loading
|
||||
experiment.add_pytorch_models({'model': conf.model})
|
||||
|
||||
# Start the experiment
|
||||
with experiment.start():
|
||||
# Run training
|
||||
conf.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
6
setup.py
6
setup.py
@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
|
||||
|
||||
setuptools.setup(
|
||||
name='labml-nn',
|
||||
version='0.4.122',
|
||||
version='0.4.123',
|
||||
author="Varuna Jayasiri, Nipun Wijerathne",
|
||||
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
||||
description="🧑🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",
|
||||
@ -20,8 +20,8 @@ setuptools.setup(
|
||||
'labml_helpers', 'labml_helpers.*',
|
||||
'test',
|
||||
'test.*')),
|
||||
install_requires=['labml>=0.4.147',
|
||||
'labml-helpers>=0.4.84',
|
||||
install_requires=['labml>=0.4.151',
|
||||
'labml-helpers>=0.4.86',
|
||||
'torch',
|
||||
'torchtext',
|
||||
'torchvision',
|
||||
|
Reference in New Issue
Block a user