experiment link

This commit is contained in:
Varuna Jayasiri
2022-03-12 16:04:05 +05:30
parent a39d91dacd
commit 16bf5d0b10
6 changed files with 307 additions and 298 deletions

View File

@ -80,6 +80,7 @@
<li><a href="model.html">Model</a> </li>
<li><a href="dataset.html">Dataset</a>: Pre-calculate the nearest neighbors </li>
<li><a href="train.html">Training code</a></li></ul>
<p><a href="https://app.labml.ai/run/3113dd3ea1e711ec85ee295d18534021"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>

File diff suppressed because it is too large Load Diff

View File

@ -71,20 +71,21 @@
</div>
<h1>RETRO training</h1>
<p>This is the training code for <a href="index.html">RETRO</a>.</p>
<p><a href="https://app.labml.ai/run/3113dd3ea1e711ec85ee295d18534021"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">RandomSampler</span>
<span class="lineno">17</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">lab</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">logger</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_helpers.datasets.text</span> <span class="kn">import</span> <span class="n">TextFileDataset</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers.noam</span> <span class="kn">import</span> <span class="n">Noam</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro</span> <span class="kn">import</span> <span class="n">model</span> <span class="k">as</span> <span class="n">retro</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">RetroIndex</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.model</span> <span class="kn">import</span> <span class="n">RetroModel</span><span class="p">,</span> <span class="n">NearestNeighborEncoder</span></pre></div>
<div class="highlight"><pre><span class="lineno">16</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">RandomSampler</span>
<span class="lineno">19</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">lab</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">logger</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_helpers.datasets.text</span> <span class="kn">import</span> <span class="n">TextFileDataset</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers.noam</span> <span class="kn">import</span> <span class="n">Noam</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro</span> <span class="kn">import</span> <span class="n">model</span> <span class="k">as</span> <span class="n">retro</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">RetroIndex</span>
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.retro.model</span> <span class="kn">import</span> <span class="n">RetroModel</span><span class="p">,</span> <span class="n">NearestNeighborEncoder</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -97,7 +98,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">Sampler</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">29</span><span class="k">class</span> <span class="nc">Sampler</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -116,7 +117,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">retro</span><span class="o">.</span><span class="n">RetroModel</span><span class="p">,</span> <span class="n">tds</span><span class="p">:</span> <span class="n">TextFileDataset</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">36</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">retro</span><span class="o">.</span><span class="n">RetroModel</span><span class="p">,</span> <span class="n">tds</span><span class="p">:</span> <span class="n">TextFileDataset</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -127,10 +128,10 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">=</span> <span class="n">chunk_len</span>
<span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span> <span class="o">=</span> <span class="n">tds</span>
<span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="lineno">44</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span></pre></div>
<div class="highlight"><pre><span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">=</span> <span class="n">chunk_len</span>
<span class="lineno">44</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span> <span class="o">=</span> <span class="n">tds</span>
<span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="lineno">46</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -142,7 +143,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span> <span class="o">=</span> <span class="n">RetroIndex</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">49</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span> <span class="o">=</span> <span class="n">RetroIndex</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -154,7 +155,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="nf">retrieve_nearest_neighbours</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">chunk</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">51</span> <span class="k">def</span> <span class="nf">retrieve_nearest_neighbours</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">chunk</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -166,7 +167,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="n">neighbor_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span><span class="p">([</span><span class="n">chunk</span><span class="p">],</span> <span class="kc">None</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">neighbor_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">index</span><span class="p">([</span><span class="n">chunk</span><span class="p">],</span> <span class="kc">None</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -179,8 +180,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">train</span>
<span class="lineno">59</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="p">[</span><span class="n">text</span><span class="p">[</span><span class="n">j</span><span class="p">:</span> <span class="n">j</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">neighbor_offsets</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span></pre></div>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">train</span>
<span class="lineno">61</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="p">[</span><span class="n">text</span><span class="p">[</span><span class="n">j</span><span class="p">:</span> <span class="n">j</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">neighbor_offsets</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -192,7 +193,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">return</span> <span class="n">neighbors</span></pre></div>
<div class="highlight"><pre><span class="lineno">64</span> <span class="k">return</span> <span class="n">neighbors</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -204,7 +205,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prompt</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sample_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prompt</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sample_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -216,7 +217,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">neighbors_str</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">neighbors_str</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -228,7 +229,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">sampled</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">sampled</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -241,7 +242,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">sample_len</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">sample_len</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -253,7 +254,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">neighbors_str</span><span class="p">)</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">81</span> <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">neighbors_str</span><span class="p">)</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -265,8 +266,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="n">off</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">neighbors_str</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span>
<span class="lineno">82</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">prompt</span><span class="p">[</span><span class="n">off</span><span class="p">:</span> <span class="n">off</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">83</span> <span class="n">off</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">neighbors_str</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span>
<span class="lineno">84</span> <span class="n">chunk</span> <span class="o">=</span> <span class="n">prompt</span><span class="p">[</span><span class="n">off</span><span class="p">:</span> <span class="n">off</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -278,7 +279,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">neighbors_str</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">retrieve_nearest_neighbours</span><span class="p">(</span><span class="n">chunk</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">86</span> <span class="n">neighbors_str</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">retrieve_nearest_neighbours</span><span class="p">(</span><span class="n">chunk</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -290,7 +291,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">src</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">src</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -302,7 +303,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">chunk</span><span class="p">])</span> <span class="k">for</span> <span class="n">chunk</span> <span class="ow">in</span> <span class="n">neighbors_str</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">chunk</span><span class="p">])</span> <span class="k">for</span> <span class="n">chunk</span> <span class="ow">in</span> <span class="n">neighbors_str</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -314,8 +315,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">src</span> <span class="o">=</span> <span class="n">src</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">93</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">neighbors</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">94</span> <span class="n">src</span> <span class="o">=</span> <span class="n">src</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">95</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">neighbors</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
@ -327,7 +328,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">96</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">src</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">neighbors</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:])</span></pre></div>
<div class="highlight"><pre><span class="lineno">98</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">src</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">neighbors</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:])</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
@ -339,7 +340,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="n">token</span> <span class="o">=</span> <span class="n">res</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">token</span> <span class="o">=</span> <span class="n">res</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
@ -351,8 +352,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">102</span> <span class="n">prompt</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">token</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span>
<span class="lineno">103</span> <span class="n">sampled</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">token</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span></pre></div>
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">prompt</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">token</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span>
<span class="lineno">105</span> <span class="n">sampled</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tds</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">token</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
@ -364,7 +365,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="k">return</span> <span class="n">sampled</span></pre></div>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">return</span> <span class="n">sampled</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
@ -376,7 +377,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span><span class="k">class</span> <span class="nc">Trainer</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">111</span><span class="k">class</span> <span class="nc">Trainer</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
@ -395,8 +396,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">retro</span><span class="o">.</span><span class="n">RetroModel</span><span class="p">,</span>
<span class="lineno">115</span> <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">retro</span><span class="o">.</span><span class="n">RetroModel</span><span class="p">,</span>
<span class="lineno">117</span> <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Optimizer</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -407,11 +408,11 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span>
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="lineno">124</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span> <span class="o">=</span> <span class="n">dataloader</span>
<span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">124</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span>
<span class="lineno">125</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span> <span class="o">=</span> <span class="n">dataloader</span>
<span class="lineno">127</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
<span class="lineno">128</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -423,7 +424,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -435,7 +436,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">)</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">enum</span><span class="p">(</span><span class="s1">&#39;Train&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">)</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">enum</span><span class="p">(</span><span class="s1">&#39;Train&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloader</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
@ -447,7 +448,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">src</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">tgt</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">neighbors</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">src</span><span class="p">,</span> <span class="n">tgt</span><span class="p">,</span> <span class="n">neighbors</span> <span class="o">=</span> <span class="n">src</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">tgt</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">neighbors</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
@ -459,7 +460,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
@ -471,7 +472,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">res</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">res</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">tgt</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">res</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">res</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">tgt</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
@ -483,7 +484,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">146</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
@ -495,7 +496,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
@ -507,7 +508,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
@ -519,8 +520,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">({</span><span class="s1">&#39;loss.train&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">})</span>
<span class="lineno">152</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">({</span><span class="s1">&#39;loss.train&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">})</span>
<span class="lineno">154</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
@ -532,7 +533,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">155</span><span class="k">def</span> <span class="nf">train</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">157</span><span class="k">def</span> <span class="nf">train</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
@ -544,7 +545,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;retro_small&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">163</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;retro_small&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
@ -556,7 +557,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
@ -568,10 +569,10 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">tds</span> <span class="o">=</span> <span class="n">TextFileDataset</span><span class="p">(</span>
<span class="lineno">168</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;tiny_shakespeare.txt&#39;</span><span class="p">,</span>
<span class="lineno">169</span> <span class="nb">list</span><span class="p">,</span>
<span class="lineno">170</span> <span class="n">url</span><span class="o">=</span><span class="s1">&#39;https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">tds</span> <span class="o">=</span> <span class="n">TextFileDataset</span><span class="p">(</span>
<span class="lineno">170</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;tiny_shakespeare.txt&#39;</span><span class="p">,</span>
<span class="lineno">171</span> <span class="nb">list</span><span class="p">,</span>
<span class="lineno">172</span> <span class="n">url</span><span class="o">=</span><span class="s1">&#39;https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
@ -583,7 +584,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;retro_train_dataset.json&#39;</span><span class="p">,</span> <span class="n">tds</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">175</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;retro_train_dataset.json&#39;</span><span class="p">,</span> <span class="n">tds</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
@ -595,9 +596,9 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">train_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
<span class="lineno">178</span> <span class="n">sampler</span><span class="o">=</span><span class="n">RandomSampler</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">replacement</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">178</span> <span class="n">train_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span>
<span class="lineno">179</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
<span class="lineno">180</span> <span class="n">sampler</span><span class="o">=</span><span class="n">RandomSampler</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">replacement</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
@ -609,11 +610,11 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">chunk_len</span> <span class="o">=</span> <span class="mi">16</span>
<span class="lineno">182</span> <span class="n">d_model</span> <span class="o">=</span> <span class="mi">128</span>
<span class="lineno">183</span> <span class="n">d_ff</span> <span class="o">=</span> <span class="mi">512</span>
<span class="lineno">184</span> <span class="n">n_heads</span> <span class="o">=</span> <span class="mi">16</span>
<span class="lineno">185</span> <span class="n">d_k</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">chunk_len</span> <span class="o">=</span> <span class="mi">16</span>
<span class="lineno">184</span> <span class="n">d_model</span> <span class="o">=</span> <span class="mi">128</span>
<span class="lineno">185</span> <span class="n">d_ff</span> <span class="o">=</span> <span class="mi">512</span>
<span class="lineno">186</span> <span class="n">n_heads</span> <span class="o">=</span> <span class="mi">16</span>
<span class="lineno">187</span> <span class="n">d_k</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
@ -625,7 +626,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">nearest_neighbor_encoder</span> <span class="o">=</span> <span class="n">NearestNeighborEncoder</span><span class="p">(</span><span class="n">chunk_len</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="p">{</span><span class="mi">3</span><span class="p">},</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">190</span> <span class="n">nearest_neighbor_encoder</span> <span class="o">=</span> <span class="n">NearestNeighborEncoder</span><span class="p">(</span><span class="n">chunk_len</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="p">{</span><span class="mi">3</span><span class="p">},</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
@ -637,10 +638,10 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="n">model</span> <span class="o">=</span> <span class="n">RetroModel</span><span class="p">(</span><span class="n">tds</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span>
<span class="lineno">191</span> <span class="p">{</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">},</span>
<span class="lineno">192</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span>
<span class="lineno">193</span> <span class="n">encoder</span><span class="o">=</span><span class="n">nearest_neighbor_encoder</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">192</span> <span class="n">model</span> <span class="o">=</span> <span class="n">RetroModel</span><span class="p">(</span><span class="n">tds</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span>
<span class="lineno">193</span> <span class="p">{</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">},</span>
<span class="lineno">194</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span>
<span class="lineno">195</span> <span class="n">encoder</span><span class="o">=</span><span class="n">nearest_neighbor_encoder</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
@ -652,7 +653,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
@ -664,7 +665,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">Noam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">d_model</span><span class="o">=</span><span class="n">d_model</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="mi">2_000</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">Noam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">d_model</span><span class="o">=</span><span class="n">d_model</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="mi">2_000</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
@ -677,7 +678,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
@ -690,7 +691,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">sampler</span> <span class="o">=</span> <span class="n">Sampler</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">tds</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">sampler</span> <span class="o">=</span> <span class="n">Sampler</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">tds</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
@ -702,7 +703,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">prompt</span> <span class="o">=</span> <span class="s1">&#39;&#39;&#39;Second Citizen:</span><span class="se">\n</span><span class="s1">One word, good citizens.</span><span class="se">\n\n</span><span class="s1">First Citizen:&#39;&#39;&#39;</span></pre></div>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">prompt</span> <span class="o">=</span> <span class="s1">&#39;&#39;&#39;Second Citizen:</span><span class="se">\n</span><span class="s1">One word, good citizens.</span><span class="se">\n\n</span><span class="s1">First Citizen:&#39;&#39;&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
@ -714,7 +715,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
@ -726,7 +727,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
@ -739,7 +740,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">211</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="mi">32</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">213</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="mi">32</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
@ -751,7 +752,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="n">trainer</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">215</span> <span class="n">trainer</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
@ -763,7 +764,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
@ -776,8 +777,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">([(</span><span class="n">prompt</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="se">\\</span><span class="s1">n</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">),</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span>
<span class="lineno">218</span> <span class="p">(</span><span class="n">sampler</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="se">\\</span><span class="s1">n</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">),</span> <span class="n">Text</span><span class="o">.</span><span class="n">none</span><span class="p">)])</span></pre></div>
<div class="highlight"><pre><span class="lineno">219</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">([(</span><span class="n">prompt</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="se">\\</span><span class="s1">n</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">),</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span>
<span class="lineno">220</span> <span class="p">(</span><span class="n">sampler</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">,</span> <span class="s1">&#39;</span><span class="se">\\</span><span class="s1">n</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">),</span> <span class="n">Text</span><span class="o">.</span><span class="n">none</span><span class="p">)])</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
@ -789,7 +790,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="n">experiment</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">222</span> <span class="n">experiment</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
@ -801,8 +802,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">225</span> <span class="n">train</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">226</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">227</span> <span class="n">train</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -32,4 +32,6 @@ Components:
* [Model](model.html)
* [Dataset](dataset.html): Pre-calculate the nearest neighbors
* [Training code](train.html)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/3113dd3ea1e711ec85ee295d18534021)
"""

View File

@ -9,6 +9,8 @@ summary: >
This is the model definition for
[RETRO](index.html).
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/3113dd3ea1e711ec85ee295d18534021)
"""
import math

View File

@ -9,6 +9,8 @@ summary: >
This is the training code for
[RETRO](index.html).
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/3113dd3ea1e711ec85ee295d18534021)
"""
import torch