This commit is contained in:
Varuna Jayasiri
2022-06-03 21:29:41 +05:30
committed by GitHub
parent c5d9235280
commit 0ce65adf9e
17 changed files with 3207 additions and 105 deletions

3
.gitignore vendored
View File

@ -15,4 +15,5 @@ logs
html/ html/
diagrams/ diagrams/
.comet.config .comet.config
settings.md settings.md
labml_app.log

View File

@ -19,3 +19,4 @@ indicators:
name: optim.* name: optim.*
options: options:
comet: false comet: false
web_api: http://localhost:5005/api/v1/track?

View File

@ -0,0 +1,900 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This creates arithmetic problems."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Arithmetic Dataset"/>
<meta name="twitter:description" content="This creates arithmetic problems."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/experiments/arithmetic_dataset.html"/>
<meta property="og:title" content="Arithmetic Dataset"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Arithmetic Dataset"/>
<meta property="og:description" content="This creates arithmetic problems."/>
<title>Arithmetic Dataset</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/experiments/arithmetic_dataset.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">experiments</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/experiments/arithmetic_dataset.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<p><em>This is based on code by <a href="https://twitter.com/gharik">Georges Harik (@gharik)</a>.</em></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">import</span> <span class="nn">random</span>
<span class="lineno">12</span><span class="kn">import</span> <span class="nn">string</span>
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
<span class="lineno">14</span>
<span class="lineno">15</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">Dataset</span>
<span class="lineno">18</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">tracker</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">,</span> <span class="n">transpose_batch</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Arithmetic Dataset</h2>
<p>This creates arithmetic addition problems and solutions with workings. We&#x27;ve only implemented addition so far.</p>
<p>It&#x27;s based on a character level tokenization.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">24</span><span class="k">class</span> <span class="nc">ArithmeticDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code>seq_len</code> is the sequence length of generated math problems. We fill as many problems as possible upto this length :max_digits: is the maximum number of digits in the operand integers :n_sequences: is the number of sequences per epoch</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">max_digits</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_sequences</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_sequences</span> <span class="o">=</span> <span class="n">n_sequences</span>
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">=</span> <span class="n">max_digits</span>
<span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span> <span class="o">=</span> <span class="n">seq_len</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Token id to string </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">itos</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">string</span><span class="o">.</span><span class="n">digits</span> <span class="o">+</span> <span class="s1">&#39;xe =</span><span class="se">\n</span><span class="s1">?+;&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Character to token id </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">stoi</span> <span class="o">=</span> <span class="p">{</span><span class="n">c</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">itos</span><span class="p">)}</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p> Generates an integer with <code class="highlight"><span></span><span class="n">n_digit</span></code>
number of digits</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="nd">@staticmethod</span>
<span class="lineno">50</span> <span class="k">def</span> <span class="nf">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">res</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">55</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_digits</span><span class="p">):</span>
<span class="lineno">56</span> <span class="n">d</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="lineno">57</span> <span class="n">res</span> <span class="o">=</span> <span class="n">res</span> <span class="o">*</span> <span class="mi">10</span> <span class="o">+</span> <span class="n">d</span>
<span class="lineno">58</span>
<span class="lineno">59</span> <span class="k">return</span> <span class="n">res</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p> Generates the workings for <code class="highlight"><span></span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span></code>
. For example for <code class="highlight"><span></span><span class="mi">11</span><span class="o">+</span><span class="mi">29</span></code>
it generates <code class="highlight"><span></span><span class="mf">1e0</span><span class="o">+</span><span class="mf">9e0</span><span class="o">+</span><span class="mf">0e0</span><span class="o">=</span><span class="mf">10e0</span> <span class="mf">1e0</span><span class="o">+</span><span class="mf">2e0</span><span class="o">+</span><span class="mf">1e0</span><span class="o">=</span><span class="mf">4e0</span></code>
.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="nd">@staticmethod</span>
<span class="lineno">62</span> <span class="k">def</span> <span class="nf">get_add_explanation</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="n">carry</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">70</span> <span class="n">e</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">71</span> <span class="n">explanation</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">72</span> <span class="k">while</span> <span class="n">x</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">y</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">carry</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">73</span> <span class="n">rx</span><span class="p">,</span> <span class="n">ry</span> <span class="o">=</span> <span class="n">x</span> <span class="o">%</span> <span class="mi">10</span><span class="p">,</span> <span class="n">y</span> <span class="o">%</span> <span class="mi">10</span>
<span class="lineno">74</span> <span class="n">total</span> <span class="o">=</span> <span class="n">rx</span> <span class="o">+</span> <span class="n">ry</span> <span class="o">+</span> <span class="n">carry</span>
<span class="lineno">75</span> <span class="n">explanation</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">rx</span><span class="si">}</span><span class="s2">e</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">+</span><span class="si">{</span><span class="n">ry</span><span class="si">}</span><span class="s2">e</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">+</span><span class="si">{</span><span class="n">carry</span><span class="si">}</span><span class="s2">e</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">==</span><span class="si">{</span><span class="n">total</span><span class="si">}</span><span class="s2">e</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="lineno">76</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">carry</span> <span class="o">=</span> <span class="n">x</span> <span class="o">//</span> <span class="mi">10</span><span class="p">,</span> <span class="n">y</span> <span class="o">//</span> <span class="mi">10</span><span class="p">,</span> <span class="n">total</span> <span class="o">//</span> <span class="mi">10</span>
<span class="lineno">77</span> <span class="n">e</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="lineno">78</span>
<span class="lineno">79</span> <span class="k">return</span> <span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">explanation</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Make a problem with a pre_explanation or not</p>
<p>Creates an arithmetic addition problem with workings and answer.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="k">def</span> <span class="nf">make_add_problem</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">87</span> <span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">88</span>
<span class="lineno">89</span> <span class="n">explanation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_add_explanation</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="lineno">90</span> <span class="k">return</span> <span class="sa">f</span><span class="s2">&quot;x=</span><span class="si">{</span><span class="n">x</span><span class="si">}</span><span class="s2">+</span><span class="si">{</span><span class="n">y</span><span class="si">}</span><span class="s2">; </span><span class="si">{</span><span class="n">explanation</span><span class="si">}</span><span class="s2"> x==</span><span class="si">{</span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="si">}</span><span class="se">\n</span><span class="s2">&quot;</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p> Get arithmetic problem and answer. This is used for evaluation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="k">def</span> <span class="nf">get_qa</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">97</span> <span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_int</span><span class="p">(</span><span class="n">n_digits</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">randrange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">98</span>
<span class="lineno">99</span> <span class="k">return</span> <span class="sa">f</span><span class="s1">&#39;x=</span><span class="si">{</span><span class="n">x</span><span class="si">}</span><span class="s1">+</span><span class="si">{</span><span class="n">y</span><span class="si">}</span><span class="s1">;&#39;</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="si">}</span><span class="s1">&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p> Generate multiple problems and pack them into a sequence.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">def</span> <span class="nf">get_packed_math_input</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">s_enc</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">106</span> <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">s_enc</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">:</span>
<span class="lineno">107</span> <span class="n">s_part</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_add_problem</span><span class="p">()</span>
<span class="lineno">108</span> <span class="n">s_part_enc</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s1">&#39;?&#39;</span> <span class="o">+</span> <span class="n">s_part</span><span class="p">)</span>
<span class="lineno">109</span> <span class="n">s_enc</span> <span class="o">=</span> <span class="n">s_enc</span> <span class="o">+</span> <span class="n">s_part_enc</span>
<span class="lineno">110</span> <span class="k">return</span> <span class="n">s_enc</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p> Encode a given string</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">s</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">c</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">s</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p> Decode a list of token ids</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">arr</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">return</span> <span class="s1">&#39;&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">c</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">arr</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p> Get a input and target pair for auto-regressive modelling</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">s</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_packed_math_input</span><span class="p">())</span>
<span class="lineno">129</span> <span class="k">return</span> <span class="n">s</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">],</span> <span class="n">s</span><span class="p">[</span><span class="mi">1</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p> Number of sequences per epoch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_sequences</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<h2>Arithmetic Task Experiment Configurations</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span><span class="k">class</span> <span class="nc">ArithmeticAutoregression</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Maximum number of digits per operand integer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">max_digits</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Number of training sequences per epoch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">train_sequences_per_epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">**</span> <span class="mi">12</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Training data loader </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="s1">&#39;arithmetic_train_loader&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Number of problems in evaluation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">n_tests</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>No need of a validation dataset </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">validator</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Number of times to run evaluations per epoch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">inner_iterations</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Number of tokens in the vocabulary </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">155</span> <span class="n">n_tokens</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">ArithmeticDataset</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">itos</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<h3>Evaluation</h3>
<p>We use the sampling function to evaluate the model on a set of problems</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">157</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">158</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Skip in the first epoch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_loop</span><span class="o">.</span><span class="n">idx</span> <span class="o">&lt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">167</span> <span class="k">return</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Create a dataset to generate problems </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ArithmeticDataset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_digits</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Get a set of problems and answers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">qa</span> <span class="o">=</span> <span class="p">[</span><span class="n">dataset</span><span class="o">.</span><span class="n">get_qa</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_tests</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Collect the problems only </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span> <span class="n">questions</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">qa</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Create a tensor with only the initial token </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">177</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">questions</span><span class="p">]])</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Move to device </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Number of sequences that have completed </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">finished</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">questions</span><span class="p">),))</span><span class="o">.</span><span class="n">bool</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Token id of the new line character - this marks end of the answer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">new_line</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Sampled results </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">questions</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Sample upto sequence length </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Sample&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>If all the sequences have completed we skip this </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">192</span> <span class="k">if</span> <span class="n">finished</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">finished</span><span class="p">):</span>
<span class="lineno">193</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Get the model output </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">output</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Get the model prediction (greedy) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Find which sequences have finished </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">finished</span> <span class="o">=</span> <span class="n">finished</span> <span class="o">|</span> <span class="p">(</span><span class="n">output</span> <span class="o">==</span> <span class="n">new_line</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>Skip if all have finished </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">if</span> <span class="n">finished</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">finished</span><span class="p">):</span>
<span class="lineno">204</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>Override with the question </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">questions</span><span class="p">):</span>
<span class="lineno">208</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">p</span><span class="p">)</span> <span class="o">&gt;</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:</span>
<span class="lineno">209</span> <span class="n">output</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">stoi</span><span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Add the next token to the input </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">data</span><span class="p">,</span> <span class="n">output</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Get the sampled results </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span> <span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">output</span><span class="p">):</span>
<span class="lineno">216</span> <span class="n">results</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">c</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Discard everything after the answer in the results </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">results</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>Log a sample </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="n">res_sample</span> <span class="o">=</span> <span class="n">results</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;;&#39;</span><span class="p">)</span>
<span class="lineno">223</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">([(</span><span class="n">res_sample</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">Text</span><span class="o">.</span><span class="n">key</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;;&#39;</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;;&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">res_sample</span><span class="p">[</span><span class="mi">1</span><span class="p">:]),</span> <span class="n">Text</span><span class="o">.</span><span class="n">none</span><span class="p">)])</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Get the answers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;x==&#39;</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">results</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Count the number of correct answers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">229</span> <span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span>
<span class="lineno">230</span> <span class="k">for</span> <span class="n">r</span><span class="p">,</span> <span class="n">_qa</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">results</span><span class="p">,</span> <span class="n">qa</span><span class="p">):</span>
<span class="lineno">231</span> <span class="k">if</span> <span class="n">r</span> <span class="o">==</span> <span class="n">_qa</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
<span class="lineno">232</span> <span class="n">correct</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Log the score </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">&#39;score&#39;</span><span class="p">,</span> <span class="n">correct</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">results</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p> Training data loader</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">238</span><span class="nd">@option</span><span class="p">(</span><span class="n">ArithmeticAutoregression</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
<span class="lineno">239</span><span class="k">def</span> <span class="nf">arithmetic_train_loader</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">ArithmeticAutoregression</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">ArithmeticDataset</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">max_digits</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">train_sequences_per_epoch</span><span class="p">),</span>
<span class="lineno">244</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
<span class="lineno">245</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">transpose_batch</span><span class="p">,</span>
<span class="lineno">246</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p> Code to test generated problems</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">249</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">253</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">ArithmeticDataset</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<span class="lineno">254</span>
<span class="lineno">255</span> <span class="nb">print</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">get_packed_math_input</span><span class="p">()))</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">259</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">260</span> <span class="n">_test</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -70,7 +70,8 @@
<a href='#section-0'>#</a> <a href='#section-0'>#</a>
</div> </div>
<h1><a href="index.html">DeepNorm</a> Experiment</h1> <h1><a href="index.html">DeepNorm</a> Experiment</h1>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p> <p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a> <a href="https://www.comet.ml/labml/deep-norm/61d817f80ff143c8825fba4aacd431d4?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=step"><img alt="Open In Comet" src="https://images.labml.ai/images/comet.svg?experiment=deep_norm&file=experiment"></a></p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">copy</span> <div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">copy</span>

View File

@ -204,7 +204,7 @@
<url> <url>
<loc>https://nn.labml.ai/normalization/deep_norm/index.html</loc> <loc>https://nn.labml.ai/normalization/deep_norm/index.html</loc>
<lastmod>2022-04-23T16:30:00+00:00</lastmod> <lastmod>2022-05-18T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
@ -244,6 +244,13 @@
</url> </url>
<url>
<loc>https://nn.labml.ai/experiments/arithmetic_dataset.html</loc>
<lastmod>2022-06-02T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url> <url>
<loc>https://nn.labml.ai/experiments/index.html</loc> <loc>https://nn.labml.ai/experiments/index.html</loc>
<lastmod>2020-12-26T16:30:00+00:00</lastmod> <lastmod>2020-12-26T16:30:00+00:00</lastmod>
@ -603,14 +610,35 @@
<url> <url>
<loc>https://nn.labml.ai/transformers/rope/index.html</loc> <loc>https://nn.labml.ai/transformers/rope/index.html</loc>
<lastmod>2022-04-05T16:30:00+00:00</lastmod> <lastmod>2022-05-31T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/rope/value_pe/arithmetic_experiment.html</loc>
<lastmod>2022-06-02T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/rope/value_pe/index.html</loc>
<lastmod>2022-06-02T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/transformers/rope/value_pe/experiment.html</loc>
<lastmod>2022-05-31T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>
<url> <url>
<loc>https://nn.labml.ai/transformers/rope/experiment.html</loc> <loc>https://nn.labml.ai/transformers/rope/experiment.html</loc>
<lastmod>2022-03-12T16:30:00+00:00</lastmod> <lastmod>2022-05-31T16:30:00+00:00</lastmod>
<priority>1.00</priority> <priority>1.00</priority>
</url> </url>

View File

@ -92,7 +92,7 @@
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">21</span><span class="k">def</span> <span class="nf">_rotary_pe_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span> <div class="highlight"><pre><span class="lineno">21</span><span class="k">def</span> <span class="nf">_rotary_pe_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">22</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.rope</span> <span class="kn">import</span> <span class="n">RotaryPEMultiHeadAttention</span> <span class="lineno">22</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.rope</span> <span class="kn">import</span> <span class="n">RotaryPEMultiHeadAttention</span>
<span class="lineno">23</span> <span class="k">return</span> <span class="n">RotaryPEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)</span></pre></div> <span class="lineno">23</span> <span class="k">return</span> <span class="n">RotaryPEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-2'> <div class='section' id='section-2'>
@ -157,7 +157,7 @@
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;rotary_pe_transformer&quot;</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">46</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;rotary_pe_transformer&quot;</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;screen&#39;</span><span class="p">})</span></pre></div>
</div> </div>
</div> </div>
<div class='section' id='section-7'> <div class='section' id='section-7'>

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,403 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This experiment trains a transformer model with Rotary Positional Embeddings with Relative Distance (RoPER) on the arithmetic addition task."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Rotary Positional Embeddings with Relative distance (RoPER) Experiment"/>
<meta name="twitter:description" content="This experiment trains a transformer model with Rotary Positional Embeddings with Relative Distance (RoPER) on the arithmetic addition task."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/rope/value_pe/arithmetic_experiment.html"/>
<meta property="og:title" content="Rotary Positional Embeddings with Relative distance (RoPER) Experiment"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Rotary Positional Embeddings with Relative distance (RoPER) Experiment"/>
<meta property="og:description" content="This experiment trains a transformer model with Rotary Positional Embeddings with Relative Distance (RoPER) on the arithmetic addition task."/>
<title>Rotary Positional Embeddings with Relative distance (RoPER) Experiment</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/transformers/rope/value_pe/arithmetic_experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../../index.html">transformers</a>
<a class="parent" href="../index.html">rope</a>
<a class="parent" href="index.html">value_pe</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/rope/value_pe/arithmetic_experiment.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Rotary Positional Embeddings with Relative distance (<a href="index.html">RoPER</a>) Experiment</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">12</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">calculate</span>
<span class="lineno">13</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.arithmetic_dataset</span> <span class="kn">import</span> <span class="n">ArithmeticAutoregression</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.rope.experiment</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">RoPEConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p> We inherit <a href="../experiment.html">RoPE experiment</a> and use it for <a href="../../experiments/arithmetic_dataset.html">arithmetic addition task</a>.</p>
<p>We add the option to change attention to use Rotary Positional Embeddings with Relative distance (RoPER) below.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">RoPEConfigs</span><span class="p">,</span> <span class="n">ArithmeticAutoregression</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">26</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p> Use Rotary Positional Embeddings with Relative distance (<a href="index.html">RoPER</a>) in attention.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span><span class="k">def</span> <span class="nf">_rotary_value_pe_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.rope.value_pe</span> <span class="kn">import</span> <span class="n">RotaryValuePEMultiHeadAttention</span>
<span class="lineno">34</span> <span class="k">return</span> <span class="n">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Configuration options </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">&#39;rotary_value&#39;</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span>
<span class="lineno">39</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">&#39;rotary_value&#39;</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span>
<span class="lineno">40</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">&#39;rotary_value&#39;</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Create experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;roper_addition&quot;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s2">&quot;rotary value 7&quot;</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;screen&#39;</span><span class="p">,</span> <span class="s1">&#39;labml&#39;</span><span class="p">,</span> <span class="s1">&#39;comet&#39;</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Create configs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Override configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">50</span> <span class="s1">&#39;max_digits&#39;</span><span class="p">:</span> <span class="mi">7</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>No fixed positional embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="s1">&#39;transformer.src_embed&#39;</span><span class="p">:</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">,</span>
<span class="lineno">54</span> <span class="s1">&#39;transformer.tgt_embed&#39;</span><span class="p">:</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Encoder with RoPER attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="s1">&#39;transformer.encoder_attn&#39;</span><span class="p">:</span> <span class="s1">&#39;rotary_value&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Encoder with RoPE attention &#x27;transformer.encoder_attn&#x27;: &#x27;rotary&#x27;, </p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="s1">&#39;rotary_pe_transformer&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Use a context size of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">256</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Train for 32 epochs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">20</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Batch size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">4</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Model size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="s1">&#39;d_model&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
<span class="lineno">73</span> <span class="s1">&#39;transformer.ffn.d_ff&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
<span class="lineno">74</span> <span class="s1">&#39;transformer.n_heads&#39;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
<span class="lineno">75</span> <span class="s1">&#39;transformer.dropout&#39;</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Use <a href="../../optimizers/noam.html">Adam optimizer</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">79</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
<span class="lineno">80</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Set models for saving and loading </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Start the experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Run training </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">93</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -0,0 +1,454 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Rotary Positional Embeddings (RoPE) Experiment"/>
<meta name="twitter:description" content="This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/rope/value_pe/experiment.html"/>
<meta property="og:title" content="Rotary Positional Embeddings (RoPE) Experiment"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Rotary Positional Embeddings (RoPE) Experiment"/>
<meta property="og:description" content="This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset."/>
<title>Rotary Positional Embeddings (RoPE) Experiment</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/transformers/rope/value_pe/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../../index.html">transformers</a>
<a class="parent" href="../index.html">rope</a>
<a class="parent" href="index.html">value_pe</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/rope/value_pe/experiment.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Rotary Positional Embeddings (RoPE) Experiment</h1>
<p>This is an annotated PyTorch experiment to train a transformer model with Rotary Positional Embeddings (RoPE).</p>
<p><a href="https://app.labml.ai/run/1cf508e693be11ecacc98de8b38a61fe"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">calculate</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.rope.experiment</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">RoPEConfigs</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h3>Rotary PE attention</h3>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">RoPEConfigs</span><span class="p">):</span> <span class="c1"># , ArithmeticAutoregression):</span>
<span class="lineno">23</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">26</span><span class="k">def</span> <span class="nf">_rotary_value_pe_mha</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TransformerConfigs</span><span class="p">):</span>
<span class="lineno">27</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.rope.value_pe</span> <span class="kn">import</span> <span class="n">RotaryValuePEMultiHeadAttention</span>
<span class="lineno">28</span> <span class="k">return</span> <span class="n">RotaryValuePEMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Configuration options </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">encoder_attn</span><span class="p">,</span> <span class="s1">&#39;rotary_value&#39;</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span>
<span class="lineno">33</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_attn</span><span class="p">,</span> <span class="s1">&#39;rotary_value&#39;</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span>
<span class="lineno">34</span><span class="n">calculate</span><span class="p">(</span><span class="n">TransformerConfigs</span><span class="o">.</span><span class="n">decoder_mem_attn</span><span class="p">,</span> <span class="s1">&#39;rotary_value&#39;</span><span class="p">,</span> <span class="n">_rotary_value_pe_mha</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Create experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;rotary_shakespeare&quot;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s2">&quot;rotary value&quot;</span><span class="p">,</span> <span class="n">writers</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;screen&#39;</span><span class="p">,</span> <span class="s1">&#39;labml&#39;</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Create configs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Override configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">43</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>No fixed positional embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">45</span> <span class="s1">&#39;transformer.src_embed&#39;</span><span class="p">:</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">,</span>
<span class="lineno">46</span> <span class="s1">&#39;transformer.tgt_embed&#39;</span><span class="p">:</span> <span class="s1">&#39;no_pos&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Encoder with RoPE </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="s1">&#39;transformer.encoder_attn&#39;</span><span class="p">:</span> <span class="s1">&#39;rotary_value&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>&#x27;transformer.encoder_attn&#x27;: &#x27;rotary&#x27;, </p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="s1">&#39;rotary_pe_transformer&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Use character level tokenizer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="s1">&#39;tokenizer&#39;</span><span class="p">:</span> <span class="s1">&#39;character&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Prompt separator is blank </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="s1">&#39;prompt_separator&#39;</span><span class="p">:</span> <span class="s1">&#39;&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Starting prompt for sampling </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="s1">&#39;prompt&#39;</span><span class="p">:</span> <span class="s1">&#39;It is &#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Use Tiny Shakespeare dataset </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="s1">&#39;text&#39;</span><span class="p">:</span> <span class="s1">&#39;tiny_shakespeare&#39;</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Use a context size of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">256</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Train for 32 epochs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">24</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Batch size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">4</span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">16</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Switch between training and validation for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">10</span></span></span></span> times per epoch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">72</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Model size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="s1">&#39;d_model&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
<span class="lineno">76</span> <span class="s1">&#39;transformer.ffn.d_ff&#39;</span><span class="p">:</span> <span class="mi">512</span><span class="p">,</span>
<span class="lineno">77</span> <span class="s1">&#39;transformer.n_heads&#39;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
<span class="lineno">78</span> <span class="s1">&#39;transformer.dropout&#39;</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Use <a href="../../optimizers/noam.html">Adam optimizer</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">82</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
<span class="lineno">83</span>
<span class="lineno">84</span> <span class="s1">&#39;dataloader_shuffle_with_replacement&#39;</span><span class="p">:</span> <span class="kc">True</span>
<span class="lineno">85</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Set models for saving and loading </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Start the experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Run training </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">98</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,260 @@
"""
---
title: Arithmetic Dataset
summary: >
This creates arithmetic problems.
---
*This is based on code by [Georges Harik (@gharik)](https://twitter.com/gharik).*
"""
import random
import string
from typing import List
import torch
from labml.logger import Text
from torch.utils.data import DataLoader, Dataset
from labml import monit, logger, tracker
from labml.configs import option
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs, transpose_batch
class ArithmeticDataset(Dataset):
"""
## Arithmetic Dataset
This creates arithmetic addition problems and solutions with workings.
We've only implemented addition so far.
It's based on a character level tokenization.
"""
def __init__(self, seq_len: int, max_digits: int, n_sequences: int):
"""
:param seq_len: is the sequence length of generated math problems.
We fill as many problems as possible upto this length
:max_digits: is the maximum number of digits in the operand integers
:n_sequences: is the number of sequences per epoch
"""
self.n_sequences = n_sequences
self.max_digits = max_digits
self.seq_len = seq_len
# Token id to string
self.itos = list(string.digits + 'xe =\n?+;')
# Character to token id
self.stoi = {c: i for i, c in enumerate(self.itos)}
@staticmethod
def make_int(n_digits: int):
"""
Generates an integer with `n_digit` number of digits
"""
res = 0
for i in range(n_digits):
d = random.randrange(1, 11) if i == 0 else random.randrange(0, 11)
res = res * 10 + d
return res
@staticmethod
def get_add_explanation(x: int, y: int):
"""
Generates the workings for `x + y`.
For example for `11+29` it generates
`1e0+9e0+0e0=10e0 1e0+2e0+1e0=4e0`.
"""
carry = 0
e = 0
explanation = []
while x > 0 or y > 0 or carry > 0:
rx, ry = x % 10, y % 10
total = rx + ry + carry
explanation.append(f"{rx}e{e}+{ry}e{e}+{carry}e{e}=={total}e{e}")
x, y, carry = x // 10, y // 10, total // 10
e += 1
return ' '.join(explanation)
# Make a problem with a pre_explanation or not
def make_add_problem(self):
"""
Creates an arithmetic addition problem with workings and answer.
"""
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
explanation = self.get_add_explanation(x, y)
return f"x={x}+{y}; {explanation} x=={x + y}\n"
def get_qa(self):
"""
Get arithmetic problem and answer. This is used for evaluation.
"""
x = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
y = self.make_int(n_digits=random.randrange(1, self.max_digits + 1))
return f'x={x}+{y};', f'{x + y}'
def get_packed_math_input(self):
"""
Generate multiple problems and pack them into a sequence.
"""
s_enc = []
while len(s_enc) <= self.seq_len:
s_part = self.make_add_problem()
s_part_enc = self.encode('?' + s_part)
s_enc = s_enc + s_part_enc
return s_enc
def encode(self, s: str):
"""
Encode a given string
"""
return [self.stoi[c] for c in s]
def decode(self, arr: List[int]):
"""
Decode a list of token ids
"""
return ''.join([self.itos[c] for c in arr])
def __getitem__(self, idx: int):
"""
Get a input and target pair for auto-regressive modelling
"""
s = torch.tensor(self.get_packed_math_input())
return s[:self.seq_len], s[1:self.seq_len + 1]
def __len__(self):
"""
Number of sequences per epoch
"""
return self.n_sequences
class ArithmeticAutoregression(NLPAutoRegressionConfigs):
"""
## Arithmetic Task Experiment Configurations
"""
# Maximum number of digits per operand integer
max_digits: int = 4
# Number of training sequences per epoch
train_sequences_per_epoch: int = 2 ** 12
# Training data loader
train_loader: DataLoader = 'arithmetic_train_loader'
# Number of problems in evaluation
n_tests: int = 64
# No need of a validation dataset
validator = None
# Number of times to run evaluations per epoch
inner_iterations = 4
# Number of tokens in the vocabulary
n_tokens = len(ArithmeticDataset(1, 1, 1).itos)
@torch.no_grad()
def sample(self):
"""
### Evaluation
We use the sampling function to evaluate the model on a set of problems
"""
# Skip in the first epoch
if self.training_loop.idx < 1:
return
# Create a dataset to generate problems
dataset = ArithmeticDataset(self.seq_len, self.max_digits, 1)
# Get a set of problems and answers
qa = [dataset.get_qa() for _ in range(self.n_tests)]
# Collect the problems only
questions = [p[0] for p in qa]
# Create a tensor with only the initial token
data = torch.tensor([[dataset.stoi[p[0]] for p in questions]])
# Move to device
data = data.to(self.device)
# Number of sequences that have completed
finished = torch.zeros((len(questions),)).bool().to(self.device)
# Token id of the new line character - this marks end of the answer
new_line = dataset.stoi['\n']
# Sampled results
results = [p[0] for p in questions]
# Sample upto sequence length
for i in monit.iterate('Sample', self.seq_len - 1):
# If all the sequences have completed we skip this
if finished.sum() == len(finished):
continue
# Get the model output
output, *_ = self.model(data)
# Get the model prediction (greedy)
output = output[-1].argmax(dim=-1)
# Find which sequences have finished
finished = finished | (output == new_line)
# Skip if all have finished
if finished.sum() == len(finished):
continue
# Override with the question
for j, p in enumerate(questions):
if len(p) > i + 1:
output[j] = dataset.stoi[p[i + 1]]
# Add the next token to the input
data = torch.cat([data, output[None, :]], dim=0)
# Get the sampled results
for j, c in enumerate(output):
results[j] += dataset.itos[c]
# Discard everything after the answer in the results
results = [r.split('\n')[0] for r in results]
# Log a sample
res_sample = results[0].split(';')
logger.log([(res_sample[0], Text.key), (';', Text.subtle), (';'.join(res_sample[1:]), Text.none)])
# Get the answers
results = [r.split('x==')[-1] for r in results]
# Count the number of correct answers
correct = 0
for r, _qa in zip(results, qa):
if r == _qa[1]:
correct += 1
# Log the score
tracker.save('score', correct / len(results))
@option(ArithmeticAutoregression.train_loader)
def arithmetic_train_loader(c: ArithmeticAutoregression):
"""
Training data loader
"""
return DataLoader(ArithmeticDataset(c.seq_len, c.max_digits, c.train_sequences_per_epoch),
batch_size=c.batch_size,
collate_fn=transpose_batch,
num_workers=4)
def _test():
"""
Code to test generated problems
"""
dataset = ArithmeticDataset(256, 8, 10)
print(dataset.decode(dataset.get_packed_math_input()))
#
if __name__ == '__main__':
_test()

View File

@ -115,37 +115,67 @@ class RotaryPositionalEmbeddings(nn.Module):
\end{pmatrix} \\ \end{pmatrix} \\
\end{align} \end{align}
""" """
def __init__(self, d: int, base: int = 10_000): def __init__(self, d: int, base: int = 10_000):
""" """
* `d` is the number of features $d$ * `d` is the number of features $d$
* `base` is the constant used for calculating $\Theta$ * `base` is the constant used for calculating $\Theta$
""" """
super().__init__() super().__init__()
self.base = base
self.d = d
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
"""
Cache $\cos$ and $\sin$ values
"""
# Return if cache is already built
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
# Get sequence length
seq_len = x.shape[0]
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False) theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
# Extract the shape
seq_len, batch_size, n_heads, d = x.shape
# $\frac{d}{2}$
d_2 = d // 2
# Create position indexes `[0, 1, ..., seq_len - 1]` # Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta) seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
# Calculate the product of position index and $\theta_i$ # Calculate the product of position index and $\theta_i$
idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta) idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
# Concatenate so that for row $m$ we have # Concatenate so that for row $m$ we have
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$ # Cache them
neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) self.cos_cached = idx_theta2.cos()[:, None, None, :]
self.sin_cached = idx_theta2.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor):
# $\frac{d}{2}$
d_2 = self.d // 2
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
# Cache $\cos$ and $\sin$ values
self._build_cache(x)
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
# Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope)
# Calculate # Calculate
# #
@ -157,10 +187,10 @@ class RotaryPositionalEmbeddings(nn.Module):
# \end{align} # \end{align}
# #
# for $i \in {1, 2, ..., \frac{d}{2}}$ # for $i \in {1, 2, ..., \frac{d}{2}}$
rx = (x * idx_theta2.cos()[:, None, None, :]) + (neg_half_x * idx_theta2.sin()[:, None, None, :]) x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
# #
return rx return torch.cat((x_rope, x_pass), dim=-1)
class RotaryPEMultiHeadAttention(MultiHeadAttention): class RotaryPEMultiHeadAttention(MultiHeadAttention):
@ -170,15 +200,13 @@ class RotaryPEMultiHeadAttention(MultiHeadAttention):
We override [multi-head attention from original transformer](../mha.html). We override [multi-head attention from original transformer](../mha.html).
""" """
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
# The linear transformations do not need a bias since we super().__init__(heads, d_model, dropout_prob)
# explicitly include it when calculating scores.
# However having a bias for `value` might make sense.
super().__init__(heads, d_model, dropout_prob, bias=False)
# Rotary positional embedding layers # Rotary positional embedding layers
self.query_rotary_pe = RotaryPositionalEmbeddings(self.d_k) d_rope = int(self.d_k * rope_percentage)
self.key_rotary_pe = RotaryPositionalEmbeddings(self.d_k) self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
def get_scores(self, query: torch.Tensor, key: torch.Tensor): def get_scores(self, query: torch.Tensor, key: torch.Tensor):
""" """

View File

@ -20,7 +20,7 @@ from labml_nn.transformers.basic.autoregressive_experiment import Autoregressive
# ### Rotary PE attention # ### Rotary PE attention
def _rotary_pe_mha(c: TransformerConfigs): def _rotary_pe_mha(c: TransformerConfigs):
from labml_nn.transformers.rope import RotaryPEMultiHeadAttention from labml_nn.transformers.rope import RotaryPEMultiHeadAttention
return RotaryPEMultiHeadAttention(c.n_heads, c.d_model) return RotaryPEMultiHeadAttention(c.n_heads, c.d_model, 1.)
# Configuration options # Configuration options
@ -43,7 +43,7 @@ def _model(c: Configs):
def main(): def main():
# Create experiment # Create experiment
experiment.create(name="rotary_pe_transformer") experiment.create(name="rotary_pe_transformer", writers={'screen'})
# Create configs # Create configs
conf = Configs() conf = Configs()
# Override configurations # Override configurations

View File

@ -0,0 +1,246 @@
"""
---
title: Rotary Positional Embeddings with Relative distance (RoPER)
summary: >
This is an implementation of RoPER which adds relative distance information to embeddings on
top of RoPE introduced in RoFormer: Enhanced Transformer with Rotary Position Embedding
---
*RoPER is work by [Georges Harik (@gharik)](https://twitter.com/gharik),
and this implementation is based on his original code.*
# Rotary Positional Embeddings with Relative distance (RoPER)
[Rotary Positional Embeddings (RoPE)](https://papers.labml.ai/paper/2104.09864) includes
relative positions in attention score calculation.
However, the embeddings themselves do not get any positional information
, [except what it can get implicitly from causal attention](https://papers.labml.ai/paper/2c364684b15b11ecac827bce58715ee7).
RoPER adds relative positional information explicitly to value embeddings.
Specifically, it adds the relative positions of the tokens it paid attention to.
We use same rotary positional embeddings to rotate the values in attention,
Then, after taking the weighted sum,
we rotate the final in the opposite direction.
Which is equivalent to rotating each of the values (before attention) relative to the current position.
Here's [the training code](experiment.html) for training a transformer model with RoPER
on an arithmetic addition where we can see significant improvement over RoPE.
### Relative distances in embeddings
For any head, let $a_{n,i}$ be the attention from position $n$ to position $i$,
and $v_i$ be the value embeddings at position $i$. Let's denote individual features
as $v^{(1)}_i, v^{(2)}_i, \dots$.
Normally, we would take the weight sum of value embeddings
$$o^{(j)}_n = \sum_i a_{n,i} v^{(j)}_i$$
This doesn't explicitly add any distance information about the positions $i$ to final
result $o^{(j)}_n$.
RoPER pairs features like RoPE and transform them.
For a pair $v^{(1)}_m$ and $v^{(2)}_m$ it transforms them by
$RoPE\big(v^{(1)}_m, v^{(2)}_m, m\big)$.
Let us donate the transformed features with $\hat{v}^{(1)}_m, \hat{v}^{(2)}_m$.
Then it rotates the weighted sum $\hat{o}^{(j)}_n$ in the the reverse direction with
$RoPE\big(\hat{o}^{(1)}_n, \hat{o}^{(2)}_n, -n\big)$.
*Note the *$-n$.
Note that,
\begin{align}
RoPE\big(x^{(1)}_m, x^{(2)}_m, m\big) &=
\begin{pmatrix}
\cos m \theta & - \sin m \theta \\
\sin m \theta & \cos m \theta
\end{pmatrix}
\begin{pmatrix}
x^{(1)}_m \\
x^{(2)}_m \\
\end{pmatrix} \\
&=
\begin{pmatrix}
x^{(1)}_m \cos m\theta - x^{(2)}_m \sin m \theta \\
x^{(2)}_m \cos m\theta + x^{(1)}_m \sin m \theta \\
\end{pmatrix} \\
\end{align}
Final output after with the transformations is,
\begin{align}
RoPE\big(\hat{o}^{(1)}_n, \hat{o}^{(2)}_n, -n\big) &= \\
\begin{pmatrix}
\hat{o}^{(1)}_n \cos n\theta + \hat{o}^{(2)}_n \sin n \theta \\
\hat{o}^{(2)}_n \cos n\theta - \hat{o}^{(1)}_n \sin n \theta \\
\end{pmatrix} \\
\end{align}
*Note that *$\sin (-n \theta) = -\sin n \theta$.
Let's expand the first term $\hat{o}^{(1)}_n \cos n\theta + \hat{o}^{(2)}_n \sin n \theta$,
\begin{align}
\hat{o}^{(1)}_n \cos n\theta + \hat{o}^{(2)}_n \sin n \theta &= \\
\sum_i a_{n,i} \hat{v}^{(1)}_i \cos n\theta + \sum_i a_{n,i} \hat{v}^{(2)}_i \sin n \theta &= \\
\sum_i a_{n,i} \Big( v^{(1)}_i \cos i\theta - v^{(2)}_i \sin i \theta \Big) \cos n\theta &+ \\
\sum_i a_{n,i} \Big( v^{(2)}_i \cos i\theta + v^{(1)}_i \sin i \theta \Big) \sin m \theta &= \\
\sum_i a_{n,i} v^{(1)}_i \Big( \cos i\theta \cos n\theta + \sin i \theta \sin n \theta \Big) &+ \\
\sum_i a_{n,i} v^{(2)}_i \Big( \cos i\theta \sin n\theta - \sin i \theta \cos n \theta \Big) &= \\
\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta - \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta &= \\
\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta - \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta
\end{align}
Simiarly we can show the second term is equal to,
$$\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta + \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta$$
Which gives,
\begin{align}
RoPE\big(\hat{o}^{(1)}_n, \hat{o}^{(2)}_n, -n\big) &= \\
\begin{pmatrix}
\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta - \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta \\
\sum_i a_{n,i} v^{(1)}_i \cos (i - n) \theta + \sum_i a_{n,i} v^{(2)}_i \sin (i - n) \theta \\
\end{pmatrix} &= \\
\sum_i a_{n,i} RoPE \big (v^{(1)}_i, v^{(1)}_i, (i - n) \theta \big)
\end{align}
That is, the weighted average of values rotated relative to current position.
[Here's an experiment](arithmetic_experiment.html) that uses RoPER on an arthmetic addition task.
"""
from typing import Optional
import torch
from labml_nn.transformers.rope import RotaryPositionalEmbeddings, RotaryPEMultiHeadAttention
class ReverseRotaryPositionalEmbeddings(RotaryPositionalEmbeddings):
"""
## RoPE module that rotates in the opposite direction
This inherits from [RoPE rotation implementation](../index.html) and changes the direction.
"""
def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
# Cache $\cos$ and $\sin$ values
self._build_cache(x)
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
# Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope)
# Calculate
#
# \begin{align}
# \begin{pmatrix}
# x^{(i)}_m \cos -m \theta_i - x^{(i + \frac{d}{2})}_m \sin -m \theta_i \\
# x^{(i + \frac{d}{2})}_m \cos -m\theta_i + x^{(i)}_m \sin -m \theta_i \\
# \end{pmatrix} = \\
# \begin{pmatrix}
# x^{(i)}_m \cos m \theta_i + x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
# x^{(i + \frac{d}{2})}_m \cos m\theta_i - x^{(i)}_m \sin m \theta_i \\
# \end{pmatrix} \\
# \end{align}
#
# for $i \in {1, 2, ..., \frac{d}{2}}$
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) - (neg_half_x * self.sin_cached[:x.shape[0]])
#
return torch.cat((x_rope, x_pass), dim=-1)
class RotaryValuePEMultiHeadAttention(RotaryPEMultiHeadAttention):
"""
## Multi-head attention with rotary positional embeddings
We override [multi-head attention from original transformer](../mha.html).
"""
def __init__(self, heads: int, d_model: int,
rope_percentage: float = 0.5, rope_value_percentage: float = 0.5,
dropout_prob: float = 0.0):
super().__init__(heads, d_model, rope_percentage, dropout_prob)
# Rotary positional embedding layers
d_rope_value = int(self.d_k * rope_value_percentage)
self.value_rotary_pe = RotaryPositionalEmbeddings(d_rope_value)
self.value_reverse_rotary_pe = ReverseRotaryPositionalEmbeddings(d_rope_value)
def forward(self, *,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None):
"""
`query`, `key` and `value` are the tensors that store
collection of *query*, *key* and *value* vectors.
They have shape `[seq_len, batch_size, d_model]`.
`mask` has shape `[seq_len, seq_len, batch_size]` and
`mask[i, j, b]` indicates whether for batch `b`,
query at position `i` has access to key-value at position `j`.
"""
# `query`, `key` and `value` have shape `[seq_len, batch_size, d_model]`
seq_len, batch_size, _ = query.shape
if mask is not None:
mask = self.prepare_mask(mask, query.shape, key.shape)
# Prepare `query`, `key` and `value` for attention computation.
# These will then have shape `[seq_len, batch_size, heads, d_k]`.
query = self.query(query)
key = self.key(key)
value = self.value(value)
# Compute attention scores $Q K^\top$.
# This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
scores = self.get_scores(query, key)
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
scores *= self.scale
# Apply mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# $softmax$ attention along the key sequence dimension
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = self.softmax(scores)
# Apply dropout
attn = self.dropout(attn)
# Rotate value embeddings before taking the weighted sum so that they contain positional information
value = self.value_rotary_pe(value)
# Multiply by values
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
# Rotate in the opposite direction so that each embedding hold the relative positions
x = self.value_reverse_rotary_pe(x)
# Save attentions for any other calculations
self.attn = attn.detach()
# Concatenate multiple heads
x = x.reshape(seq_len, batch_size, -1)
# Output layer
return self.output(x)

View File

@ -0,0 +1,93 @@
"""
---
title: Rotary Positional Embeddings with Relative distance (RoPER) Experiment
summary: This experiment trains a transformer model with Rotary Positional Embeddings with
Relative Distance (RoPER) on the arithmetic addition task.
---
# Rotary Positional Embeddings with Relative distance ([RoPER](index.html)) Experiment
"""
from labml import experiment
from labml.configs import calculate
from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression
from labml_nn.transformers import TransformerConfigs
from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
class Configs(RoPEConfigs, ArithmeticAutoregression):
"""
We inherit [RoPE experiment](../experiment.html) and use it for
[arithmetic addition task](../../experiments/arithmetic_dataset.html).
We add the option to change attention to use Rotary Positional Embeddings with Relative distance (RoPER)
below.
"""
pass
def _rotary_value_pe_mha(c: TransformerConfigs):
"""
Use Rotary Positional Embeddings with Relative distance ([RoPER](index.html)) in attention.
"""
from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
# Configuration options
calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha)
calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha)
calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha)
def main():
# Create experiment
experiment.create(name="roper_addition", comment="rotary value 7", writers={'screen', 'labml', 'comet'})
# Create configs
conf = Configs()
# Override configurations
experiment.configs(conf, {
'max_digits': 7,
# No fixed positional embeddings
'transformer.src_embed': 'no_pos',
'transformer.tgt_embed': 'no_pos',
# Encoder with RoPER attention
'transformer.encoder_attn': 'rotary_value',
# Encoder with RoPE attention
# 'transformer.encoder_attn': 'rotary',
#
'model': 'rotary_pe_transformer',
# Use a context size of $256$
'seq_len': 512,
# Train for 32 epochs
'epochs': 20,
# Batch size $4$
'batch_size': 16,
# Model size
'd_model': 128,
'transformer.ffn.d_ff': 512,
'transformer.n_heads': 4,
'transformer.dropout': 0.0,
# Use [Adam optimizer](../../optimizers/noam.html)
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 2.5e-4,
})
# Set models for saving and loading
experiment.add_pytorch_models({'model': conf.model})
# Start the experiment
with experiment.start():
# Run training
conf.run()
#
if __name__ == '__main__':
main()

View File

@ -0,0 +1,98 @@
"""
---
title: Rotary Positional Embeddings (RoPE) Experiment
summary: This experiment trains a transformer model with Rotary Positional Embeddings (RoPE) on tiny Shakespeare dataset.
---
# Rotary Positional Embeddings (RoPE) Experiment
This is an annotated PyTorch experiment to train a transformer model with Rotary Positional Embeddings (RoPE).
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/1cf508e693be11ecacc98de8b38a61fe)
"""
from labml import experiment
from labml.configs import calculate
from labml_nn.transformers import TransformerConfigs
from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
# ### Rotary PE attention
class Configs(RoPEConfigs): # , ArithmeticAutoregression):
pass
def _rotary_value_pe_mha(c: TransformerConfigs):
from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
# Configuration options
calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha)
calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha)
calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha)
def main():
# Create experiment
experiment.create(name="rotary_shakespeare", comment="rotary value", writers={'screen', 'labml'})
# Create configs
conf = Configs()
# Override configurations
experiment.configs(conf, {
# No fixed positional embeddings
'transformer.src_embed': 'no_pos',
'transformer.tgt_embed': 'no_pos',
# Encoder with RoPE
'transformer.encoder_attn': 'rotary_value',
# 'transformer.encoder_attn': 'rotary',
#
'model': 'rotary_pe_transformer',
# Use character level tokenizer
'tokenizer': 'character',
# Prompt separator is blank
'prompt_separator': '',
# Starting prompt for sampling
'prompt': 'It is ',
# Use Tiny Shakespeare dataset
'text': 'tiny_shakespeare',
# Use a context size of $256$
'seq_len': 512,
# Train for 32 epochs
'epochs': 24,
# Batch size $4$
'batch_size': 16,
# Switch between training and validation for $10$ times
# per epoch
'inner_iterations': 4,
# Model size
'd_model': 128,
'transformer.ffn.d_ff': 512,
'transformer.n_heads': 4,
'transformer.dropout': 0.0,
# Use [Adam optimizer](../../optimizers/noam.html)
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 2.5e-4,
'dataloader_shuffle_with_replacement': True
})
# Set models for saving and loading
experiment.add_pytorch_models({'model': conf.model})
# Start the experiment
with experiment.start():
# Run training
conf.run()
#
if __name__ == '__main__':
main()

View File

@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup( setuptools.setup(
name='labml-nn', name='labml-nn',
version='0.4.122', version='0.4.123',
author="Varuna Jayasiri, Nipun Wijerathne", author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com", author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="🧑‍🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠", description="🧑‍🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",
@ -20,8 +20,8 @@ setuptools.setup(
'labml_helpers', 'labml_helpers.*', 'labml_helpers', 'labml_helpers.*',
'test', 'test',
'test.*')), 'test.*')),
install_requires=['labml>=0.4.147', install_requires=['labml>=0.4.151',
'labml-helpers>=0.4.84', 'labml-helpers>=0.4.86',
'torch', 'torch',
'torchtext', 'torchtext',
'torchvision', 'torchvision',