mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
1126 lines
66 KiB
HTML
1126 lines
66 KiB
HTML
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
|
<meta name="description" content="This is a reusable trainer for auto-regressive tasks"/>
|
|
|
|
<meta name="twitter:card" content="summary"/>
|
|
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
|
<meta name="twitter:title" content="NLP auto-regression trainer"/>
|
|
<meta name="twitter:description" content="This is a reusable trainer for auto-regressive tasks"/>
|
|
<meta name="twitter:site" content="@labmlai"/>
|
|
<meta name="twitter:creator" content="@labmlai"/>
|
|
|
|
<meta property="og:url" content="https://nn.labml.ai/experiments/nlp_autoregression.html"/>
|
|
<meta property="og:title" content="NLP auto-regression trainer"/>
|
|
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
|
<meta property="og:site_name" content="LabML Neural Networks"/>
|
|
<meta property="og:type" content="object"/>
|
|
<meta property="og:title" content="NLP auto-regression trainer"/>
|
|
<meta property="og:description" content="This is a reusable trainer for auto-regressive tasks"/>
|
|
|
|
<title>NLP auto-regression trainer</title>
|
|
<link rel="shortcut icon" href="/icon.png"/>
|
|
<link rel="stylesheet" href="../pylit.css?v=1">
|
|
<link rel="canonical" href="https://nn.labml.ai/experiments/nlp_autoregression.html"/>
|
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
|
|
|
<!-- Global site tag (gtag.js) - Google Analytics -->
|
|
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
|
<script>
|
|
window.dataLayer = window.dataLayer || [];
|
|
|
|
function gtag() {
|
|
dataLayer.push(arguments);
|
|
}
|
|
|
|
gtag('js', new Date());
|
|
|
|
gtag('config', 'G-4V3HC8HBLH');
|
|
</script>
|
|
</head>
|
|
<body>
|
|
<div id='container'>
|
|
<div id="background"></div>
|
|
<div class='section'>
|
|
<div class='docs'>
|
|
<p>
|
|
<a class="parent" href="/">home</a>
|
|
<a class="parent" href="index.html">experiments</a>
|
|
</p>
|
|
<p>
|
|
|
|
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/experiments/nlp_autoregression.py">
|
|
<img alt="Github"
|
|
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
|
style="max-width:100%;"/></a>
|
|
<a href="https://twitter.com/labmlai"
|
|
rel="nofollow">
|
|
<img alt="Twitter"
|
|
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
|
style="max-width:100%;"/></a>
|
|
</p>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-0'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-0'>#</a>
|
|
</div>
|
|
<h1>Auto-regressive NLP model trainer</h1>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span>
|
|
<span class="lineno">12</span>
|
|
<span class="lineno">13</span><span class="kn">import</span> <span class="nn">torch</span>
|
|
<span class="lineno">14</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
|
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">RandomSampler</span>
|
|
<span class="lineno">16</span>
|
|
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">monit</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">tracker</span>
|
|
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
|
|
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
|
|
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_helpers.datasets.text</span> <span class="kn">import</span> <span class="n">TextDataset</span><span class="p">,</span> <span class="n">SequentialDataLoader</span><span class="p">,</span> <span class="n">SequentialUnBatchedDataset</span><span class="p">,</span> <span class="n">TextFileDataset</span>
|
|
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
|
|
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_helpers.metrics.accuracy</span> <span class="kn">import</span> <span class="n">Accuracy</span>
|
|
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
|
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_helpers.train_valid</span> <span class="kn">import</span> <span class="n">TrainValidConfigs</span><span class="p">,</span> <span class="n">hook_model_outputs</span><span class="p">,</span> <span class="n">BatchIndex</span>
|
|
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml_nn.optimizers.configs</span> <span class="kn">import</span> <span class="n">OptimizerConfigs</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-1'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-1'>#</a>
|
|
</div>
|
|
<h3>Cross entropy loss</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">28</span><span class="k">class</span> <span class="nc">CrossEntropyLoss</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-2'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-2'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">33</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="lineno">34</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
<span class="lineno">35</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-3'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-3'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
|
|
<span class="lineno">38</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss</span><span class="p">(</span><span class="n">outputs</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">targets</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-4'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-4'>#</a>
|
|
</div>
|
|
<p> <a id="NLPAutoRegressionConfigs"></a></p>
|
|
<h2>Trainer configurations</h2>
|
|
<p>This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.</p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">41</span><span class="k">class</span> <span class="nc">NLPAutoRegressionConfigs</span><span class="p">(</span><span class="n">TrainValidConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-5'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-5'>#</a>
|
|
</div>
|
|
<p>Optimizer </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">52</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-6'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-6'>#</a>
|
|
</div>
|
|
<p>Training device </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-7'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-7'>#</a>
|
|
</div>
|
|
<p>Autoregressive model </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">model</span><span class="p">:</span> <span class="n">Module</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-8'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-8'>#</a>
|
|
</div>
|
|
<p>Text dataset </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">text</span><span class="p">:</span> <span class="n">TextDataset</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-9'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-9'>#</a>
|
|
</div>
|
|
<p>Batch size </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-10'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-10'>#</a>
|
|
</div>
|
|
<p>Length of the sequence, or context size </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-11'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-11'>#</a>
|
|
</div>
|
|
<p>Number of token in vocabulary </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">n_tokens</span><span class="p">:</span> <span class="nb">int</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-12'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-12'>#</a>
|
|
</div>
|
|
<p>Tokenizer </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">67</span> <span class="n">tokenizer</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="s1">'character'</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-13'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-13'>#</a>
|
|
</div>
|
|
<p>Text prompt to start sampling (for illustration) </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">prompt</span><span class="p">:</span> <span class="nb">str</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-14'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-14'>#</a>
|
|
</div>
|
|
<p>The token separator when sampling (blank for character level tokenization) </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">prompt_separator</span><span class="p">:</span> <span class="nb">str</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-15'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-15'>#</a>
|
|
</div>
|
|
<p>Whether to periodically save models </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">is_save_models</span> <span class="o">=</span> <span class="kc">True</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-16'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-16'>#</a>
|
|
</div>
|
|
<p>Loss function </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">78</span> <span class="n">loss_func</span> <span class="o">=</span> <span class="n">CrossEntropyLoss</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-17'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-17'>#</a>
|
|
</div>
|
|
<p>Accuracy function </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-18'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-18'>#</a>
|
|
</div>
|
|
<p>Model embedding size </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-19'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-19'>#</a>
|
|
</div>
|
|
<p>Gradient clipping </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">grad_norm_clip</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-20'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-20'>#</a>
|
|
</div>
|
|
<p>Training data loader </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="s1">'shuffled_train_loader'</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-21'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-21'>#</a>
|
|
</div>
|
|
<p>Validation data loader </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="s1">'shuffled_valid_loader'</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-22'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-22'>#</a>
|
|
</div>
|
|
<p>Data loaders shuffle with replacement </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">dataloader_shuffle_with_replacement</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-23'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-23'>#</a>
|
|
</div>
|
|
<p>Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks. </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">is_log_model_params_grads</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-24'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-24'>#</a>
|
|
</div>
|
|
<p>Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks. </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">102</span> <span class="n">is_log_model_activations</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-25'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-25'>#</a>
|
|
</div>
|
|
<h3>Initialization</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">104</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-26'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-26'>#</a>
|
|
</div>
|
|
<p>Set tracker configurations </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">109</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">"accuracy.*"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
|
<span class="lineno">110</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_scalar</span><span class="p">(</span><span class="s2">"loss.*"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span>
|
|
<span class="lineno">111</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_text</span><span class="p">(</span><span class="s2">"sampled"</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-27'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-27'>#</a>
|
|
</div>
|
|
<p>Add a hook to log module outputs </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">hook_model_outputs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-28'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-28'>#</a>
|
|
</div>
|
|
<p>Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation. </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_modules</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-29'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-29'>#</a>
|
|
</div>
|
|
<p>Override to calculate and log other metrics </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">def</span> <span class="nf">other_metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-30'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-30'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">pass</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-31'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-31'>#</a>
|
|
</div>
|
|
<h3>Training or validation step</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">any</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="n">BatchIndex</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-32'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-32'>#</a>
|
|
</div>
|
|
<p>Set training/eval mode </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">130</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-33'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-33'>#</a>
|
|
</div>
|
|
<p>Move data to the device </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-34'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-34'>#</a>
|
|
</div>
|
|
<p>Update global step (number of tokens processed) when in training mode </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span>
|
|
<span class="lineno">137</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-35'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-35'>#</a>
|
|
</div>
|
|
<p>Whether to capture model outputs </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">140</span> <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">is_log_activations</span><span class="o">=</span><span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_log_model_activations</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-36'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-36'>#</a>
|
|
</div>
|
|
<p>Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜 </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">output</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-37'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-37'>#</a>
|
|
</div>
|
|
<p>Calculate and log loss </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
|
<span class="lineno">148</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s2">"loss."</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-38'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-38'>#</a>
|
|
</div>
|
|
<p>Calculate and log accuracy </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">151</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
|
|
<span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">accuracy</span><span class="o">.</span><span class="n">track</span><span class="p">()</span>
|
|
<span class="lineno">153</span>
|
|
<span class="lineno">154</span> <span class="bp">self</span><span class="o">.</span><span class="n">other_metrics</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-39'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-39'>#</a>
|
|
</div>
|
|
<p>Train the model </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">157</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="o">.</span><span class="n">is_train</span><span class="p">:</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-40'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-40'>#</a>
|
|
</div>
|
|
<p>Calculate gradients </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">159</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-41'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-41'>#</a>
|
|
</div>
|
|
<p>Clip gradients </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">161</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">max_norm</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">grad_norm_clip</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-42'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-42'>#</a>
|
|
</div>
|
|
<p>Take optimizer step </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">163</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-43'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-43'>#</a>
|
|
</div>
|
|
<p>Log the model parameters and gradients on last batch of every epoch </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">165</span> <span class="k">if</span> <span class="n">batch_idx</span><span class="o">.</span><span class="n">is_last</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_log_model_params_grads</span><span class="p">:</span>
|
|
<span class="lineno">166</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="s1">'model'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-44'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-44'>#</a>
|
|
</div>
|
|
<p>Clear the gradients </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">168</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-45'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-45'>#</a>
|
|
</div>
|
|
<p>Save the tracked metrics </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-46'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-46'>#</a>
|
|
</div>
|
|
<h3>Sampling function to generate samples periodically while training</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-47'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-47'>#</a>
|
|
</div>
|
|
<p>Starting prompt </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">prompt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prompt</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-48'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-48'>#</a>
|
|
</div>
|
|
<p>Collect output for printing </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">log</span> <span class="o">=</span> <span class="p">[(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">)]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-49'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-49'>#</a>
|
|
</div>
|
|
<p>Sample 25 tokens </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">183</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">'Sample'</span><span class="p">,</span> <span class="mi">25</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-50'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-50'>#</a>
|
|
</div>
|
|
<p>Tokenize the prompt </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">185</span> <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">text_to_i</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="lineno">186</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-51'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-51'>#</a>
|
|
</div>
|
|
<p>Get the model output </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">output</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-52'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-52'>#</a>
|
|
</div>
|
|
<p>Get the model prediction (greedy) </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">190</span> <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-53'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-53'>#</a>
|
|
</div>
|
|
<p>Add the prediction to prompt </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">192</span> <span class="n">prompt</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prompt_separator</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-54'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-54'>#</a>
|
|
</div>
|
|
<p>Add the prediction for logging </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">194</span> <span class="n">log</span> <span class="o">+=</span> <span class="p">[(</span><span class="bp">self</span><span class="o">.</span><span class="n">prompt_separator</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">itos</span><span class="p">[</span><span class="n">output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">Text</span><span class="o">.</span><span class="n">value</span><span class="p">)]</span>
|
|
<span class="lineno">195</span>
|
|
<span class="lineno">196</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="s1">'sampled'</span><span class="p">:</span> <span class="n">prompt</span><span class="p">})</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-55'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-55'>#</a>
|
|
</div>
|
|
<p>Print the sampled output </p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">198</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">log</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-56'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-56'>#</a>
|
|
</div>
|
|
<h3>Default <a href="../optimizers/configs.html">optimizer configurations</a></h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">201</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
|
|
<span class="lineno">202</span><span class="k">def</span> <span class="nf">_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-57'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-57'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">OptimizerConfigs</span><span class="p">()</span>
|
|
<span class="lineno">208</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
|
|
<span class="lineno">209</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
|
<span class="lineno">210</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">d_model</span>
|
|
<span class="lineno">211</span>
|
|
<span class="lineno">212</span> <span class="k">return</span> <span class="n">optimizer</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-58'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-58'>#</a>
|
|
</div>
|
|
<p> Get number of tokens</p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">215</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">)</span>
|
|
<span class="lineno">216</span><span class="k">def</span> <span class="nf">_n_tokens</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-59'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-59'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">220</span> <span class="k">return</span> <span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">n_tokens</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-60'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-60'>#</a>
|
|
</div>
|
|
<h3>Basic english tokenizer</h3>
|
|
<p>We use character level tokenizer in this experiment. You can switch by setting,</p>
|
|
<pre class="highlight lang-"><code><span></span><span class="s1">'tokenizer'</span><span class="p">:</span> <span class="s1">'basic_english'</span><span class="p">,</span></code></pre>
|
|
<p>in the configurations dictionary when starting the experiment.</p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">223</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
|
<span class="lineno">224</span><span class="k">def</span> <span class="nf">basic_english</span><span class="p">():</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-61'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-61'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">238</span> <span class="kn">from</span> <span class="nn">torchtext.data</span> <span class="kn">import</span> <span class="n">get_tokenizer</span>
|
|
<span class="lineno">239</span> <span class="k">return</span> <span class="n">get_tokenizer</span><span class="p">(</span><span class="s1">'basic_english'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-62'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-62'>#</a>
|
|
</div>
|
|
<h3>Character level tokenizer</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">242</span><span class="k">def</span> <span class="nf">character_tokenizer</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-63'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-63'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">246</span> <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-64'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-64'>#</a>
|
|
</div>
|
|
<h3>Character level tokenizer configuration</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">249</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
|
<span class="lineno">250</span><span class="k">def</span> <span class="nf">character</span><span class="p">():</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-65'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-65'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">254</span> <span class="k">return</span> <span class="n">character_tokenizer</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-66'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-66'>#</a>
|
|
</div>
|
|
<h3>Tiny Shakespeare dataset</h3>
|
|
<p>It will download from the url if not present</p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">257</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">text</span><span class="p">)</span>
|
|
<span class="lineno">258</span><span class="k">def</span> <span class="nf">tiny_shakespeare</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-67'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-67'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">264</span> <span class="k">return</span> <span class="n">TextFileDataset</span><span class="p">(</span>
|
|
<span class="lineno">265</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">'tiny_shakespeare.txt'</span><span class="p">,</span>
|
|
<span class="lineno">266</span> <span class="n">c</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">,</span>
|
|
<span class="lineno">267</span> <span class="n">url</span><span class="o">=</span><span class="s1">'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-68'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-68'>#</a>
|
|
</div>
|
|
<h3>Sequential training data loader</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">270</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
|
|
<span class="lineno">271</span><span class="k">def</span> <span class="nf">sequential_train_loader</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-69'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-69'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">275</span> <span class="k">return</span> <span class="n">SequentialDataLoader</span><span class="p">(</span><span class="n">text</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">train</span><span class="p">,</span>
|
|
<span class="lineno">276</span> <span class="n">dataset</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="p">,</span>
|
|
<span class="lineno">277</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="lineno">278</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-70'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-70'>#</a>
|
|
</div>
|
|
<h3>Sequential validation data loader</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">281</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">)</span>
|
|
<span class="lineno">282</span><span class="k">def</span> <span class="nf">sequential_valid_loader</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-71'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-71'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">286</span> <span class="k">return</span> <span class="n">SequentialDataLoader</span><span class="p">(</span><span class="n">text</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">valid</span><span class="p">,</span>
|
|
<span class="lineno">287</span> <span class="n">dataset</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="p">,</span>
|
|
<span class="lineno">288</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="lineno">289</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-72'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-72'>#</a>
|
|
</div>
|
|
<h3>Transpose batch</h3>
|
|
<p><code class="highlight"><span></span><span class="n">DataLoader</span></code>
|
|
collects the batches on the first dimension. We need to transpose it to be sequence first.</p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">292</span><span class="k">def</span> <span class="nf">transpose_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-73'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-73'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">300</span> <span class="n">transposed_data</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">batch</span><span class="p">))</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-74'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-74'>#</a>
|
|
</div>
|
|
<p>Stack the batch along the second dimension <code class="highlight"><span></span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span></code>
|
|
</p>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">302</span> <span class="n">src</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">transposed_data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="lineno">303</span> <span class="n">tgt</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">transposed_data</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="lineno">304</span>
|
|
<span class="lineno">305</span> <span class="k">return</span> <span class="n">src</span><span class="p">,</span> <span class="n">tgt</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-75'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-75'>#</a>
|
|
</div>
|
|
<h3>Shuffled training data loader</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">308</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
|
|
<span class="lineno">309</span><span class="k">def</span> <span class="nf">shuffled_train_loader</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-76'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-76'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">313</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">SequentialUnBatchedDataset</span><span class="p">(</span><span class="n">text</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">train</span><span class="p">,</span>
|
|
<span class="lineno">314</span> <span class="n">dataset</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="p">,</span>
|
|
<span class="lineno">315</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">)</span>
|
|
<span class="lineno">316</span> <span class="n">sampler</span> <span class="o">=</span> <span class="n">RandomSampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">replacement</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dataloader_shuffle_with_replacement</span><span class="p">)</span>
|
|
<span class="lineno">317</span>
|
|
<span class="lineno">318</span> <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span>
|
|
<span class="lineno">319</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="lineno">320</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">transpose_batch</span><span class="p">,</span>
|
|
<span class="lineno">321</span> <span class="n">sampler</span><span class="o">=</span><span class="n">sampler</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-77'>
|
|
<div class='docs doc-strings'>
|
|
<div class='section-link'>
|
|
<a href='#section-77'>#</a>
|
|
</div>
|
|
<h3>Shuffled validation data loader</h3>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">324</span><span class="nd">@option</span><span class="p">(</span><span class="n">NLPAutoRegressionConfigs</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">)</span>
|
|
<span class="lineno">325</span><span class="k">def</span> <span class="nf">shuffled_valid_loader</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">NLPAutoRegressionConfigs</span><span class="p">):</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='section' id='section-78'>
|
|
<div class='docs'>
|
|
<div class='section-link'>
|
|
<a href='#section-78'>#</a>
|
|
</div>
|
|
|
|
</div>
|
|
<div class='code'>
|
|
<div class="highlight"><pre><span class="lineno">329</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">SequentialUnBatchedDataset</span><span class="p">(</span><span class="n">text</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">valid</span><span class="p">,</span>
|
|
<span class="lineno">330</span> <span class="n">dataset</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">text</span><span class="p">,</span>
|
|
<span class="lineno">331</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">seq_len</span><span class="p">)</span>
|
|
<span class="lineno">332</span> <span class="n">sampler</span> <span class="o">=</span> <span class="n">RandomSampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">replacement</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">dataloader_shuffle_with_replacement</span><span class="p">)</span>
|
|
<span class="lineno">333</span>
|
|
<span class="lineno">334</span> <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span>
|
|
<span class="lineno">335</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="lineno">336</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">transpose_batch</span><span class="p">,</span>
|
|
<span class="lineno">337</span> <span class="n">sampler</span><span class="o">=</span><span class="n">sampler</span><span class="p">)</span></pre></div>
|
|
</div>
|
|
</div>
|
|
<div class='footer'>
|
|
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
|
<a href="https://labml.ai">labml.ai</a>
|
|
</div>
|
|
</div>
|
|
<script src=../interactive.js?v=1"></script>
|
|
<script>
|
|
function handleImages() {
|
|
var images = document.querySelectorAll('p>img')
|
|
|
|
for (var i = 0; i < images.length; ++i) {
|
|
handleImage(images[i])
|
|
}
|
|
}
|
|
|
|
function handleImage(img) {
|
|
img.parentElement.style.textAlign = 'center'
|
|
|
|
var modal = document.createElement('div')
|
|
modal.id = 'modal'
|
|
|
|
var modalContent = document.createElement('div')
|
|
modal.appendChild(modalContent)
|
|
|
|
var modalImage = document.createElement('img')
|
|
modalContent.appendChild(modalImage)
|
|
|
|
var span = document.createElement('span')
|
|
span.classList.add('close')
|
|
span.textContent = 'x'
|
|
modal.appendChild(span)
|
|
|
|
img.onclick = function () {
|
|
console.log('clicked')
|
|
document.body.appendChild(modal)
|
|
modalImage.src = img.src
|
|
}
|
|
|
|
span.onclick = function () {
|
|
document.body.removeChild(modal)
|
|
}
|
|
}
|
|
|
|
handleImages()
|
|
</script>
|
|
</body>
|
|
</html> |