This commit is contained in:
Varuna Jayasiri
2022-08-20 11:13:36 +05:30
parent e19d95f9c3
commit 4860cc680b
8 changed files with 394 additions and 227 deletions

View File

@ -3,24 +3,24 @@
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="description" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="half_precision.py"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:title" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta name="twitter:description" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/neox/evaluation/half_precision.html"/>
<meta property="og:title" content="half_precision.py"/>
<meta property="og:title" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="half_precision.py"/>
<meta property="og:site_name" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="half_precision.py"/>
<meta property="og:description" content=""/>
<meta property="og:title" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta property="og:description" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<title>half_precision.py</title>
<title>Evaluate GPT-NeoX using LLM.int8() quantization on test suite</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/neox/evaluation/half_precision.html"/>
@ -71,32 +71,97 @@
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Evaluate GPT-NeoX using LLM.int8() quantization on test suite</h1>
<p>This code evaluate <a href="../index.html">GPT-NeoX</a> using, on a suite of tasks.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">15</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_nn.neox.evaluation</span> <span class="kn">import</span> <span class="n">run_eval_harness</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.neox.model</span> <span class="kn">import</span> <span class="n">LayerGenerator</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">2</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">3</span>
<span class="lineno">4</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span>
<span class="lineno">5</span><span class="kn">from</span> <span class="nn">labml_nn.neox.evaluation</span> <span class="kn">import</span> <span class="n">run_eval_harness</span>
<span class="lineno">6</span><span class="kn">from</span> <span class="nn">labml_nn.neox.model</span> <span class="kn">import</span> <span class="n">LayerGenerator</span>
<span class="lineno">7</span>
<span class="lineno">8</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">9</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span>
<span class="lineno">10</span> <span class="n">layers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">LayerGenerator</span><span class="p">(</span><span class="n">is_clone_layers</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">11</span> <span class="n">filter_layers</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="lineno">12</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
<span class="lineno">13</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span>
<span class="lineno">14</span> <span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">())</span>
<span class="lineno">15</span>
<span class="lineno">16</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Sequential&#39;</span><span class="p">):</span>
<span class="lineno">17</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span>
<span class="lineno">18</span>
<span class="lineno">19</span> <span class="nb">print</span><span class="p">(</span><span class="n">run_eval_harness</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s1">&#39;half_precision&#39;</span><span class="p">,</span> <span class="p">[</span><span class="s1">&#39;lambada&#39;</span><span class="p">],</span> <span class="n">device</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">20</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Device </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">22</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Load layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">24</span> <span class="n">layers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">LayerGenerator</span><span class="p">(</span><span class="n">is_clone_layers</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">25</span> <span class="n">filter_layers</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="lineno">26</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
<span class="lineno">27</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span>
<span class="lineno">28</span> <span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Create <code class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span></code>
model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Run <a href="index.html">evaluation harness</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="nb">print</span><span class="p">(</span><span class="n">run_eval_harness</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s1">&#39;half_precision&#39;</span><span class="p">,</span> <span class="p">[</span><span class="s1">&#39;lambada&#39;</span><span class="p">],</span> <span class="n">device</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">39</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -3,24 +3,24 @@
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="description" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="llm_int8.py"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:title" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta name="twitter:description" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/neox/evaluation/llm_int8.html"/>
<meta property="og:title" content="llm_int8.py"/>
<meta property="og:title" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="llm_int8.py"/>
<meta property="og:site_name" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="llm_int8.py"/>
<meta property="og:description" content=""/>
<meta property="og:title" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<meta property="og:description" content="Evaluate GPT-NeoX using LLM.int8() quantization on test suite"/>
<title>llm_int8.py</title>
<title>Evaluate GPT-NeoX using LLM.int8() quantization on test suite</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/neox/evaluation/llm_int8.html"/>
@ -71,26 +71,21 @@
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Evaluate GPT-NeoX using LLM.int8() quantization on test suite</h1>
<p>This code evaluate <a href="../index.html">GPT-NeoX</a> using <a href="../utils/llm_int8.html">LLM.int8() quantization</a>, on a suite of tasks.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">1</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">2</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">3</span>
<span class="lineno">4</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span>
<span class="lineno">5</span><span class="kn">from</span> <span class="nn">labml_nn.neox.evaluation</span> <span class="kn">import</span> <span class="n">run_eval_harness</span>
<span class="lineno">6</span><span class="kn">from</span> <span class="nn">labml_nn.neox.model</span> <span class="kn">import</span> <span class="n">LayerGenerator</span>
<span class="lineno">7</span>
<span class="lineno">8</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">9</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span>
<span class="lineno">10</span> <span class="n">layer_generator</span> <span class="o">=</span> <span class="n">LayerGenerator</span><span class="p">(</span><span class="n">is_clone_layers</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">11</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
<span class="lineno">12</span> <span class="n">device</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span>
<span class="lineno">13</span> <span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">14</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">16</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.neox.evaluation</span> <span class="kn">import</span> <span class="n">run_eval_harness</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.neox.model</span> <span class="kn">import</span> <span class="n">LayerGenerator</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -98,11 +93,10 @@
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p>Load layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">15</span> <span class="n">layers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">layer_generator</span><span class="o">.</span><span class="n">load</span><span class="p">())</span></pre></div>
<div class="highlight"><pre><span class="lineno">22</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -110,22 +104,94 @@
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Device </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">24</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>Load layers in float16 into CPU. We convert the layers to int8 later, because doing that on the fly after loading layers to GPU causes CUDA memory fragmentation (about 3GB memory can get lost due to fragmentation). </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span> <span class="n">layer_generator</span> <span class="o">=</span> <span class="n">LayerGenerator</span><span class="p">(</span><span class="n">is_clone_layers</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">30</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
<span class="lineno">31</span> <span class="n">device</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span>
<span class="lineno">32</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Load layers </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span> <span class="n">layers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">layer_generator</span><span class="o">.</span><span class="n">load</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>This reduces CUDA memory fragmentation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Convert to int8&#39;</span><span class="p">,</span> <span class="n">layers</span><span class="p">,</span> <span class="n">is_children_silent</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="lineno">19</span> <span class="n">layer_generator</span><span class="o">.</span><span class="n">post_load_prepare</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span>
<span class="lineno">20</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">21</span> <span class="n">is_llm_int8</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">22</span> <span class="n">llm_int8_threshold</span><span class="o">=</span><span class="mf">6.0</span><span class="p">,</span>
<span class="lineno">23</span> <span class="p">)</span>
<span class="lineno">24</span> <span class="n">layer</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="lineno">25</span>
<span class="lineno">26</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Sequential&#39;</span><span class="p">):</span>
<span class="lineno">27</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span>
<span class="lineno">28</span>
<span class="lineno">29</span> <span class="nb">print</span><span class="p">(</span><span class="n">run_eval_harness</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s1">&#39;half_precision&#39;</span><span class="p">,</span> <span class="p">[],</span> <span class="n">device</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">37</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Convert to int8&#39;</span><span class="p">,</span> <span class="n">layers</span><span class="p">,</span> <span class="n">is_children_silent</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="lineno">38</span> <span class="n">layer_generator</span><span class="o">.</span><span class="n">post_load_prepare</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span>
<span class="lineno">39</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">40</span> <span class="n">is_llm_int8</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">41</span> <span class="n">llm_int8_threshold</span><span class="o">=</span><span class="mf">6.0</span><span class="p">,</span>
<span class="lineno">42</span> <span class="p">)</span>
<span class="lineno">43</span> <span class="n">layer</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Create <code class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span></code>
model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Run <a href="index.html">evaluation harness</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="nb">print</span><span class="p">(</span><span class="n">run_eval_harness</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s1">&#39;half_precision&#39;</span><span class="p">,</span> <span class="p">[],</span> <span class="n">device</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">54</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -1643,7 +1643,8 @@
<div class='section-link'>
<a href='#section-118'>#</a>
</div>
<p> <a id="post_load_prepare"></a> ### Layer transformations after loading the checkpoint</p>
<p> <a id="post_load_prepare"></a></p>
<h3>Layer transformations after loading the checkpoint</h3>
<p>This function implements layer transformations after loading the checkpoint.</p>
<p>Currently, it only applies the int8 quantization.</p>
<ul><li><code class="highlight"><span></span><span class="n">layer</span></code>
@ -1675,12 +1676,12 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">537</span> <span class="k">if</span> <span class="n">is_llm_int8</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">538</span> <span class="n">is_llm_int8</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_llm_int8</span>
<span class="lineno">539</span> <span class="k">if</span> <span class="n">device</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">540</span> <span class="n">device</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span>
<span class="lineno">541</span> <span class="k">if</span> <span class="n">llm_int8_threshold</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">542</span> <span class="n">llm_int8_threshold</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_int8_threshold</span></pre></div>
<div class="highlight"><pre><span class="lineno">538</span> <span class="k">if</span> <span class="n">is_llm_int8</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">539</span> <span class="n">is_llm_int8</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_llm_int8</span>
<span class="lineno">540</span> <span class="k">if</span> <span class="n">device</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">541</span> <span class="n">device</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span>
<span class="lineno">542</span> <span class="k">if</span> <span class="n">llm_int8_threshold</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">543</span> <span class="n">llm_int8_threshold</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_int8_threshold</span></pre></div>
</div>
</div>
<div class='section' id='section-120'>
@ -1692,8 +1693,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">545</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">is_llm_int8</span><span class="p">:</span>
<span class="lineno">546</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
<div class="highlight"><pre><span class="lineno">546</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">is_llm_int8</span><span class="p">:</span>
<span class="lineno">547</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='section' id='section-121'>
@ -1705,8 +1706,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">549</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">TransformerLayer</span><span class="p">):</span>
<span class="lineno">550</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
<div class="highlight"><pre><span class="lineno">550</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">TransformerLayer</span><span class="p">):</span>
<span class="lineno">551</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='section' id='section-122'>
@ -1719,7 +1720,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">553</span> <span class="kn">from</span> <span class="nn">labml_nn.neox.utils.llm_int8</span> <span class="kn">import</span> <span class="n">make_llm_int8_linear</span></pre></div>
<div class="highlight"><pre><span class="lineno">554</span> <span class="kn">from</span> <span class="nn">labml_nn.neox.utils.llm_int8</span> <span class="kn">import</span> <span class="n">make_llm_int8_linear</span></pre></div>
</div>
</div>
<div class='section' id='section-123'>
@ -1731,19 +1732,19 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">556</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Convert to int8&#39;</span><span class="p">):</span>
<span class="lineno">557</span> <span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">output</span><span class="p">,</span>
<span class="lineno">558</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">559</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span>
<span class="lineno">560</span> <span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">qkv_lin</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">qkv_lin</span><span class="p">,</span>
<span class="lineno">561</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">562</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span>
<span class="lineno">563</span> <span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h_h4</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h_h4</span><span class="p">,</span>
<span class="lineno">564</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">565</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span>
<span class="lineno">566</span> <span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h4_h</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h4_h</span><span class="p">,</span>
<span class="lineno">567</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">568</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">557</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Convert to int8&#39;</span><span class="p">):</span>
<span class="lineno">558</span> <span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">output</span><span class="p">,</span>
<span class="lineno">559</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">560</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span>
<span class="lineno">561</span> <span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">qkv_lin</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">qkv_lin</span><span class="p">,</span>
<span class="lineno">562</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">563</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span>
<span class="lineno">564</span> <span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h_h4</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h_h4</span><span class="p">,</span>
<span class="lineno">565</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">566</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span>
<span class="lineno">567</span> <span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h4_h</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h4_h</span><span class="p">,</span>
<span class="lineno">568</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">569</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-124'>
@ -1755,7 +1756,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">570</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
<div class="highlight"><pre><span class="lineno">571</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='section' id='section-125'>
@ -1773,7 +1774,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">572</span> <span class="k">def</span> <span class="nf">_create_and_cache_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">creator</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[],</span> <span class="n">NeoXModule</span><span class="p">]):</span></pre></div>
<div class="highlight"><pre><span class="lineno">573</span> <span class="k">def</span> <span class="nf">_create_and_cache_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">creator</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[],</span> <span class="n">NeoXModule</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-126'>
@ -1784,14 +1785,14 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">584</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_clone_layers</span><span class="p">:</span>
<span class="lineno">585</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="n">creator</span><span class="p">())</span>
<span class="lineno">586</span>
<span class="lineno">587</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">588</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="n">creator</span><span class="p">())</span>
<span class="lineno">589</span>
<span class="lineno">590</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
<span class="lineno">591</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
<div class="highlight"><pre><span class="lineno">585</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_clone_layers</span><span class="p">:</span>
<span class="lineno">586</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="n">creator</span><span class="p">())</span>
<span class="lineno">587</span>
<span class="lineno">588</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">589</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="n">creator</span><span class="p">())</span>
<span class="lineno">590</span>
<span class="lineno">591</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
<span class="lineno">592</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='section' id='section-127'>
@ -1802,11 +1803,11 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">593</span> <span class="k">def</span> <span class="nf">_create_transformer_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">594</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_and_cache_layer</span><span class="p">(</span>
<span class="lineno">595</span> <span class="s1">&#39;transformer_layer&#39;</span><span class="p">,</span>
<span class="lineno">596</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">)</span>
<span class="lineno">597</span> <span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">594</span> <span class="k">def</span> <span class="nf">_create_transformer_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">595</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_and_cache_layer</span><span class="p">(</span>
<span class="lineno">596</span> <span class="s1">&#39;transformer_layer&#39;</span><span class="p">,</span>
<span class="lineno">597</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">)</span>
<span class="lineno">598</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-128'>
@ -1817,8 +1818,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">599</span> <span class="k">def</span> <span class="nf">_create_embedding_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">600</span> <span class="k">return</span> <span class="n">Embedding</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_vocab</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">600</span> <span class="k">def</span> <span class="nf">_create_embedding_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">601</span> <span class="k">return</span> <span class="n">Embedding</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_vocab</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-129'>
@ -1829,8 +1830,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">602</span> <span class="k">def</span> <span class="nf">_create_final_norm_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">603</span> <span class="k">return</span> <span class="n">FinalNorm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">603</span> <span class="k">def</span> <span class="nf">_create_final_norm_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">604</span> <span class="k">return</span> <span class="n">FinalNorm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-130'>
@ -1841,8 +1842,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">605</span> <span class="k">def</span> <span class="nf">_create_readout_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">606</span> <span class="k">return</span> <span class="n">ReadoutLayer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_vocab</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">606</span> <span class="k">def</span> <span class="nf">_create_readout_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">607</span> <span class="k">return</span> <span class="n">ReadoutLayer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-131'>
@ -1854,8 +1855,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">608</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">609</span> <span class="k">def</span> <span class="nf">get_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Generator</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">NeoXModule</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">str</span><span class="p">]],</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span></pre></div>
<div class="highlight"><pre><span class="lineno">609</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">610</span> <span class="k">def</span> <span class="nf">get_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Generator</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">NeoXModule</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">str</span><span class="p">]],</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-132'>
@ -1867,10 +1868,10 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">614</span> <span class="k">if</span> <span class="mi">0</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">615</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Embedding layer&#39;</span><span class="p">):</span>
<span class="lineno">616</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_create_embedding_layer</span><span class="p">())</span>
<span class="lineno">617</span> <span class="k">yield</span> <span class="n">layer</span><span class="p">,</span> <span class="p">(</span><span class="s1">&#39;layer_00-model_00-model_states.pt&#39;</span><span class="p">,</span> <span class="s1">&#39;layer_00-model_01-model_states.pt&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">615</span> <span class="k">if</span> <span class="mi">0</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">616</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Embedding layer&#39;</span><span class="p">):</span>
<span class="lineno">617</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_create_embedding_layer</span><span class="p">())</span>
<span class="lineno">618</span> <span class="k">yield</span> <span class="n">layer</span><span class="p">,</span> <span class="p">(</span><span class="s1">&#39;layer_00-model_00-model_states.pt&#39;</span><span class="p">,</span> <span class="s1">&#39;layer_00-model_01-model_states.pt&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-133'>
@ -1882,7 +1883,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">620</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">621</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-134'>
@ -1894,11 +1895,11 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">622</span> <span class="k">if</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">623</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Transformer Layer </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">):</span>
<span class="lineno">624</span> <span class="k">yield</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_transformer_layer</span><span class="p">(),</span> \
<span class="lineno">625</span> <span class="p">(</span><span class="sa">f</span><span class="s1">&#39;layer_</span><span class="si">{</span><span class="n">i</span> <span class="o">+</span> <span class="mi">2</span> <span class="si">:</span><span class="s1">02d</span><span class="si">}</span><span class="s1">-model_00-model_states.pt&#39;</span><span class="p">,</span>
<span class="lineno">626</span> <span class="sa">f</span><span class="s1">&#39;layer_</span><span class="si">{</span><span class="n">i</span> <span class="o">+</span> <span class="mi">2</span> <span class="si">:</span><span class="s1">02d</span><span class="si">}</span><span class="s1">-model_01-model_states.pt&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">623</span> <span class="k">if</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">624</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Transformer Layer </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">):</span>
<span class="lineno">625</span> <span class="k">yield</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_transformer_layer</span><span class="p">(),</span> \
<span class="lineno">626</span> <span class="p">(</span><span class="sa">f</span><span class="s1">&#39;layer_</span><span class="si">{</span><span class="n">i</span> <span class="o">+</span> <span class="mi">2</span> <span class="si">:</span><span class="s1">02d</span><span class="si">}</span><span class="s1">-model_00-model_states.pt&#39;</span><span class="p">,</span>
<span class="lineno">627</span> <span class="sa">f</span><span class="s1">&#39;layer_</span><span class="si">{</span><span class="n">i</span> <span class="o">+</span> <span class="mi">2</span> <span class="si">:</span><span class="s1">02d</span><span class="si">}</span><span class="s1">-model_01-model_states.pt&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-135'>
@ -1910,10 +1911,10 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">629</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">1</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">630</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Final norm layer&#39;</span><span class="p">):</span>
<span class="lineno">631</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_create_final_norm_layer</span><span class="p">())</span>
<span class="lineno">632</span> <span class="k">yield</span> <span class="n">layer</span><span class="p">,</span> <span class="p">(</span><span class="s1">&#39;layer_47-model_00-model_states.pt&#39;</span><span class="p">,</span> <span class="s1">&#39;layer_47-model_01-model_states.pt&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">630</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">1</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">631</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Final norm layer&#39;</span><span class="p">):</span>
<span class="lineno">632</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_create_final_norm_layer</span><span class="p">())</span>
<span class="lineno">633</span> <span class="k">yield</span> <span class="n">layer</span><span class="p">,</span> <span class="p">(</span><span class="s1">&#39;layer_47-model_00-model_states.pt&#39;</span><span class="p">,</span> <span class="s1">&#39;layer_47-model_01-model_states.pt&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-136'>
@ -1925,13 +1926,13 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">635</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">2</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">636</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Readout layer&#39;</span><span class="p">):</span>
<span class="lineno">637</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_create_readout_layer</span><span class="p">())</span>
<span class="lineno">638</span> <span class="k">yield</span> <span class="n">layer</span><span class="p">,</span> <span class="p">(</span><span class="s1">&#39;layer_48-model_00-model_states.pt&#39;</span><span class="p">,</span> <span class="s1">&#39;layer_48-model_01-model_states.pt&#39;</span><span class="p">)</span>
<span class="lineno">639</span>
<span class="lineno">640</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="lineno">641</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
<div class="highlight"><pre><span class="lineno">636</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">2</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">637</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Readout layer&#39;</span><span class="p">):</span>
<span class="lineno">638</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_create_readout_layer</span><span class="p">())</span>
<span class="lineno">639</span> <span class="k">yield</span> <span class="n">layer</span><span class="p">,</span> <span class="p">(</span><span class="s1">&#39;layer_48-model_00-model_states.pt&#39;</span><span class="p">,</span> <span class="s1">&#39;layer_48-model_01-model_states.pt&#39;</span><span class="p">)</span>
<span class="lineno">640</span>
<span class="lineno">641</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="lineno">642</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-137'>
@ -1943,8 +1944,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">643</span> <span class="nd">@property</span>
<span class="lineno">644</span> <span class="k">def</span> <span class="nf">total_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">644</span> <span class="nd">@property</span>
<span class="lineno">645</span> <span class="k">def</span> <span class="nf">total_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-138'>
@ -1955,7 +1956,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">648</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">3</span></pre></div>
<div class="highlight"><pre><span class="lineno">649</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">3</span></pre></div>
</div>
</div>
<div class='section' id='section-139'>
@ -1967,8 +1968,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">650</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">651</span> <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Generator</span><span class="p">[</span><span class="n">NeoXModule</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span></pre></div>
<div class="highlight"><pre><span class="lineno">651</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">652</span> <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Generator</span><span class="p">[</span><span class="n">NeoXModule</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-140'>
@ -1979,15 +1980,15 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">655</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;Layers&quot;</span><span class="p">):</span>
<span class="lineno">656</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">files</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_layers</span><span class="p">()):</span>
<span class="lineno">657</span> <span class="k">if</span> <span class="n">files</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">658</span> <span class="n">layer</span><span class="o">.</span><span class="n">load_state</span><span class="p">(</span><span class="o">*</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">load_checkpoint_files</span><span class="p">(</span><span class="n">files</span><span class="p">))</span>
<span class="lineno">659</span>
<span class="lineno">660</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_load_prepare</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="lineno">661</span>
<span class="lineno">662</span> <span class="n">monit</span><span class="o">.</span><span class="n">progress</span><span class="p">(</span><span class="nb">min</span><span class="p">(</span><span class="mf">0.99</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_layers</span><span class="p">))</span>
<span class="lineno">663</span> <span class="k">yield</span> <span class="n">layer</span></pre></div>
<div class="highlight"><pre><span class="lineno">656</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;Layers&quot;</span><span class="p">):</span>
<span class="lineno">657</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">files</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_layers</span><span class="p">()):</span>
<span class="lineno">658</span> <span class="k">if</span> <span class="n">files</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">659</span> <span class="n">layer</span><span class="o">.</span><span class="n">load_state</span><span class="p">(</span><span class="o">*</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">load_checkpoint_files</span><span class="p">(</span><span class="n">files</span><span class="p">))</span>
<span class="lineno">660</span>
<span class="lineno">661</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_load_prepare</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="lineno">662</span>
<span class="lineno">663</span> <span class="n">monit</span><span class="o">.</span><span class="n">progress</span><span class="p">(</span><span class="nb">min</span><span class="p">(</span><span class="mf">0.99</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_layers</span><span class="p">))</span>
<span class="lineno">664</span> <span class="k">yield</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -77,20 +77,18 @@
</div>
<h1>Generate Text with GPT-NeoX using LLM.int8() quantization</h1>
<p>This shows how to generate text from GPT-NeoX using <a href="../utils/llm_int8.html">LLM.int8() quantization</a>.</p>
<p>This needs a GPU with more than 45GB memory.</p>
<p>This needs a GPU with 24GB memory.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
<span class="lineno">16</span>
<span class="lineno">17</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">19</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.neox.model</span> <span class="kn">import</span> <span class="n">LayerGenerator</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.neox.samples.generate</span> <span class="kn">import</span> <span class="n">PROMPT</span><span class="p">,</span> <span class="n">infer</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.neox.utils</span> <span class="kn">import</span> <span class="n">get_tokens</span><span class="p">,</span> <span class="n">print_tokens</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.neox.utils.cache</span> <span class="kn">import</span> <span class="n">get_cache</span></pre></div>
<div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">17</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.neox.model</span> <span class="kn">import</span> <span class="n">LayerGenerator</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.neox.samples.generate</span> <span class="kn">import</span> <span class="n">PROMPT</span><span class="p">,</span> <span class="n">infer</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.neox.utils</span> <span class="kn">import</span> <span class="n">get_tokens</span><span class="p">,</span> <span class="n">print_tokens</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.neox.utils.cache</span> <span class="kn">import</span> <span class="n">get_cache</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -102,7 +100,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span class="k">def</span> <span class="nf">generate</span><span class="p">():</span></pre></div>
<div class="highlight"><pre><span class="lineno">25</span><span class="k">def</span> <span class="nf">generate</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -114,8 +112,8 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">33</span> <span class="n">cache</span> <span class="o">=</span> <span class="n">get_cache</span><span class="p">()</span>
<span class="lineno">34</span> <span class="n">cache</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s1">&#39;use_cache&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">31</span> <span class="n">cache</span> <span class="o">=</span> <span class="n">get_cache</span><span class="p">()</span>
<span class="lineno">32</span> <span class="n">cache</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s1">&#39;use_cache&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -127,7 +125,7 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">37</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">35</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -139,9 +137,12 @@
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">layer_generator</span> <span class="o">=</span> <span class="n">LayerGenerator</span><span class="p">(</span><span class="n">is_clone_layers</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">43</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
<span class="lineno">44</span> <span class="n">device</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span></pre></div>
<div class="highlight"><pre><span class="lineno">40</span> <span class="n">layer_generator</span> <span class="o">=</span> <span class="n">LayerGenerator</span><span class="p">(</span><span class="n">is_clone_layers</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">41</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
<span class="lineno">42</span> <span class="n">device</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span>
<span class="lineno">43</span> <span class="n">is_llm_int8</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">44</span> <span class="p">)</span>
<span class="lineno">45</span> <span class="n">layers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">layer_generator</span><span class="o">.</span><span class="n">load</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -149,12 +150,17 @@
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>is_llm_int8=True, </p>
<p>This reduces CUDA memory fragmentation </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="p">)</span>
<span class="lineno">47</span> <span class="n">layers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">layer_generator</span><span class="o">.</span><span class="n">load</span><span class="p">())</span></pre></div>
<div class="highlight"><pre><span class="lineno">48</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Convert to int8&#39;</span><span class="p">,</span> <span class="n">layers</span><span class="p">,</span> <span class="n">is_children_silent</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="lineno">49</span> <span class="n">layer_generator</span><span class="o">.</span><span class="n">post_load_prepare</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span>
<span class="lineno">50</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">51</span> <span class="n">is_llm_int8</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">52</span> <span class="n">llm_int8_threshold</span><span class="o">=</span><span class="mf">6.0</span><span class="p">,</span>
<span class="lineno">53</span> <span class="p">)</span>
<span class="lineno">54</span> <span class="n">layer</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -162,17 +168,12 @@
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>This reduces CUDA memory fragmentation </p>
<p>Create <code class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span></code>
model </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Convert to int8&#39;</span><span class="p">,</span> <span class="n">layers</span><span class="p">,</span> <span class="n">is_children_silent</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="lineno">51</span> <span class="n">layer_generator</span><span class="o">.</span><span class="n">post_load_prepare</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span>
<span class="lineno">52</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">53</span> <span class="n">is_llm_int8</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">54</span> <span class="n">llm_int8_threshold</span><span class="o">=</span><span class="mf">6.0</span><span class="p">,</span>
<span class="lineno">55</span> <span class="p">)</span>
<span class="lineno">56</span> <span class="n">layer</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">57</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -180,12 +181,12 @@
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Create <code class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span></code>
model </p>
<p>Clear cache and print memory summary for debugging </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
<span class="lineno">61</span> <span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">memory_summary</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -193,12 +194,11 @@
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Clear cache and print memory summary for debugging </p>
<p>Get token ids </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
<span class="lineno">63</span> <span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">memory_summary</span><span class="p">())</span></pre></div>
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">ids</span> <span class="o">=</span> <span class="n">get_tokens</span><span class="p">(</span><span class="n">PROMPT</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -206,11 +206,15 @@
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Get token ids </p>
<p>Run the model. We use the <a href="generate.html"><code class="highlight"><span></span><span class="n">infer</span></code>
</a> function defined in <a href="generate.html"><code class="highlight"><span></span><span class="n">generate</span><span class="o">.</span><span class="n">py</span></code>
</a> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="n">ids</span> <span class="o">=</span> <span class="n">get_tokens</span><span class="p">(</span><span class="n">PROMPT</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">cache</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s1">&#39;state_ids&#39;</span><span class="p">,</span> <span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">69</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Infer&#39;</span><span class="p">):</span>
<span class="lineno">70</span> <span class="n">next_token</span> <span class="o">=</span> <span class="n">infer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ids</span><span class="p">,</span> <span class="n">device</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -218,13 +222,11 @@
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Run the model </p>
<p>Append the predicted token </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span> <span class="n">cache</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s1">&#39;state_ids&#39;</span><span class="p">,</span> <span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">70</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Infer&#39;</span><span class="p">):</span>
<span class="lineno">71</span> <span class="n">next_token</span> <span class="o">=</span> <span class="n">infer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ids</span><span class="p">,</span> <span class="n">device</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">ids</span> <span class="o">+=</span> <span class="p">[</span><span class="n">next_token</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -232,11 +234,11 @@
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Append the predicted token </p>
<p>Predict 100 tokens </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="n">ids</span> <span class="o">+=</span> <span class="p">[</span><span class="n">next_token</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">76</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -244,11 +246,11 @@
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Predict 100 tokens </p>
<p>Set the state to use cached activations </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">78</span> <span class="n">cache</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s1">&#39;state_ids&#39;</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -256,11 +258,12 @@
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Set the state to use cached activations </p>
<p>Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">cache</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s1">&#39;state_ids&#39;</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">81</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Infer&#39;</span><span class="p">):</span>
<span class="lineno">82</span> <span class="n">next_token</span> <span class="o">=</span> <span class="n">infer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="p">[</span><span class="n">next_token</span><span class="p">],</span> <span class="n">device</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -268,12 +271,11 @@
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens. </p>
<p>Append the predicted token </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Infer&#39;</span><span class="p">):</span>
<span class="lineno">83</span> <span class="n">next_token</span> <span class="o">=</span> <span class="n">infer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="p">[</span><span class="n">next_token</span><span class="p">],</span> <span class="n">device</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">84</span> <span class="n">ids</span> <span class="o">+=</span> <span class="p">[</span><span class="n">next_token</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -281,11 +283,11 @@
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Append the predicted token </p>
<p>Print </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">ids</span> <span class="o">+=</span> <span class="p">[</span><span class="n">next_token</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">86</span> <span class="n">print_tokens</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="p">[</span><span class="n">ids</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -293,24 +295,12 @@
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Print </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="n">print_tokens</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="p">[</span><span class="n">ids</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">92</span> <span class="n">generate</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">90</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">91</span> <span class="n">generate</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>

View File

@ -1,19 +1,39 @@
"""
---
title: Evaluate GPT-NeoX using LLM.int8() quantization on test suite
summary: >
Evaluate GPT-NeoX using LLM.int8() quantization on test suite
---
# Evaluate GPT-NeoX using LLM.int8() quantization on test suite
This code evaluate [GPT-NeoX](../index.html) using, on a suite of tasks.
"""
import torch
from torch import nn
from labml import monit
from labml_nn.neox.evaluation import run_eval_harness
from labml_nn.neox.model import LayerGenerator
if __name__ == '__main__':
def main():
# Device
device = torch.device('cuda:0')
# Load layers
layers = list(LayerGenerator(is_clone_layers=True,
filter_layers=None,
dtype=torch.float16,
device=device
).load())
with monit.section('Sequential'):
model = nn.Sequential(*layers)
# Create `nn.Sequential` model
model = nn.Sequential(*layers)
# Run [evaluation harness](index.html)
print(run_eval_harness(model, 'half_precision', ['lambada'], device))
#
if __name__ == '__main__':
main()

View File

@ -1,3 +1,16 @@
"""
---
title: Evaluate GPT-NeoX using LLM.int8() quantization on test suite
summary: >
Evaluate GPT-NeoX using LLM.int8() quantization on test suite
---
# Evaluate GPT-NeoX using LLM.int8() quantization on test suite
This code evaluate [GPT-NeoX](../index.html) using [LLM.int8() quantization](../utils/llm_int8.html),
on a suite of tasks.
"""
import torch
from torch import nn
@ -5,8 +18,14 @@ from labml import monit
from labml_nn.neox.evaluation import run_eval_harness
from labml_nn.neox.model import LayerGenerator
if __name__ == '__main__':
def main():
# Device
device = torch.device('cuda:0')
# Load layers in float16 into CPU. We convert the layers to int8 later, because doing that
# on the fly after loading layers to GPU causes CUDA memory fragmentation
# (about 3GB memory can get lost due to fragmentation).
layer_generator = LayerGenerator(is_clone_layers=True,
dtype=torch.float16,
device=torch.device('cpu'),
@ -23,7 +42,13 @@ if __name__ == '__main__':
)
layer.to(device)
with monit.section('Sequential'):
model = nn.Sequential(*layers)
# Create `nn.Sequential` model
model = nn.Sequential(*layers)
# Run [evaluation harness](index.html)
print(run_eval_harness(model, 'half_precision', [], device))
#
if __name__ == '__main__':
main()

View File

@ -520,6 +520,7 @@ class LayerGenerator:
):
"""
<a id="post_load_prepare"></a>
### Layer transformations after loading the checkpoint
This function implements layer transformations after loading the checkpoint.

View File

@ -9,11 +9,9 @@ summary: >
This shows how to generate text from GPT-NeoX using [LLM.int8() quantization](../utils/llm_int8.html).
This needs a GPU with more than 45GB memory.
This needs a GPU with 24GB memory.
"""
from typing import List
import torch
from torch import nn
@ -42,7 +40,7 @@ def generate():
layer_generator = LayerGenerator(is_clone_layers=True,
dtype=torch.float16,
device=torch.device('cpu'),
# is_llm_int8=True,
is_llm_int8=False,
)
layers = list(layer_generator.load())
@ -65,7 +63,8 @@ def generate():
# Get token ids
ids = get_tokens(PROMPT)
# Run the model
# Run the model.
# We use the [`infer`](generate.html) function defined in [`generate.py`](generate.html)
cache.set('state_ids', (None, 1))
with monit.section('Infer'):
next_token = infer(model, ids, device)[-1]