This commit is contained in:
Varuna Jayasiri
2025-07-31 08:48:07 +05:30
parent 00f8714843
commit c4d2e8cd22
101 changed files with 11682 additions and 7441 deletions

View File

@ -84,7 +84,7 @@
<span class="lineno">20</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">21</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">monit</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_helpers.datasets.text</span> <span class="kn">import</span> <span class="n">TextFileDataset</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.datasets</span> <span class="kn">import</span> <span class="n">TextFileDataset</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.bert_embeddings</span> <span class="kn">import</span> <span class="n">BERTChunkEmbeddings</span></pre></div>
</div>
</div>

View File

@ -84,7 +84,7 @@
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span> <span class="k">as</span> <span class="n">PyTorchDataset</span>
<span class="lineno">21</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">monit</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_helpers.datasets.text</span> <span class="kn">import</span> <span class="n">TextFileDataset</span><span class="p">,</span> <span class="n">TextDataset</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.datasets</span> <span class="kn">import</span> <span class="n">TextFileDataset</span><span class="p">,</span> <span class="n">TextDataset</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.database</span> <span class="kn">import</span> <span class="n">RetroIndex</span></pre></div>
</div>
</div>

View File

@ -77,16 +77,15 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">RandomSampler</span>
<span class="lineno">17</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">lab</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">logger</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_helpers.datasets.text</span> <span class="kn">import</span> <span class="n">TextFileDataset</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers.noam</span> <span class="kn">import</span> <span class="n">Noam</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro</span> <span class="kn">import</span> <span class="n">model</span> <span class="k">as</span> <span class="n">retro</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">RetroIndex</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.model</span> <span class="kn">import</span> <span class="n">RetroModel</span><span class="p">,</span> <span class="n">NearestNeighborEncoder</span></pre></div>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">lab</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">logger</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.helpers.datasets</span> <span class="kn">import</span> <span class="n">TextFileDataset</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers.noam</span> <span class="kn">import</span> <span class="n">Noam</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro</span> <span class="kn">import</span> <span class="n">model</span> <span class="k">as</span> <span class="n">retro</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">RetroIndex</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.model</span> <span class="kn">import</span> <span class="n">RetroModel</span><span class="p">,</span> <span class="n">NearestNeighborEncoder</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">RandomSampler</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -99,7 +98,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">Sampler</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">26</span><span class="k">class</span> <span class="nc">Sampler</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -118,7 +117,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">retro</span><span class="o">.</span><span class="n">RetroModel</span><span class="p">,</span> <span class="n">tds</span><span class="p">:</span> <span class="n">TextFileDataset</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">33</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">retro</span><span class="o">.</span><span class="n">RetroModel</span><span class="p">,</span> <span class="n">tds</span><span class="p">:</span> <span class="n">TextFileDataset</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -129,10 +128,10 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">=</span> <span class="n">chunk_len</span>
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span> <span class="o">=</span> <span class="n">tds</span>
<span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="lineno">44</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span></pre></div>
<div class="highlight"><pre><span class="lineno">40</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">=</span> <span class="n">chunk_len</span>
<span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span> <span class="o">=</span> <span class="n">tds</span>
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -144,7 +143,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span> <span class="o">=</span> <span class="n">RetroIndex</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">46</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span> <span class="o">=</span> <span class="n">RetroIndex</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -156,7 +155,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="nf">retrieve_nearest_neighbours</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">chunk</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">48</span> <span class="k">def</span> <span class="nf">retrieve_nearest_neighbours</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">chunk</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -168,7 +167,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="n">neighbor_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span><span class="p">([</span><span class="n">chunk</span><span class="p">],</span> <span class="kc">None</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">neighbor_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span><span class="p">([</span><span class="n">chunk</span><span class="p">],</span> <span class="kc">None</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -181,8 +180,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">train</span>
<span class="lineno">59</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="p">[</span><span class="n">text</span><span class="p">[</span><span class="n">j</span><span class="p">:</span> <span class="n">j</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">neighbor_offsets</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span></pre></div>
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">train</span>
<span class="lineno">58</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="p">[</span><span class="n">text</span><span class="p">[</span><span class="n">j</span><span class="p">:</span> <span class="n">j</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">neighbor_offsets</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -194,7 +193,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">return</span> <span class="n">neighbors</span></pre></div>
<div class="highlight"><pre><span class="lineno">61</span> <span class="k">return</span> <span class="n">neighbors</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -206,7 +205,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prompt</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sample_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">63</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prompt</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sample_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -218,7 +217,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">neighbors_str</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
<div class="highlight"><pre><span class="lineno">69</span> <span class="n">neighbors_str</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -230,7 +229,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">sampled</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">sampled</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -243,7 +242,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">sample_len</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">sample_len</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -255,7 +254,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">neighbors_str</span><span class="p">)</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">neighbors_str</span><span class="p">)</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -267,8 +266,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="n">off</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">neighbors_str</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span>
<span class="lineno">82</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">prompt</span><span class="p">[</span><span class="n">off</span><span class="p">:</span> <span class="n">off</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">off</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">neighbors_str</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span>
<span class="lineno">81</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">prompt</span><span class="p">[</span><span class="n">off</span><span class="p">:</span> <span class="n">off</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -280,7 +279,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">neighbors_str</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">retrieve_nearest_neighbours</span><span class="p">(</span><span class="n">chunk</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">83</span> <span class="n">neighbors_str</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">retrieve_nearest_neighbours</span><span class="p">(</span><span class="n">chunk</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -292,7 +291,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">src</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">86</span> <span class="n">src</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -304,7 +303,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">chunk</span><span class="p">])</span> <span class="k">for</span> <span class="n">chunk</span> <span class="ow">in</span> <span class="n">neighbors_str</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">88</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">chunk</span><span class="p">])</span> <span class="k">for</span> <span class="n">chunk</span> <span class="ow">in</span> <span class="n">neighbors_str</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -316,8 +315,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">src</span> <span class="o">=</span> <span class="n">src</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">93</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">neighbors</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">src</span> <span class="o">=</span> <span class="n">src</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">92</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">neighbors</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
@ -329,7 +328,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">src</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">neighbors</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:])</span></pre></div>
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">src</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">neighbors</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:])</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
@ -341,7 +340,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="n">token</span> <span class="o">=</span> <span class="n">res</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">98</span> <span class="n">token</span> <span class="o">=</span> <span class="n">res</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
@ -353,8 +352,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">102</span> <span class="n">prompt</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">token</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span>
<span class="lineno">103</span> <span class="n">sampled</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">token</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span></pre></div>
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">prompt</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">token</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span>
<span class="lineno">102</span> <span class="n">sampled</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">token</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
@ -366,7 +365,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="k">return</span> <span class="n">sampled</span></pre></div>
<div class="highlight"><pre><span class="lineno">105</span> <span class="k">return</span> <span class="n">sampled</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
@ -378,7 +377,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span><span class="k">class</span> <span class="nc">Trainer</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">108</span><span class="k">class</span> <span class="nc">Trainer</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
@ -397,8 +396,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">retro</span><span class="o">.</span><span class="n">RetroModel</span><span class="p">,</span>
<span class="lineno">115</span> <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">113</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">retro</span><span class="o">.</span><span class="n">RetroModel</span><span class="p">,</span>
<span class="lineno">114</span> <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -409,11 +408,11 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span>
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="lineno">124</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span> <span class="o">=</span> <span class="n">dataloader</span>
<span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span>
<span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span> <span class="o">=</span> <span class="n">dataloader</span>
<span class="lineno">124</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -425,7 +424,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -437,7 +436,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">)</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">enum</span><span class="p">(</span><span class="s1">&#39;Train&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">)</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">enum</span><span class="p">(</span><span class="s1">&#39;Train&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
@ -449,7 +448,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">src</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">tgt</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">neighbors</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">135</span> <span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">src</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">tgt</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">neighbors</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
@ -461,7 +460,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
@ -473,7 +472,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">res</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">res</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">tgt</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">res</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">res</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">tgt</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
@ -485,7 +484,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
@ -497,7 +496,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
@ -509,7 +508,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">147</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
@ -521,8 +520,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">({</span><span class="s1">&#39;loss.train&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">})</span>
<span class="lineno">152</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">({</span><span class="s1">&#39;loss.train&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">})</span>
<span class="lineno">151</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
@ -534,7 +533,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">155</span><span class="k">def</span> <span class="nf">train</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">154</span><span class="k">def</span> <span class="nf">train</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
@ -546,7 +545,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;retro_small&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;retro_small&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
@ -558,7 +557,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">163</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
@ -570,10 +569,10 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">tds</span> <span class="o">=</span> <span class="n">TextFileDataset</span><span class="p">(</span>
<span class="lineno">168</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;tiny_shakespeare.txt&#39;</span><span class="p">,</span>
<span class="lineno">169</span> <span class="nb">list</span><span class="p">,</span>
<span class="lineno">170</span> <span class="n">url</span><span class="o">=</span><span class="s1">&#39;https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">tds</span> <span class="o">=</span> <span class="n">TextFileDataset</span><span class="p">(</span>
<span class="lineno">167</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;tiny_shakespeare.txt&#39;</span><span class="p">,</span>
<span class="lineno">168</span> <span class="nb">list</span><span class="p">,</span>
<span class="lineno">169</span> <span class="n">url</span><span class="o">=</span><span class="s1">&#39;https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
@ -585,7 +584,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;retro_train_dataset.json&#39;</span><span class="p">,</span> <span class="n">tds</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;retro_train_dataset.json&#39;</span><span class="p">,</span> <span class="n">tds</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
@ -597,9 +596,9 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">train_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
<span class="lineno">178</span> <span class="n">sampler</span><span class="o">=</span><span class="n">RandomSampler</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">replacement</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">175</span> <span class="n">train_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span>
<span class="lineno">176</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">sampler</span><span class="o">=</span><span class="n">RandomSampler</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">replacement</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
@ -611,11 +610,11 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">chunk_len</span> <span class="o">=</span> <span class="mi">16</span>
<span class="lineno">182</span> <span class="n">d_model</span> <span class="o">=</span> <span class="mi">128</span>
<span class="lineno">183</span> <span class="n">d_ff</span> <span class="o">=</span> <span class="mi">512</span>
<span class="lineno">184</span> <span class="n">n_heads</span> <span class="o">=</span> <span class="mi">16</span>
<span class="lineno">185</span> <span class="n">d_k</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
<div class="highlight"><pre><span class="lineno">180</span> <span class="n">chunk_len</span> <span class="o">=</span> <span class="mi">16</span>
<span class="lineno">181</span> <span class="n">d_model</span> <span class="o">=</span> <span class="mi">128</span>
<span class="lineno">182</span> <span class="n">d_ff</span> <span class="o">=</span> <span class="mi">512</span>
<span class="lineno">183</span> <span class="n">n_heads</span> <span class="o">=</span> <span class="mi">16</span>
<span class="lineno">184</span> <span class="n">d_k</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
@ -627,7 +626,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">nearest_neighbor_encoder</span> <span class="o">=</span> <span class="n">NearestNeighborEncoder</span><span class="p">(</span><span class="n">chunk_len</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="p">{</span><span class="mi">3</span><span class="p">},</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">187</span> <span class="n">nearest_neighbor_encoder</span> <span class="o">=</span> <span class="n">NearestNeighborEncoder</span><span class="p">(</span><span class="n">chunk_len</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="p">{</span><span class="mi">3</span><span class="p">},</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
@ -639,10 +638,10 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="n">model</span> <span class="o">=</span> <span class="n">RetroModel</span><span class="p">(</span><span class="n">tds</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span>
<span class="lineno">191</span> <span class="p">{</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">},</span>
<span class="lineno">192</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span>
<span class="lineno">193</span> <span class="n">encoder</span><span class="o">=</span><span class="n">nearest_neighbor_encoder</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">model</span> <span class="o">=</span> <span class="n">RetroModel</span><span class="p">(</span><span class="n">tds</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span>
<span class="lineno">190</span> <span class="p">{</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">},</span>
<span class="lineno">191</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span>
<span class="lineno">192</span> <span class="n">encoder</span><span class="o">=</span><span class="n">nearest_neighbor_encoder</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
@ -654,7 +653,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">194</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
@ -666,7 +665,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">Noam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">d_model</span><span class="o">=</span><span class="n">d_model</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="mi">2_000</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">Noam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">d_model</span><span class="o">=</span><span class="n">d_model</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="mi">2_000</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
@ -679,7 +678,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">198</span> <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
@ -692,7 +691,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">sampler</span> <span class="o">=</span> <span class="n">Sampler</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">tds</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">sampler</span> <span class="o">=</span> <span class="n">Sampler</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">tds</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
@ -704,7 +703,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">prompt</span> <span class="o">=</span> <span class="s1">&#39;&#39;&#39;Second Citizen:</span><span class="se">\n</span><span class="s1">One word, good citizens.</span><span class="se">\n\n</span><span class="s1">First Citizen:&#39;&#39;&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">202</span> <span class="n">prompt</span> <span class="o">=</span> <span class="s1">&#39;&#39;&#39;Second Citizen:</span><span class="se">\n</span><span class="s1">One word, good citizens.</span><span class="se">\n\n</span><span class="s1">First Citizen:&#39;&#39;&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
@ -716,7 +715,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
@ -728,7 +727,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">208</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
@ -741,7 +740,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="mi">32</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">210</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="mi">32</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
@ -753,7 +752,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="n">trainer</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">212</span> <span class="n">trainer</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
@ -765,7 +764,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">214</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
@ -778,8 +777,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">([(</span><span class="n">prompt</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="se">\\</span><span class="s1">n</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">),</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span>
<span class="lineno">218</span> <span class="p">(</span><span class="n">sampler</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="se">\\</span><span class="s1">n</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">),</span> <span class="n">Text</span><span class="o">.</span><span class="n">none</span><span class="p">)])</span></pre></div>
<div class="highlight"><pre><span class="lineno">216</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">([(</span><span class="n">prompt</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="se">\\</span><span class="s1">n</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">),</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span>
<span class="lineno">217</span> <span class="p">(</span><span class="n">sampler</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="se">\\</span><span class="s1">n</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">),</span> <span class="n">Text</span><span class="o">.</span><span class="n">none</span><span class="p">)])</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
@ -791,7 +790,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="n">experiment</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">()</span></pre></div>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-56'>
@ -803,8 +802,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">225</span> <span class="n">train</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">222</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">223</span> <span class="n">train</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>