mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
Zero3 memory optimizations (#140)
This commit is contained in:
400
docs/neox/utils/cache.html
Normal file
400
docs/neox/utils/cache.html
Normal file
@ -0,0 +1,400 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="Cache for intermediate activations for faster inference."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Cache for Intermediate Activations"/>
|
||||
<meta name="twitter:description" content="Cache for intermediate activations for faster inference."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/neox/utils/cache.html"/>
|
||||
<meta property="og:title" content="Cache for Intermediate Activations"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="Cache for Intermediate Activations"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Cache for Intermediate Activations"/>
|
||||
<meta property="og:description" content="Cache for intermediate activations for faster inference."/>
|
||||
|
||||
<title>Cache for Intermediate Activations</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css?v=1">
|
||||
<link rel="canonical" href="https://nn.labml.ai/neox/utils/cache.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">neox</a>
|
||||
<a class="parent" href="index.html">utils</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/sponsors/labmlai" target="_blank">
|
||||
<img alt="Sponsor"
|
||||
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/neox/utils/cache.py" target="_blank">
|
||||
View code on Github</a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Cache for Intermediate Activations</h1>
|
||||
<p>During inference the model outputs token by token. We use this simple cache to store key's and value's attention layers, so that we don't have to recompute them for previous tokens.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>Cache</h2>
|
||||
<p>This maintains a key-value cache and queues push values and pop them in the same order. The queues are useful since we have multiple attention layers.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">18</span><span class="k">class</span> <span class="nc">Cache</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">26</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="lineno">27</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span> <span class="o">=</span> <span class="p">{}</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<h3>Clear cache</h3>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">29</span> <span class="k">def</span> <span class="nf">clear_all</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">33</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span> <span class="o">=</span> <span class="p">{}</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<h3>Push a value to a queue</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">name</span></code>
|
||||
is the name of the queue </li>
|
||||
<li><code class="highlight"><span></span><span class="n">value</span></code>
|
||||
is the value to be pushed</li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">35</span> <span class="k">def</span> <span class="nf">push</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Create an empty queue if it's not present </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">:</span>
|
||||
<span class="lineno">45</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Push to the queue </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">48</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<h3>Return the size of the queue</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">name</span></code>
|
||||
is the name of the queue </li>
|
||||
<p><em>Returns</em> size of the queue if exists else None</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">50</span> <span class="k">def</span> <span class="nf">q_size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">58</span> <span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">:</span>
|
||||
<span class="lineno">59</span> <span class="k">return</span> <span class="kc">None</span>
|
||||
<span class="lineno">60</span>
|
||||
<span class="lineno">61</span> <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">name</span><span class="p">])</span> <span class="o">!=</span> <span class="nb">list</span><span class="p">:</span>
|
||||
<span class="lineno">62</span> <span class="k">return</span> <span class="kc">None</span>
|
||||
<span class="lineno">63</span>
|
||||
<span class="lineno">64</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">name</span><span class="p">])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<h3>Pop from a queue</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">name</span></code>
|
||||
is the name of the queue </li>
|
||||
<p><em>Returns</em> the value</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">def</span> <span class="nf">pop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">73</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<h3>Cache a value</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">key</span></code>
|
||||
is the name of the value to be cached </li>
|
||||
<li><code class="highlight"><span></span><span class="n">value</span></code>
|
||||
is the value</li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">def</span> <span class="nf">set</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">82</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<h3>Retrieve a value from cache</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">key</span></code>
|
||||
is the name used when caching </li>
|
||||
<li><code class="highlight"><span></span><span class="n">default</span></code>
|
||||
is the default value if the cache is empty </li>
|
||||
<p><em>Returns</em> the cached value</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">84</span> <span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default</span><span class="p">:</span> <span class="n">Any</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">92</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">default</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<h3>Clear a cache value</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">key</span></code>
|
||||
is the name used when caching</li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">94</span> <span class="k">def</span> <span class="nf">clear</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">100</span> <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Singleton for cache </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">104</span><span class="n">_INSTANCE</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<h3>Get the cache instance</h3>
|
||||
<ul><p><em>Returns</em> the cache instance</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">107</span><span class="k">def</span> <span class="nf">get_cache</span><span class="p">()</span> <span class="o">-></span> <span class="n">Cache</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">113</span> <span class="k">global</span> <span class="n">_INSTANCE</span>
|
||||
<span class="lineno">114</span>
|
||||
<span class="lineno">115</span> <span class="k">if</span> <span class="n">_INSTANCE</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">116</span> <span class="n">_INSTANCE</span> <span class="o">=</span> <span class="n">Cache</span><span class="p">()</span>
|
||||
<span class="lineno">117</span>
|
||||
<span class="lineno">118</span> <span class="k">return</span> <span class="n">_INSTANCE</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src=../../interactive.js?v=1"></script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
297
docs/neox/utils/finetune.html
Normal file
297
docs/neox/utils/finetune.html
Normal file
@ -0,0 +1,297 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content=""/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="finetune.py"/>
|
||||
<meta name="twitter:description" content=""/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/neox/utils/finetune.html"/>
|
||||
<meta property="og:title" content="finetune.py"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="finetune.py"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="finetune.py"/>
|
||||
<meta property="og:description" content=""/>
|
||||
|
||||
<title>finetune.py</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css?v=1">
|
||||
<link rel="canonical" href="https://nn.labml.ai/neox/utils/finetune.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">neox</a>
|
||||
<a class="parent" href="index.html">utils</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/sponsors/labmlai" target="_blank">
|
||||
<img alt="Sponsor"
|
||||
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/neox/utils/finetune.py" target="_blank">
|
||||
View code on Github</a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span>
|
||||
<span class="lineno">2</span>
|
||||
<span class="lineno">3</span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">4</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||
<span class="lineno">5</span>
|
||||
<span class="lineno">6</span><span class="kn">from</span> <span class="nn">labml_nn.neox.model</span> <span class="kn">import</span> <span class="n">TransformerLayer</span><span class="p">,</span> <span class="n">NeoXModule</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">9</span><span class="k">class</span> <span class="nc">FineTuner</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">10</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">NeoXModule</span><span class="p">]):</span>
|
||||
<span class="lineno">11</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">layers</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">13</span> <span class="k">def</span> <span class="nf">get_trainable_params</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]:</span>
|
||||
<span class="lineno">14</span> <span class="n">params</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="lineno">15</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span>
|
||||
<span class="lineno">16</span> <span class="n">params</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_layer_trainable_params</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">prefix</span><span class="o">=</span><span class="sa">f</span><span class="s1">'layer_</span><span class="si">{</span><span class="n">i</span> <span class="si">:</span><span class="s1">02d</span><span class="si">}</span><span class="s1">'</span><span class="p">))</span>
|
||||
<span class="lineno">17</span>
|
||||
<span class="lineno">18</span> <span class="k">return</span> <span class="n">params</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">20</span> <span class="k">def</span> <span class="nf">get_layer_trainable_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">NeoXModule</span><span class="p">,</span> <span class="n">prefix</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]:</span>
|
||||
<span class="lineno">21</span> <span class="k">raise</span> <span class="ne">NotImplementedError</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">23</span> <span class="k">def</span> <span class="nf">set_trainable_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="lineno">24</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Set <code class="highlight"><span></span><span class="n">requires_grad</span></code>
|
||||
to <code class="highlight"><span></span><span class="kc">False</span></code>
|
||||
for the entire layer. </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">26</span> <span class="n">layer</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">28</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_trainable_params</span><span class="p">()</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
||||
<span class="lineno">29</span> <span class="n">p</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">31</span> <span class="k">def</span> <span class="nf">state_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="lineno">32</span> <span class="k">return</span> <span class="p">{</span><span class="n">n</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_trainable_params</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">34</span> <span class="k">def</span> <span class="nf">load_state_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
|
||||
<span class="lineno">35</span> <span class="n">params</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_trainable_params</span><span class="p">()</span>
|
||||
<span class="lineno">36</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||||
<span class="lineno">37</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">n</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="lineno">38</span>
|
||||
<span class="lineno">39</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
||||
<span class="lineno">40</span> <span class="k">assert</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">params</span><span class="p">,</span> <span class="n">n</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">43</span><span class="k">class</span> <span class="nc">FineTuneBiases</span><span class="p">(</span><span class="n">FineTuner</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">44</span> <span class="k">def</span> <span class="nf">get_layer_trainable_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">NeoXModule</span><span class="p">,</span> <span class="n">prefix</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">]:</span>
|
||||
<span class="lineno">45</span> <span class="n">params</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="lineno">46</span>
|
||||
<span class="lineno">47</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">TransformerLayer</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>No need to train the mlp bias because we are adding it with attention output </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">params</span><span class="p">[</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">prefix</span><span class="si">}</span><span class="s1">.attention.output.bias'</span><span class="p">]</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">output</span><span class="o">.</span><span class="n">bias</span>
|
||||
<span class="lineno">50</span> <span class="n">params</span><span class="p">[</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">prefix</span><span class="si">}</span><span class="s1">.attention.qkv_lin.bias'</span><span class="p">]</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">qkv_lin</span><span class="o">.</span><span class="n">bias</span>
|
||||
<span class="lineno">51</span> <span class="n">params</span><span class="p">[</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">prefix</span><span class="si">}</span><span class="s1">.ffn.dense_h_h4.bias'</span><span class="p">]</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h_h4</span><span class="o">.</span><span class="n">bias</span>
|
||||
<span class="lineno">52</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">53</span> <span class="k">pass</span>
|
||||
<span class="lineno">54</span>
|
||||
<span class="lineno">55</span> <span class="k">return</span> <span class="n">params</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src=../../interactive.js?v=1"></script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
391
docs/neox/utils/index.html
Normal file
391
docs/neox/utils/index.html
Normal file
@ -0,0 +1,391 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="Utilities and helper functions"/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Utilities and Helpers"/>
|
||||
<meta name="twitter:description" content="Utilities and helper functions"/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/neox/utils/index.html"/>
|
||||
<meta property="og:title" content="Utilities and Helpers"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="Utilities and Helpers"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Utilities and Helpers"/>
|
||||
<meta property="og:description" content="Utilities and helper functions"/>
|
||||
|
||||
<title>Utilities and Helpers</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css?v=1">
|
||||
<link rel="canonical" href="https://nn.labml.ai/neox/utils/index.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">neox</a>
|
||||
<a class="parent" href="index.html">utils</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/sponsors/labmlai" target="_blank">
|
||||
<img alt="Sponsor"
|
||||
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/neox/utils/__init__.py" target="_blank">
|
||||
View code on Github</a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Utilities and Helpers</h1>
|
||||
<ul><li><a href="cache.html">Cache for intermediate activations (for faster inference)</a> </li>
|
||||
<li><a href="finetune.html">Tools for finetuning</a> </li>
|
||||
<li><a href="trainer.html">Trainer</a> </li>
|
||||
<li><a href="text_dataset.html">Text dataset</a></li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">typing</span>
|
||||
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span>
|
||||
<span class="lineno">17</span>
|
||||
<span class="lineno">18</span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">19</span>
|
||||
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">logger</span>
|
||||
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
|
||||
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.neox.tokenizer</span> <span class="kn">import</span> <span class="n">get_tokenizer</span>
|
||||
<span class="lineno">23</span>
|
||||
<span class="lineno">24</span><span class="k">if</span> <span class="n">typing</span><span class="o">.</span><span class="n">TYPE_CHECKING</span><span class="p">:</span>
|
||||
<span class="lineno">25</span> <span class="kn">from</span> <span class="nn">tokenizers</span> <span class="kn">import</span> <span class="n">Tokenizer</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<p>Tokenizer singleton </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">28</span><span class="n">_TOKENIZER</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s1">'Tokenizer'</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<h3>Get token ids</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">text</span></code>
|
||||
is the text to tokenize </li>
|
||||
<p><em>Returns</em> the token ids</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">31</span><span class="k">def</span> <span class="nf">get_tokens</span><span class="p">(</span><span class="n">text</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">38</span> <span class="k">global</span> <span class="n">_TOKENIZER</span>
|
||||
<span class="lineno">39</span> <span class="k">if</span> <span class="n">_TOKENIZER</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">40</span> <span class="n">_TOKENIZER</span> <span class="o">=</span> <span class="n">get_tokenizer</span><span class="p">()</span>
|
||||
<span class="lineno">41</span> <span class="k">return</span> <span class="n">_TOKENIZER</span><span class="o">.</span><span class="n">encode_batch</span><span class="p">([</span><span class="n">text</span><span class="p">])[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">ids</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<h3>Print tokens from model outputs</h3>
|
||||
<p>Pretty prints target tokens along side outputs from the model(s).</p>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">ids</span></code>
|
||||
are the target token ids </li>
|
||||
<li><code class="highlight"><span></span><span class="n">xs</span></code>
|
||||
are the model(s) outputs</li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">44</span><span class="k">def</span> <span class="nf">print_token_outputs</span><span class="p">(</span><span class="n">ids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="o">*</span><span class="n">xs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">ids</span> <span class="o">=</span> <span class="n">ids</span> <span class="o">+</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="lineno">54</span> <span class="n">xs</span> <span class="o">=</span> <span class="p">[[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">xs</span><span class="p">]</span>
|
||||
<span class="lineno">55</span>
|
||||
<span class="lineno">56</span> <span class="n">print_tokens</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">xs</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<h3>Print tokens</h3>
|
||||
<p>Pretty prints tokens for comparison</p>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">target</span></code>
|
||||
are the target token ids </li>
|
||||
<li><code class="highlight"><span></span><span class="n">others</span></code>
|
||||
are the sampled outputs from the model(s)</li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">59</span><span class="k">def</span> <span class="nf">print_tokens</span><span class="p">(</span><span class="n">target</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">others</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Load tokenizer </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">70</span> <span class="k">global</span> <span class="n">_TOKENIZER</span>
|
||||
<span class="lineno">71</span> <span class="k">if</span> <span class="n">_TOKENIZER</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">72</span> <span class="n">_TOKENIZER</span> <span class="o">=</span> <span class="n">get_tokenizer</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Convert the tokens to list of strings </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="n">text</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="lineno">76</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">target</span><span class="p">)):</span>
|
||||
<span class="lineno">77</span> <span class="n">tokens</span> <span class="o">=</span> <span class="p">[</span><span class="n">_TOKENIZER</span><span class="o">.</span><span class="n">decode</span><span class="p">([</span><span class="n">target</span><span class="p">[</span><span class="n">i</span><span class="p">]])</span> <span class="k">if</span> <span class="n">target</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span> <span class="k">else</span> <span class="s1">'---'</span><span class="p">]</span>
|
||||
<span class="lineno">78</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">others</span><span class="p">)):</span>
|
||||
<span class="lineno">79</span> <span class="n">tokens</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">_TOKENIZER</span><span class="o">.</span><span class="n">decode</span><span class="p">([</span><span class="n">others</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="n">i</span><span class="p">]])</span> <span class="k">if</span> <span class="n">others</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span> <span class="k">else</span> <span class="s1">'---'</span><span class="p">)</span>
|
||||
<span class="lineno">80</span>
|
||||
<span class="lineno">81</span> <span class="n">text</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>Stats </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">correct</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">others</span><span class="p">]</span>
|
||||
<span class="lineno">85</span> <span class="n">total</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Iterate through tokens </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">88</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">target</span><span class="p">)):</span>
|
||||
<span class="lineno">89</span> <span class="n">parts</span> <span class="o">=</span> <span class="p">[(</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">: '</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">meta</span><span class="p">)]</span>
|
||||
<span class="lineno">90</span> <span class="n">parts</span> <span class="o">+=</span> <span class="p">[(</span><span class="s1">'"'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="p">(</span><span class="n">text</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="p">(</span><span class="s1">'"'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="s1">'</span><span class="se">\t</span><span class="s1">'</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Empty target </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">93</span> <span class="k">if</span> <span class="n">target</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
|
||||
<span class="lineno">94</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">others</span><span class="p">)):</span>
|
||||
<span class="lineno">95</span> <span class="n">parts</span> <span class="o">+=</span> <span class="p">[(</span><span class="s1">'"'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="p">(</span><span class="n">text</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="p">(</span><span class="s1">'"'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="s1">'</span><span class="se">\t</span><span class="s1">'</span><span class="p">]</span>
|
||||
<span class="lineno">96</span>
|
||||
<span class="lineno">97</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">parts</span><span class="p">)</span>
|
||||
<span class="lineno">98</span> <span class="k">continue</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Number of tokens </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">total</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p>Other outputs </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">104</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">others</span><span class="p">)):</span>
|
||||
<span class="lineno">105</span> <span class="n">correct</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">others</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">else</span> <span class="mi">0</span>
|
||||
<span class="lineno">106</span>
|
||||
<span class="lineno">107</span> <span class="n">parts</span> <span class="o">+=</span> <span class="p">[(</span><span class="s1">'"'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span>
|
||||
<span class="lineno">108</span> <span class="p">(</span><span class="n">text</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span> <span class="n">Text</span><span class="o">.</span><span class="n">success</span> <span class="k">if</span> <span class="n">others</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">else</span> <span class="n">Text</span><span class="o">.</span><span class="n">danger</span><span class="p">),</span>
|
||||
<span class="lineno">109</span> <span class="p">(</span><span class="s1">'"'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">subtle</span><span class="p">),</span> <span class="s1">'</span><span class="se">\t</span><span class="s1">'</span><span class="p">]</span>
|
||||
<span class="lineno">110</span>
|
||||
<span class="lineno">111</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">parts</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Stats </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">parts</span> <span class="o">=</span> <span class="p">[(</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">total</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">highlight</span><span class="p">),</span> <span class="s1">'</span><span class="se">\t</span><span class="s1">'</span><span class="p">]</span>
|
||||
<span class="lineno">115</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">others</span><span class="p">)):</span>
|
||||
<span class="lineno">116</span> <span class="n">parts</span> <span class="o">+=</span> <span class="p">[(</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">correct</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">value</span><span class="p">),</span> <span class="s1">'</span><span class="se">\t</span><span class="s1">'</span><span class="p">]</span>
|
||||
<span class="lineno">117</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">parts</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<h3>Balance layers</h3>
|
||||
<p>Split the <code class="highlight"><span></span><span class="n">n_layers</span></code>
|
||||
into <code class="highlight"><span></span><span class="n">n_chunks</span></code>
|
||||
. This is used for pipeline parallel training.</p>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">n_layers</span></code>
|
||||
is the number of layers </li>
|
||||
<li><code class="highlight"><span></span><span class="n">n_chunks</span></code>
|
||||
is the number of chunks </li>
|
||||
<p><em>Returns</em> returns a list with the number of layers for each chunk</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">120</span><span class="k">def</span> <span class="nf">balance_layers_simple</span><span class="p">(</span><span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_chunks</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">130</span> <span class="n">balance</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="lineno">131</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_chunks</span><span class="p">):</span>
|
||||
<span class="lineno">132</span> <span class="n">balance</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">n_layers</span> <span class="o">-</span> <span class="nb">sum</span><span class="p">(</span><span class="n">balance</span><span class="p">))</span> <span class="o">//</span> <span class="p">(</span><span class="n">n_chunks</span> <span class="o">-</span> <span class="n">i</span><span class="p">))</span>
|
||||
<span class="lineno">133</span>
|
||||
<span class="lineno">134</span> <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="nb">reversed</span><span class="p">(</span><span class="n">balance</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src=../../interactive.js?v=1"></script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
447
docs/neox/utils/text_dataset.html
Normal file
447
docs/neox/utils/text_dataset.html
Normal file
@ -0,0 +1,447 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="Loads text datasets to fine-tune GPT-NeoX"/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Text Dataset for GPT-NeoX"/>
|
||||
<meta name="twitter:description" content="Loads text datasets to fine-tune GPT-NeoX"/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/neox/utils/text_dataset.html"/>
|
||||
<meta property="og:title" content="Text Dataset for GPT-NeoX"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="Text Dataset for GPT-NeoX"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Text Dataset for GPT-NeoX"/>
|
||||
<meta property="og:description" content="Loads text datasets to fine-tune GPT-NeoX"/>
|
||||
|
||||
<title>Text Dataset for GPT-NeoX</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css?v=1">
|
||||
<link rel="canonical" href="https://nn.labml.ai/neox/utils/text_dataset.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">neox</a>
|
||||
<a class="parent" href="index.html">utils</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/sponsors/labmlai" target="_blank">
|
||||
<img alt="Sponsor"
|
||||
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/neox/utils/text_dataset.py" target="_blank">
|
||||
View code on Github</a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Text Dataset for GPT-NeoX</h1>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">PurePath</span><span class="p">,</span> <span class="n">Path</span>
|
||||
<span class="lineno">11</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">List</span>
|
||||
<span class="lineno">12</span>
|
||||
<span class="lineno">13</span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">14</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
|
||||
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span>
|
||||
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span>
|
||||
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span>
|
||||
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml.utils.download</span> <span class="kn">import</span> <span class="n">download_file</span>
|
||||
<span class="lineno">19</span>
|
||||
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.neox.tokenizer</span> <span class="kn">import</span> <span class="n">get_tokenizer</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h3>Load text file</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">path</span></code>
|
||||
is the location of the text file </li>
|
||||
<li><code class="highlight"><span></span><span class="n">url</span></code>
|
||||
is the URL to download the file from </li>
|
||||
<li><code class="highlight"><span></span><span class="n">filter_subset</span></code>
|
||||
is the number of characters to filter. Use this during testing when trying large datasets </li>
|
||||
<p><em>Returns</em> the text content</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">23</span><span class="k">def</span> <span class="nf">load_text</span><span class="p">(</span><span class="n">path</span><span class="p">:</span> <span class="n">PurePath</span><span class="p">,</span> <span class="n">url</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">filter_subset</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">34</span> <span class="n">path</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>Download if it doesn't exist </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
|
||||
<span class="lineno">38</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">url</span><span class="p">:</span>
|
||||
<span class="lineno">39</span> <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">))</span>
|
||||
<span class="lineno">40</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">41</span> <span class="n">download_file</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span>
|
||||
<span class="lineno">42</span>
|
||||
<span class="lineno">43</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">"Load data"</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Load data </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">45</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">),</span> <span class="s1">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||||
<span class="lineno">46</span> <span class="n">text</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Filter </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">48</span> <span class="k">if</span> <span class="n">filter_subset</span><span class="p">:</span>
|
||||
<span class="lineno">49</span> <span class="n">text</span> <span class="o">=</span> <span class="n">text</span><span class="p">[:</span><span class="n">filter_subset</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">52</span> <span class="k">return</span> <span class="n">text</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<h2>Dataset for fine-tuning GPT-NeoX</h2>
|
||||
<p>This is not optimized to very large datasets.</p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">55</span><span class="k">class</span> <span class="nc">NeoXDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">tokens</span></code>
|
||||
is the list of token ids </li>
|
||||
<li><code class="highlight"><span></span><span class="n">seq_len</span></code>
|
||||
is the sequence length of a single training sample</li></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tokens</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span> <span class="o">=</span> <span class="n">seq_len</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Number of samples </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">n_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span> <span class="o">//</span> <span class="n">seq_len</span>
|
||||
<span class="lineno">71</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span> <span class="o">=</span> <span class="n">n_samples</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Truncate </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">tokens</span> <span class="o">=</span> <span class="n">tokens</span><span class="p">[:</span><span class="n">n_samples</span> <span class="o">*</span> <span class="n">seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Create a PyTorch tensor </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="lineno">78</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<h3>Get a sample</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">idx</span></code>
|
||||
is the index of the sample </li>
|
||||
<p><em>Returns</em> the input and the target</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">80</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">offset</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span>
|
||||
<span class="lineno">88</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens</span><span class="p">[</span><span class="n">offset</span><span class="p">:</span><span class="n">offset</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens</span><span class="p">[</span><span class="n">offset</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:</span><span class="n">offset</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq_len</span><span class="p">]</span>
|
||||
<span class="lineno">89</span>
|
||||
<span class="lineno">90</span>
|
||||
<span class="lineno">91</span><span class="n">DATASETS</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="lineno">92</span> <span class="s1">'tiny_shakespeare'</span><span class="p">:</span> <span class="p">{</span>
|
||||
<span class="lineno">93</span> <span class="s1">'file'</span><span class="p">:</span> <span class="s1">'tiny_shakespeare.txt'</span><span class="p">,</span>
|
||||
<span class="lineno">94</span> <span class="s1">'url'</span><span class="p">:</span> <span class="s1">'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'</span>
|
||||
<span class="lineno">95</span> <span class="p">}</span>
|
||||
<span class="lineno">96</span><span class="p">}</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<h3>Load Dataset</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">seq_len</span></code>
|
||||
is the sequence length of a single training sample </li>
|
||||
<li><code class="highlight"><span></span><span class="n">dataset_name</span></code>
|
||||
is the name of the dataset </li>
|
||||
<p><em>Returns</em> the dataset</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">99</span><span class="k">def</span> <span class="nf">get_training_data</span><span class="p">(</span><span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'tiny_shakespeare'</span><span class="p">,</span> <span class="n">truncate</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">DATASETS</span><span class="p">[</span><span class="n">dataset_name</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Load the content </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">110</span> <span class="n">text</span> <span class="o">=</span> <span class="n">load_text</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="n">ds</span><span class="p">[</span><span class="s1">'file'</span><span class="p">],</span> <span class="n">ds</span><span class="p">[</span><span class="s1">'url'</span><span class="p">])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Tokenize </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">112</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_tokenizer</span><span class="p">()</span>
|
||||
<span class="lineno">113</span> <span class="n">tokens</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">encode_batch</span><span class="p">([</span><span class="n">text</span><span class="p">])[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="lineno">114</span>
|
||||
<span class="lineno">115</span> <span class="k">if</span> <span class="n">truncate</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="lineno">116</span> <span class="n">token_ids</span> <span class="o">=</span> <span class="n">tokens</span><span class="o">.</span><span class="n">ids</span><span class="p">[:</span><span class="n">truncate</span> <span class="o">*</span> <span class="n">seq_len</span><span class="p">]</span>
|
||||
<span class="lineno">117</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">118</span> <span class="n">token_ids</span> <span class="o">=</span> <span class="n">tokens</span><span class="o">.</span><span class="n">ids</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">121</span> <span class="k">return</span> <span class="n">NeoXDataset</span><span class="p">(</span><span class="n">token_ids</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">124</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span>
|
||||
<span class="lineno">125</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">get_training_data</span><span class="p">()</span>
|
||||
<span class="lineno">126</span>
|
||||
<span class="lineno">127</span> <span class="n">inspect</span><span class="p">(</span><span class="n">tokens</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">tokens</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">131</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">132</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src=../../interactive.js?v=1"></script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
541
docs/neox/utils/trainer.html
Normal file
541
docs/neox/utils/trainer.html
Normal file
@ -0,0 +1,541 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content=""/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="trainer.py"/>
|
||||
<meta name="twitter:description" content=""/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/neox/utils/trainer.html"/>
|
||||
<meta property="og:title" content="trainer.py"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="trainer.py"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="trainer.py"/>
|
||||
<meta property="og:description" content=""/>
|
||||
|
||||
<title>trainer.py</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css?v=1">
|
||||
<link rel="canonical" href="https://nn.labml.ai/neox/utils/trainer.html"/>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||||
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">neox</a>
|
||||
<a class="parent" href="index.html">utils</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/sponsors/labmlai" target="_blank">
|
||||
<img alt="Sponsor"
|
||||
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/neox/utils/trainer.py" target="_blank">
|
||||
View code on Github</a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Set</span><span class="p">,</span> <span class="n">List</span>
|
||||
<span class="lineno">2</span>
|
||||
<span class="lineno">3</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
|
||||
<span class="lineno">4</span><span class="kn">import</span> <span class="nn">torch.optim</span>
|
||||
<span class="lineno">5</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
|
||||
<span class="lineno">6</span><span class="kn">from</span> <span class="nn">torch.cuda</span> <span class="kn">import</span> <span class="n">amp</span>
|
||||
<span class="lineno">7</span><span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">GradScaler</span>
|
||||
<span class="lineno">8</span>
|
||||
<span class="lineno">9</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">tracker</span>
|
||||
<span class="lineno">10</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span><span class="p">,</span> <span class="n">option</span>
|
||||
<span class="lineno">11</span><span class="kn">from</span> <span class="nn">labml_nn.neox.utils.finetune</span> <span class="kn">import</span> <span class="n">FineTuner</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h3>Get trainable parameters</h3>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">model</span></code>
|
||||
is the model to train </li>
|
||||
<p><em>Returns</em> a list of parameters for training</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">14</span><span class="k">def</span> <span class="nf">get_trainable_params</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<p>Get all parameters </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">23</span> <span class="n">params</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>Filter parameters that require gradients </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">25</span> <span class="n">trainable_params</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">params</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p> </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">28</span> <span class="k">return</span> <span class="n">trainable_params</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">31</span><span class="k">class</span> <span class="nc">TrainerConf</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span>
|
||||
<span class="lineno">32</span> <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span>
|
||||
<span class="lineno">33</span> <span class="n">layers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">]</span>
|
||||
<span class="lineno">34</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Optimizer</span> <span class="o">=</span> <span class="s1">'Adam'</span>
|
||||
<span class="lineno">35</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span>
|
||||
<span class="lineno">36</span> <span class="n">valid_loader</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="lineno">37</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">'cuda:0'</span><span class="p">)</span>
|
||||
<span class="lineno">38</span> <span class="n">scaler</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">GradScaler</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'Default'</span>
|
||||
<span class="lineno">39</span> <span class="n">is_amp</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="lineno">40</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span>
|
||||
<span class="lineno">41</span>
|
||||
<span class="lineno">42</span> <span class="n">is_clone_layers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="lineno">43</span>
|
||||
<span class="lineno">44</span> <span class="n">loss_func</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span>
|
||||
<span class="lineno">45</span> <span class="n">checkpoints_per_epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="lineno">46</span> <span class="n">samples_per_epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="lineno">47</span>
|
||||
<span class="lineno">48</span> <span class="n">grad_norm</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="mf">1.0</span>
|
||||
<span class="lineno">49</span> <span class="n">learning_rate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">3e-4</span>
|
||||
<span class="lineno">50</span> <span class="n">max_seq_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span>
|
||||
<span class="lineno">51</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span>
|
||||
<span class="lineno">52</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span>
|
||||
<span class="lineno">53</span>
|
||||
<span class="lineno">54</span> <span class="n">n_gpus</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">device_count</span><span class="p">()</span>
|
||||
<span class="lineno">55</span>
|
||||
<span class="lineno">56</span> <span class="n">filter_layers</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Set</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<ul><li><code class="highlight"><span></span><span class="n">dataset_split</span></code>
|
||||
train/valid </li>
|
||||
<li><code class="highlight"><span></span><span class="n">sample</span></code>
|
||||
is the sample </li>
|
||||
<p><em>Returns</em> the loss, output and the target</p></ul>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">58</span> <span class="k">def</span> <span class="nf">get_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">,</span> <span class="n">dataset_split</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">data</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">sample</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>Forward pass </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">67</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">'Forward pass'</span><span class="p">):</span>
|
||||
<span class="lineno">68</span> <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>Move targets to the same device as output </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">70</span> <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Calculate loss </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">72</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">numel</span><span class="p">(),</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">target</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
|
||||
<span class="lineno">73</span>
|
||||
<span class="lineno">74</span> <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="lineno">77</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epochs</span><span class="p">):</span>
|
||||
<span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_epoch</span><span class="p">()</span>
|
||||
<span class="lineno">79</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">81</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span>
|
||||
<span class="lineno">82</span> <span class="k">pass</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">84</span> <span class="k">def</span> <span class="nf">save_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span>
|
||||
<span class="lineno">85</span> <span class="k">pass</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">87</span> <span class="k">def</span> <span class="nf">get_iterators</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Iterate through the batches </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">iterators</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">'train'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)]</span>
|
||||
<span class="lineno">90</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">91</span> <span class="n">iterators</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="s1">'valid'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">))</span>
|
||||
<span class="lineno">92</span>
|
||||
<span class="lineno">93</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_per_epoch</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="lineno">94</span> <span class="n">iterators</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">,</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">samples_per_epoch</span><span class="p">)]))</span>
|
||||
<span class="lineno">95</span>
|
||||
<span class="lineno">96</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_per_epoch</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="lineno">97</span> <span class="n">iterators</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">,</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_per_epoch</span><span class="p">)]))</span>
|
||||
<span class="lineno">98</span>
|
||||
<span class="lineno">99</span> <span class="k">return</span> <span class="n">iterators</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">def</span> <span class="nf">train_epoch</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Set model for train </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
|
||||
<span class="lineno">104</span>
|
||||
<span class="lineno">105</span> <span class="n">iterators</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_iterators</span><span class="p">()</span>
|
||||
<span class="lineno">106</span> <span class="k">for</span> <span class="n">split_name</span><span class="p">,</span> <span class="n">sample</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">mix</span><span class="p">(</span><span class="mi">1024</span><span class="p">,</span> <span class="o">*</span><span class="n">iterators</span><span class="p">):</span>
|
||||
<span class="lineno">107</span> <span class="k">if</span> <span class="n">split_name</span> <span class="o">==</span> <span class="s1">'train'</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>Set gradients to zero </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||||
<span class="lineno">110</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">()</span>
|
||||
<span class="lineno">111</span>
|
||||
<span class="lineno">112</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">set_grad_enabled</span><span class="p">(</span><span class="n">split_name</span> <span class="o">==</span> <span class="s1">'train'</span><span class="p">):</span>
|
||||
<span class="lineno">113</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_amp</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Forward pass </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">115</span> <span class="k">with</span> <span class="n">amp</span><span class="o">.</span><span class="n">autocast</span><span class="p">():</span>
|
||||
<span class="lineno">116</span> <span class="n">loss</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_loss</span><span class="p">(</span><span class="n">sample</span><span class="p">,</span> <span class="n">split_name</span><span class="p">)</span>
|
||||
<span class="lineno">117</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">118</span> <span class="n">loss</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_loss</span><span class="p">(</span><span class="n">sample</span><span class="p">,</span> <span class="n">split_name</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Get predictions </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">121</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Calculate accuracy </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">/</span> <span class="p">(</span><span class="n">target</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">100</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
|
||||
<span class="lineno">124</span>
|
||||
<span class="lineno">125</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add</span><span class="p">({</span><span class="sa">f</span><span class="s1">'loss.</span><span class="si">{</span><span class="n">split_name</span><span class="si">}</span><span class="s1">'</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'acc.</span><span class="si">{</span><span class="n">split_name</span><span class="si">}</span><span class="s1">'</span><span class="p">:</span> <span class="n">accuracy</span> <span class="o">*</span> <span class="mi">100</span><span class="p">})</span>
|
||||
<span class="lineno">126</span>
|
||||
<span class="lineno">127</span> <span class="k">if</span> <span class="n">split_name</span> <span class="o">==</span> <span class="s1">'train'</span><span class="p">:</span>
|
||||
<span class="lineno">128</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>Backward pass </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">130</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span><span class="o">.</span><span class="n">scale</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>tracker.add({'loss.scaled': loss}) </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">'Backward pass'</span><span class="p">):</span>
|
||||
<span class="lineno">134</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-24'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>Optimize </p>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">137</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">'Optimize'</span><span class="p">):</span>
|
||||
<span class="lineno">138</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">139</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||||
<span class="lineno">140</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">141</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span><span class="o">.</span><span class="n">unscale_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
|
||||
<span class="lineno">142</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_norm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">143</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">clip_grad_norm_</span><span class="p">(</span><span class="n">get_trainable_params</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_norm</span><span class="p">)</span>
|
||||
<span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
|
||||
<span class="lineno">145</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span><span class="o">.</span><span class="n">update</span><span class="p">()</span>
|
||||
<span class="lineno">146</span>
|
||||
<span class="lineno">147</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-25'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">150</span><span class="nd">@option</span><span class="p">(</span><span class="n">TrainerConf</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">'Adam'</span><span class="p">)</span>
|
||||
<span class="lineno">151</span><span class="k">def</span> <span class="nf">adam_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TrainerConf</span><span class="p">):</span>
|
||||
<span class="lineno">152</span> <span class="k">if</span> <span class="n">c</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span>
|
||||
<span class="lineno">153</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">get_trainable_params</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">model</span><span class="p">),</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span>
|
||||
<span class="lineno">154</span> <span class="k">elif</span> <span class="n">c</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span>
|
||||
<span class="lineno">155</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.adam_fp16</span> <span class="kn">import</span> <span class="n">AdamFP16</span>
|
||||
<span class="lineno">156</span> <span class="k">return</span> <span class="n">AdamFP16</span><span class="p">(</span><span class="n">get_trainable_params</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">model</span><span class="p">),</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span>
|
||||
<span class="lineno">157</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">158</span> <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
|
||||
<span class="lineno">159</span>
|
||||
<span class="lineno">160</span>
|
||||
<span class="lineno">161</span><span class="nd">@option</span><span class="p">(</span><span class="n">TrainerConf</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s1">'SGD'</span><span class="p">)</span>
|
||||
<span class="lineno">162</span><span class="k">def</span> <span class="nf">sgd_optimizer</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TrainerConf</span><span class="p">):</span>
|
||||
<span class="lineno">163</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">get_trainable_params</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">model</span><span class="p">),</span> <span class="n">lr</span><span class="o">=</span><span class="n">c</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span>
|
||||
<span class="lineno">164</span>
|
||||
<span class="lineno">165</span>
|
||||
<span class="lineno">166</span><span class="nd">@option</span><span class="p">(</span><span class="n">TrainerConf</span><span class="o">.</span><span class="n">scaler</span><span class="p">,</span> <span class="s1">'Default'</span><span class="p">)</span>
|
||||
<span class="lineno">167</span><span class="k">def</span> <span class="nf">grad_scaler</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">TrainerConf</span><span class="p">):</span>
|
||||
<span class="lineno">168</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">c</span><span class="o">.</span><span class="n">is_amp</span><span class="p">:</span>
|
||||
<span class="lineno">169</span> <span class="k">return</span> <span class="kc">None</span>
|
||||
<span class="lineno">170</span>
|
||||
<span class="lineno">171</span> <span class="k">if</span> <span class="n">c</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span>
|
||||
<span class="lineno">172</span> <span class="kn">from</span> <span class="nn">labml_nn.optimizers.adam_fp16</span> <span class="kn">import</span> <span class="n">GradScalerFP16</span>
|
||||
<span class="lineno">173</span> <span class="k">return</span> <span class="n">GradScalerFP16</span><span class="p">()</span>
|
||||
<span class="lineno">174</span> <span class="k">else</span><span class="p">:</span>
|
||||
<span class="lineno">175</span> <span class="k">return</span> <span class="n">GradScaler</span><span class="p">()</span>
|
||||
<span class="lineno">176</span>
|
||||
<span class="lineno">177</span>
|
||||
<span class="lineno">178</span><span class="k">class</span> <span class="nc">PipelineParallelTrainerConf</span><span class="p">(</span><span class="n">TrainerConf</span><span class="p">):</span>
|
||||
<span class="lineno">179</span> <span class="n">is_checkpointing</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="lineno">180</span> <span class="n">chunks</span><span class="p">:</span> <span class="nb">int</span>
|
||||
<span class="lineno">181</span>
|
||||
<span class="lineno">182</span> <span class="n">fine_tuner</span><span class="p">:</span> <span class="n">FineTuner</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src=../../interactive.js?v=1"></script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
Reference in New Issue
Block a user