Files
2024-06-21 19:35:22 +05:30

1436 lines
94 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This experiment trains a compressive transformer model on tiny Shakespeare dataset."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Compressive Transformer Experiment"/>
<meta name="twitter:description" content="This experiment trains a compressive transformer model on tiny Shakespeare dataset."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/compressive/experiment.html"/>
<meta property="og:title" content="Compressive Transformer Experiment"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Compressive Transformer Experiment"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Compressive Transformer Experiment"/>
<meta property="og:description" content="This experiment trains a compressive transformer model on tiny Shakespeare dataset."/>
<title>Compressive Transformer Experiment</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/transformers/compressive/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">compressive</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/compressive/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Compressive Transformer Experiment</h1>
<p>This is an annotated PyTorch experiment to train a compressive transformer model.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">NamedTuple</span>
<span class="lineno">12</span>
<span class="lineno">13</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">14</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">monit</span><span class="p">,</span> <span class="n">logger</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_helpers.metrics.simple_state</span> <span class="kn">import</span> <span class="n">SimpleStateModule</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">BatchIndex</span><span class="p">,</span> <span class="n">hook_model_outputs</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.compressive</span> <span class="kn">import</span> <span class="n">CompressiveTransformer</span><span class="p">,</span> <span class="n">AttentionReconstructionLoss</span><span class="p">,</span> \
<span class="lineno">24</span> <span class="n">CompressiveTransformerLayer</span><span class="p">,</span> <span class="n">Conv1dCompression</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">CompressedMemory</span><span class="p">(</span><span class="n">NamedTuple</span><span class="p">):</span>
<span class="lineno">28</span> <span class="n">mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span>
<span class="lineno">29</span> <span class="n">c_mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<h2>Auto regressive model</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">AutoregressiveModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">transformer</span><span class="p">:</span> <span class="n">CompressiveTransformer</span><span class="p">):</span>
<span class="lineno">38</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Token embedding module </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Transformer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span> <span class="o">=</span> <span class="n">transformer</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Final layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Masks </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_x</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">47</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_mem</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">CompressedMemory</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Get memory and compressed memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="k">if</span> <span class="n">mem</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">52</span> <span class="n">mem</span><span class="p">,</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="n">mem</span><span class="o">.</span><span class="n">mem</span><span class="p">,</span> <span class="n">mem</span><span class="o">.</span><span class="n">c_mem</span>
<span class="lineno">53</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">54</span> <span class="n">mem</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">55</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Total length of the memory and compressed memory (for masks) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">m_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">mem</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="k">if</span> <span class="n">mem</span> <span class="k">else</span> <span class="mi">0</span>
<span class="lineno">59</span> <span class="k">if</span> <span class="n">c_mem</span><span class="p">:</span>
<span class="lineno">60</span> <span class="n">m_len</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">c_mem</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Create a subsequent mask for tokens </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_x</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="lineno">64</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span>
<span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_x</span> <span class="o">=</span> <span class="n">subsequent_mask</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Create an all ones (full visibility) mask for memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_mem</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_mem</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">m_len</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_mem</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_x</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">m_len</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Concatenate the masks if there is memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="k">if</span> <span class="n">m_len</span><span class="p">:</span>
<span class="lineno">72</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_mem</span><span class="p">[:</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">:</span><span class="n">m_len</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_x</span><span class="p">[:</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">:</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)]),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Use only the subsequent mask otherwise </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">75</span> <span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_x</span><span class="p">[:</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">:</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Token embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">src_embed</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Run it through the transformer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">res</span><span class="p">,</span> <span class="n">mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">mem</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Generate logits of the next token </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="k">return</span> <span class="n">res</span><span class="p">,</span> <span class="n">mem</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<h2>Configurations</h2>
<p>The default configurations can and will be overridden when we start the experiment.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="n">model</span><span class="p">:</span> <span class="n">AutoregressiveModel</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Token embedding size </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Number of attention heads </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Dropout probability </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Number of features in FFN hidden layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Number of transformer layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">105</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Number of memories to keep </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="n">mem_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>State module to maintain memories when switching between training and validation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span> <span class="n">memory</span> <span class="o">=</span> <span class="n">SimpleStateModule</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Attention Reconstruction Loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">attention_reconstruction_loss</span><span class="p">:</span> <span class="n">AttentionReconstructionLoss</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Compression rate </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">compression_rate</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Compressed memory length </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span> <span class="n">c_mem_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Set tracker configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;accuracy.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
<span class="lineno">120</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;loss.*&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Do not print the attention reconstruction loss in the terminal </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">&quot;ar_loss.*&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Add a hook to log module outputs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="s1">&#39;model&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>This will keep the accuracy metric stats and memories separate for training and validation. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">126</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p> Concatenate new memories and compress the oldest memories.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">129</span> <span class="k">def</span> <span class="nf">merge_compress_memory</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">CompressedMemory</span><span class="p">,</span> <span class="n">new_mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">])</span> \
<span class="lineno">130</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">CompressedMemory</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]:</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>If the configurations specify not to use memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_len</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">c_mem_len</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">137</span> <span class="k">return</span> <span class="n">CompressedMemory</span><span class="p">([],</span> <span class="p">[]),</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Get memory and compressed memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="k">if</span> <span class="n">mem</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">141</span> <span class="n">mem</span><span class="p">,</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="n">mem</span><span class="o">.</span><span class="n">mem</span><span class="p">,</span> <span class="n">mem</span><span class="o">.</span><span class="n">c_mem</span>
<span class="lineno">142</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">143</span> <span class="n">mem</span><span class="p">,</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Concatenate new memories with old memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="k">if</span> <span class="n">mem</span><span class="p">:</span>
<span class="lineno">147</span> <span class="n">mem</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">m</span><span class="p">,</span> <span class="n">x</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">m</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">mem</span><span class="p">,</span> <span class="n">new_mem</span><span class="p">)]</span>
<span class="lineno">148</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">149</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">new_mem</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>Compress the oldest memories if there are more memories than <code class="highlight"><span></span><span class="n">mem_len</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mem</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_len</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Calculate the number of compressed memories to make <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">c</span><span class="mord mathnormal mtight">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.40003em;vertical-align:-0.95003em;"></span><span class="mord"><span class="delimsizing size3"></span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.08968em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.5102em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqc" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8278285714285715em;"><span style="top:-2.214em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">m</span></span></span><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286em;"><span></span></span></span></span></span></span></span><span class="mbin mtight"></span><span class="mord mtight coloredeq eqd" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16454285714285719em;"><span style="top:-2.357em;margin-left:-0.10903em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size3"></span></span></span></span></span></span>, where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.998892em;vertical-align:-0.247em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style=""><span class="mord mathnormal" style="">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.751892em;"><span style="top:-2.4530000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">m</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span></span></span></span></span></span> is the number of memories we have and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> is the maximum number of memories we maintain (<code class="highlight"><span></span><span class="n">mem_len</span></code>
). </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">n_c_mem</span> <span class="o">=</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">mem</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_len</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">compression_rate</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">compression_rate</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Number of memories to compress <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqb" style=""><span class="mord mathnormal" style="">c</span><span class="mord" style=""><span class="mord mathnormal" style="">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">c</span><span class="mord mathnormal mtight" style="">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">n_old</span> <span class="o">=</span> <span class="n">n_c_mem</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">compression_rate</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>A list to keep memories that need to be compressed for each layer. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">mem_to_compress</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>A list to keep the memories that do not get compressed for each layer. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">uncompressed_mem</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Iterate through memories of each layer. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">mem</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Split the memories at <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqb" style=""><span class="mord mathnormal" style="">c</span><span class="mord" style=""><span class="mord mathnormal" style="">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">c</span><span class="mord mathnormal mtight" style="">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">cm</span><span class="p">,</span> <span class="n">m</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="p">[</span><span class="n">n_old</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">m</span><span class="p">)</span> <span class="o">-</span> <span class="n">n_old</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>Collect memories to compress </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">mem_to_compress</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">cm</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>Collect remaining memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="n">uncompressed_mem</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">m</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Update the memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">uncompressed_mem</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Compress the memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">175</span> <span class="n">new_c_mem</span> <span class="o">=</span> <span class="p">[]</span>
<span class="lineno">176</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span>
<span class="lineno">177</span> <span class="n">new_c_mem</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">compress</span><span class="p">(</span><span class="n">mem_to_compress</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Concatenate newly compressed memories with old compressed memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">180</span> <span class="k">if</span> <span class="n">c_mem</span><span class="p">:</span>
<span class="lineno">181</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">m</span><span class="p">,</span> <span class="n">nm</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">m</span><span class="p">,</span> <span class="n">nm</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">c_mem</span><span class="p">,</span> <span class="n">new_c_mem</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>If there are no old compressed memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">184</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="n">new_c_mem</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Truncate old memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">c_mem</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">c_mem_len</span><span class="p">:</span>
<span class="lineno">188</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="p">[</span><span class="n">m</span><span class="p">[</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">c_mem_len</span><span class="p">:]</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">c_mem</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>No memories are compressed if the number of memories is less than <code class="highlight"><span></span><span class="n">mem_len</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">191</span> <span class="n">mem_to_compress</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Return memories and the memories that were compressed. Memories that were compressed are needed for the reconstruction loss computation. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="k">return</span> <span class="n">CompressedMemory</span><span class="p">(</span><span class="n">mem</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">),</span> <span class="n">mem_to_compress</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<h3>Training/validation step</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Move data to the device </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Update global step (number of tokens processed) when in training mode </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
<span class="lineno">207</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Whether to capture model outputs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">is_log_activations</span><span class="o">=</span><span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>Get memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="n">mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">get</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>Run the model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="n">output</span><span class="p">,</span> <span class="n">new_mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>Merge and compress memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">216</span> <span class="n">mem</span><span class="p">,</span> <span class="n">mem_to_compress</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">merge_compress_memory</span><span class="p">(</span><span class="n">mem</span><span class="p">,</span> <span class="n">new_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Update memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">218</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>Calculate and log cross entropy loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">221</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<span class="lineno">222</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;loss.&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p>Calculate attention reconstruction loss if memories were compressed in this step </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">225</span> <span class="k">if</span> <span class="n">mem_to_compress</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>Get attention reconstruction loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">ar_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_reconstruction_loss</span><span class="p">(</span><span class="n">new_mem</span><span class="p">,</span> <span class="n">mem_to_compress</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>Track attention reconstruction loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">229</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">&quot;ar_loss.&quot;</span><span class="p">,</span> <span class="n">ar_loss</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>Add attention reconstruction loss to loss </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">231</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">+</span> <span class="n">ar_loss</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>Calculate and log accuracy </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">234</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<span class="lineno">235</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="o">.</span><span class="n">track</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
<p>Train the model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">238</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>Calculate gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">240</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>Clip gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">242</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">grad_norm_clip</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>Take optimizer step </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">244</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>Log the model parameters and gradients on last batch of every epoch </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span><span class="p">:</span>
<span class="lineno">247</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">&#39;model&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>Clear the gradients </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">249</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>Save the tracked metrics </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<h3>Sampling function to generate samples periodically while training</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<p>Starting prompt </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">260</span> <span class="n">prompt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prompt</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<p>Collect output for printing </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">262</span> <span class="n">log</span> <span class="o">=</span> <span class="p">[(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<p>memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">264</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">CompressedMemory</span><span class="p">([],</span> <span class="p">[])</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
<p>Sample 25 tokens </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">266</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Sample&#39;</span><span class="p">,</span> <span class="mi">25</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>Tokenize the prompt </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">268</span> <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>Move to device </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">270</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>Get the model output </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">272</span> <span class="n">output</span><span class="p">,</span> <span class="n">new_mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<p>Get the model prediction (greedy) </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">274</span> <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
<p>Add the prediction to prompt </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">276</span> <span class="n">prompt</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prompt_separator</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<p>Only feed the last character to model in next iteration, rest will go in as memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">278</span> <span class="n">prompt</span> <span class="o">=</span> <span class="n">prompt</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<p>Add the prediction for logging </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">280</span> <span class="n">log</span> <span class="o">+=</span> <span class="p">[(</span><span class="bp">self</span><span class="o">.</span><span class="n">prompt_separator</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">Text</span><span class="o">.</span><span class="n">value</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
<p>Update and compress memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">mem</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">merge_compress_memory</span><span class="p">(</span><span class="n">mem</span><span class="p">,</span> <span class="n">new_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>Print the sampled output </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">log</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<h3>Initialize the auto-regressive model</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">289</span><span class="k">def</span> <span class="nf">autoregressive_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">293</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.xl</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span>
<span class="lineno">294</span> <span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="lineno">295</span> <span class="n">m</span> <span class="o">=</span> <span class="n">AutoregressiveModel</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">CompressiveTransformer</span><span class="p">(</span>
<span class="lineno">296</span> <span class="n">CompressiveTransformerLayer</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span>
<span class="lineno">297</span> <span class="n">self_attn</span><span class="o">=</span><span class="n">RelativeMultiHeadAttention</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">),</span>
<span class="lineno">298</span> <span class="n">feed_forward</span><span class="o">=</span><span class="n">FeedForward</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">),</span>
<span class="lineno">299</span> <span class="n">dropout_prob</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span>
<span class="lineno">300</span> <span class="n">compress</span><span class="o">=</span><span class="n">Conv1dCompression</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">compression_rate</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span><span class="p">)),</span> <span class="n">c</span><span class="o">.</span><span class="n">n_layers</span><span class="p">))</span>
<span class="lineno">301</span> <span class="k">return</span> <span class="n">m</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<h3>Initialize the attention reconstruction loss</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">304</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">attention_reconstruction_loss</span><span class="p">)</span>
<span class="lineno">305</span><span class="k">def</span> <span class="nf">attention_reconstruction_loss</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">309</span> <span class="k">return</span> <span class="n">AttentionReconstructionLoss</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<h3>Run the experiment</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p>Create experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">317</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;compressive_transformer&quot;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">&#39;&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>Create configs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">319</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-98'>
<div class='docs'>
<div class='section-link'>
<a href='#section-98'>#</a>
</div>
<p>Load configurations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">321</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span></pre></div>
</div>
</div>
<div class='section' id='section-99'>
<div class='docs'>
<div class='section-link'>
<a href='#section-99'>#</a>
</div>
<p>A dictionary of configurations to override </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">323</span> <span class="p">{</span><span class="s1">&#39;tokenizer&#39;</span><span class="p">:</span> <span class="s1">&#39;character&#39;</span><span class="p">,</span>
<span class="lineno">324</span> <span class="s1">&#39;text&#39;</span><span class="p">:</span> <span class="s1">&#39;tiny_shakespeare&#39;</span><span class="p">,</span>
<span class="lineno">325</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">2.5e-4</span><span class="p">,</span>
<span class="lineno">326</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;AdamW&#39;</span><span class="p">,</span>
<span class="lineno">327</span> <span class="s1">&#39;prompt&#39;</span><span class="p">:</span> <span class="s1">&#39;It is&#39;</span><span class="p">,</span>
<span class="lineno">328</span> <span class="s1">&#39;prompt_separator&#39;</span><span class="p">:</span> <span class="s1">&#39;&#39;</span><span class="p">,</span>
<span class="lineno">329</span>
<span class="lineno">330</span> <span class="s1">&#39;train_loader&#39;</span><span class="p">:</span> <span class="s1">&#39;sequential_train_loader&#39;</span><span class="p">,</span>
<span class="lineno">331</span> <span class="s1">&#39;valid_loader&#39;</span><span class="p">:</span> <span class="s1">&#39;sequential_valid_loader&#39;</span><span class="p">,</span>
<span class="lineno">332</span>
<span class="lineno">333</span> <span class="s1">&#39;seq_len&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="lineno">334</span> <span class="s1">&#39;mem_len&#39;</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span>
<span class="lineno">335</span> <span class="s1">&#39;epochs&#39;</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
<span class="lineno">336</span> <span class="s1">&#39;batch_size&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span>
<span class="lineno">337</span> <span class="s1">&#39;inner_iterations&#39;</span><span class="p">:</span> <span class="mi">25</span><span class="p">,</span>
<span class="lineno">338</span> <span class="s1">&#39;compression_rate&#39;</span><span class="p">:</span> <span class="mi">2</span><span class="p">,</span>
<span class="lineno">339</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-100'>
<div class='docs'>
<div class='section-link'>
<a href='#section-100'>#</a>
</div>
<p>Set models for saving and loading </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">342</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;model&#39;</span><span class="p">:</span> <span class="n">conf</span><span class="o">.</span><span class="n">model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-101'>
<div class='docs'>
<div class='section-link'>
<a href='#section-101'>#</a>
</div>
<p>Start the experiment </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">345</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-102'>
<div class='docs'>
<div class='section-link'>
<a href='#section-102'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">TrainValidConfigs</span><span class="o">.</span><span class="n">run</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">347</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-103'>
<div class='docs'>
<div class='section-link'>
<a href='#section-103'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">351</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">352</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>