Files
Varuna Jayasiri 2038b11d29 ja translation
2023-05-10 17:00:29 -04:00

2158 lines
207 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="ja">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="これがGPT-Neoxのモデル定義です。"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="GPT-ネオックスモデル定義"/>
<meta name="twitter:description" content="これがGPT-Neoxのモデル定義です。"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/neox/model.html"/>
<meta property="og:title" content="GPT-ネオックスモデル定義"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="GPT-ネオックスモデル定義"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="GPT-ネオックスモデル定義"/>
<meta property="og:description" content="これがGPT-Neoxのモデル定義です。"/>
<title>GPT-ネオックスモデル定義</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/neox/model.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">neox</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/neox/model.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>GPT ネオックスモデル</h1>
<p>これは、GPT-Neoxモデルのレイヤー用のコードと20Bのチェックポイントをロードするコードです。</p>
<p><code class="highlight"><span></span><span class="n">load_state</span></code>
レイヤー内のメソッドは、そのレイヤーのチェックポイントをロードします。チェックポイントロードヘルパーがオンになっています <a href="checkpoint.html"><code class="highlight"><span></span><span class="n">checkpoint</span><span class="o">.</span><span class="n">py</span></code>
</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">16</span><span></span><span class="kn">import</span> <span class="nn">copy</span>
<span class="lineno">17</span><span class="kn">import</span> <span class="nn">math</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Set</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Generator</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="lineno">19</span>
<span class="lineno">20</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">autocast</span>
<span class="lineno">23</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span><span class="p">,</span> <span class="n">logger</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">Text</span>
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">labml_nn.neox</span> <span class="kn">import</span> <span class="n">checkpoint</span>
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">labml_nn.neox.utils.cache</span> <span class="kn">import</span> <span class="n">get_cache</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">30</span><span class="k">class</span> <span class="nc">NeoXModule</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">31</span> <span class="k">def</span> <span class="nf">load_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p1</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">p2</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
<span class="lineno">32</span> <span class="k">pass</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<h2>埋め込みレイヤー</h2>
<p>これは、チェックポイントをロードするコードを含む標準の埋め込みレイヤーです。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">Embedding</span><span class="p">(</span><span class="n">NeoXModule</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_vocab</span></code>
ボキャブラリーの大きさです</li>
<li><code class="highlight"><span></span><span class="n">n_hidden</span></code>
は埋め込みのサイズです</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">50_432</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6_144</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">48</span>
<span class="lineno">49</span> <span class="bp">self</span><span class="o">.</span><span class="n">emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
形状のトークンIDです <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>チェックポイントをロードするコード</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">57</span> <span class="k">def</span> <span class="nf">load_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p1</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">p2</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Load embedding layer&#39;</span><span class="p">):</span>
<span class="lineno">62</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_dim_0</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">emb</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="s1">&#39;word_embeddings.weight&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<h2>ロータリーポジショナルエンベディング</h2>
<p><a href="https://papers.labml.ai/paper/2104.09864">GPT-Neoxは回転式ポジショナルエンベディング</a>RoPEを使用しています。</p>
<p><a href="https://nn.labml.ai/transformers/rope/index.html">ここでは</a>、RoPE の実装に注釈を付けて、理論に関する注釈を付けました。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span><span class="k">class</span> <span class="nc">RoPE</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_rope</span></code>
RoPE 埋め込みの機能の数です</li>
<li><code class="highlight"><span></span><span class="n">base</span></code>
がの基底で<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.27379em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eql" style="">10000</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:1.12379em;"><span style="top:-3.3973400000000002em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0377857142857143em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.5020714285714285em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span><span class="mclose mtight" style="">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span>、デフォルトは <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style="">10000</span></span></span></span></span></span></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_rope</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">base</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10_000.</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>機能用に保存するには</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">theta</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqf" style=""><span class="mop" style=""><span style="">c</span><span style="">o</span><span style="">s</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span>キャッシュと <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqg" style=""><span class="mop" style=""><span style="">s</span><span style="">i</span><span style="">n</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">86</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>のベース <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.27379em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eql" style="">10000</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:1.12379em;"><span style="top:-3.3973400000000002em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0377857142857143em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.5020714285714285em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span><span class="mclose mtight" style="">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="bp">self</span><span class="o">.</span><span class="n">base</span> <span class="o">=</span> <span class="n">base</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>RoPE の機能の数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_rope</span> <span class="o">=</span> <span class="n">d_rope</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<h3>フィーチャをローテーションしてください</h3>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.22902em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97902em;"><span style="top:-3.363em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span><span class="mbin mtight">+</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97902em;"><span style="top:-3.363em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span><span class="mbin mtight">+</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">d</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97902em;"><span style="top:-3.363em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="nd">@staticmethod</span>
<span class="lineno">94</span> <span class="k">def</span> <span class="nf">rotate_half</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">x1</span><span class="p">,</span> <span class="n">x2</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span><span class="p">],</span> <span class="n">x</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span><span class="p">:]</span>
<span class="lineno">101</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="o">-</span><span class="n">x2</span><span class="p">,</span> <span class="n">x1</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
形がある <code class="highlight"><span></span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">offset</span></code>
<code class="highlight"><span></span><span class="n">x</span></code>
の開始位置です。これは、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.5782em;vertical-align:-0.0391em;"></span><span class="mrel">&gt;</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0</span></span></span></span></span>以前のポジションのキーとクエリをキャッシュしたときです</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">offset</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>実際のシーケンス長を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">seq_len</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">3</span><span class="p">]</span> <span class="o">+</span> <span class="n">offset</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>[初期化] <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">theta</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.27379em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord" style=""><span class="mord" style=""><span class="mord coloredeq eql" style="">10000</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:1.12379em;"><span style="top:-3.3973400000000002em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0377857142857143em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.5020714285714285em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="">i</span><span class="mbin mtight" style=""></span><span class="mord mtight" style="">1</span><span class="mclose mtight" style="">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="n">theta</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">base</span> <span class="o">**</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_rope</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_rope</span><span class="p">))</span>
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">theta</span> <span class="o">=</span> <span class="n">theta</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqf" style=""><span class="mop" style=""><span style="">c</span><span style="">o</span><span style="">s</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span>初期化とキャッシュ <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqg" style=""><span class="mop" style=""><span style="">s</span><span style="">i</span><span style="">n</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">if</span> <span class="p">(</span>
<span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span>
<span class="lineno">122</span> <span class="n">seq_len</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="ow">or</span>
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="o">.</span><span class="n">device</span> <span class="o">!=</span> <span class="n">x</span><span class="o">.</span><span class="n">device</span> <span class="ow">or</span>
<span class="lineno">124</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span>
<span class="lineno">125</span> <span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>位置インデックスを取得 <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="">m</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">seq_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">type_as</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">theta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">idx_theta</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;s,d-&gt;sd&quot;</span><span class="p">,</span> <span class="n">seq_idx</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">theta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>行が次のようになるように連結します <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="">m</span></span></span></span></span></span></p>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.26202em;vertical-align:-0.5120199999999999em;"></span><span class="mopen">[</span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.7287800000000004em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.5120199999999999em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqn" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.7287800000000004em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.5120199999999999em;"><span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="n">idx_theta2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">idx_theta</span><span class="p">,</span> <span class="n">idx_theta</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqf" style=""><span class="mop" style=""><span style="">c</span><span style="">o</span><span style="">s</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span>計算して <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqg" style=""><span class="mop" style=""><span style="">s</span><span style="">i</span><span style="">n</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span></span> fp32 で</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="k">with</span> <span class="n">autocast</span><span class="p">(</span><span class="n">enabled</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="lineno">137</span> <span class="n">idx_theta2</span> <span class="o">=</span> <span class="n">idx_theta2</span><span class="o">.</span><span class="n">float</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>頭部寸法を追加</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">139</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span> <span class="o">=</span> <span class="n">idx_theta2</span><span class="o">.</span><span class="n">cos</span><span class="p">()[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">140</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span> <span class="o">=</span> <span class="n">idx_theta2</span><span class="o">.</span><span class="n">sin</span><span class="p">()[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>それらをキャッシュする</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="lineno">144</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>機能を分割してください。RoPE <code class="highlight"><span></span><span class="n">d_rope</span></code>
は機能にのみ適用されます</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">d_rope</span><span class="p">],</span> <span class="n">x</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_rope</span><span class="p">:]</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>キャッシュから sin と cos の値を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">cos</span><span class="p">,</span> <span class="n">sin</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cos_cached</span><span class="p">[</span><span class="n">offset</span><span class="p">:</span> <span class="n">seq_len</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">sin_cached</span><span class="p">[</span><span class="n">offset</span><span class="p">:</span> <span class="n">seq_len</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>ロープ埋め込み</p>
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.42324em;vertical-align:-1.4616200000000001em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.9616199999999997em;"><span style="top:-3.96162em;"><span class="pstrut" style="height:3.8116199999999996em;"></span><span class="mord"><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size4">(</span></span><span class="mord"><span class="mtable"><span class="col-align-c"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8116199999999998em;"><span style="top:-3.81162em;"><span class="pstrut" style="height:3.20162em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.5834080000000004em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.11659199999999997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">cos</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.20162em;"><span style="top:-2.883408em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span style="top:-3.5856000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mbin mtight">+</span><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.11659199999999997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">sin</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span style="top:-2.25em;"><span class="pstrut" style="height:3.20162em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.20162em;"><span style="top:-2.883408em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span style="top:-3.5856000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mbin mtight">+</span><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.11659199999999997em;"><span></span></span></span></span></span></span><span class="mord coloredeq eqf" style=""><span class="mop" style=""><span style="">c</span><span style="">o</span><span style="">s</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.5834080000000004em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqn" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.11659199999999997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">sin</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqh" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqn" style="">m</span></span><span class="mord" style=""><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal coloredeq eqk" style="margin-right:0.02778em">θ</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.3116200000000002em;"><span></span></span></span></span></span></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size4">)</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.4616200000000001em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>にとって <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69862em;vertical-align:-0.0391em;"></span><span class="mord mathnormal">i</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mord"><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">2</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">x_rope</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_rope</span> <span class="o">*</span> <span class="n">cos</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rotate_half</span><span class="p">(</span><span class="n">x_rope</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>RoPe 埋め込みに対応していなかった機能との連携</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x_rope</span><span class="p">,</span> <span class="n">x_pass</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<h2>アテンションレイヤー</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span><span class="k">class</span> <span class="nc">AttentionLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_hidden</span></code>
埋め込みに含まれる機能の数</li>
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
アテンション・ヘッドの数</li>
<li><code class="highlight"><span></span><span class="n">rope_percentage</span></code>
RoPe 埋め込みを追加する機能の割合</li>
<li><code class="highlight"><span></span><span class="n">mask_fill</span></code>
アテンション・マトリックスのマスキング・フィル値</li>
<li><code class="highlight"><span></span><span class="n">is_flash_attention</span></code>
<a href="https://github.com/HazyResearch/flash-attention">フラッシュアテンションを使用するかどうかを指定します</a></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6_144</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span> <span class="n">rope_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.25</span><span class="p">,</span>
<span class="lineno">174</span> <span class="n">mask_fill</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">-</span><span class="mf">10_000.0</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">is_flash_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">184</span>
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_fill</span> <span class="o">=</span> <span class="n">mask_fill</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>クエリ、キー、値の線形レイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv_lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">n_hidden</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>最終線形レイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>ヘッドあたりの機能数</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="n">d_k</span> <span class="o">=</span> <span class="n">n_hidden</span> <span class="o">//</span> <span class="n">n_heads</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>RoPE 埋め込みモジュール</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span> <span class="o">=</span> <span class="n">RoPE</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">d_k</span> <span class="o">*</span> <span class="n">rope_percentage</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>アテンションスケーリングファクター</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">d_k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>因果マスクをキャッシュするには</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="bp">self</span><span class="o">.</span><span class="n">causal_mask</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>アテンションソフトマックスモジュール</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p><a href="https://github.com/HazyResearch/flash-attention">フラッシュ・アテンション</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="k">if</span> <span class="n">is_flash_attention</span><span class="p">:</span>
<span class="lineno">209</span> <span class="k">try</span><span class="p">:</span>
<span class="lineno">210</span> <span class="kn">from</span> <span class="nn">flash_attn.flash_attention</span> <span class="kn">import</span> <span class="n">FlashAttention</span>
<span class="lineno">211</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span> <span class="o">=</span> <span class="n">FlashAttention</span><span class="p">()</span>
<span class="lineno">212</span> <span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="lineno">213</span> <span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="s1">&#39;Install flash attention github.com/HazyResearch/flash-attention. &#39;</span>
<span class="lineno">214</span> <span class="s1">&#39;Falling back to normal attention&#39;</span><span class="p">,</span> <span class="n">Text</span><span class="o">.</span><span class="n">warning</span><span class="p">)</span>
<span class="lineno">215</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span> <span class="o">=</span> <span class="kc">None</span>
<span class="lineno">216</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">217</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<h4>因果マスクの計算</h4>
<ul><li><code class="highlight"><span></span><span class="n">attn</span></code>
<a href="batch_size, query_seq_len, key_seq_len, n_heads">バッチサイズ、クエリシーケンスレン、キーシーケンスレン、</a> nヘッズのシェイプがあります</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="k">def</span> <span class="nf">_get_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attn</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>クエリとキーの長さ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">nq</span><span class="p">,</span> <span class="n">nk</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>マスク作成</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="k">if</span> <span class="p">(</span>
<span class="lineno">231</span> <span class="bp">self</span><span class="o">.</span><span class="n">causal_mask</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span>
<span class="lineno">232</span> <span class="bp">self</span><span class="o">.</span><span class="n">causal_mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="n">nq</span> <span class="ow">or</span>
<span class="lineno">233</span> <span class="bp">self</span><span class="o">.</span><span class="n">causal_mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="n">nk</span> <span class="ow">or</span>
<span class="lineno">234</span> <span class="bp">self</span><span class="o">.</span><span class="n">causal_mask</span><span class="o">.</span><span class="n">device</span> <span class="o">!=</span> <span class="n">attn</span><span class="o">.</span><span class="n">device</span>
<span class="lineno">235</span> <span class="p">):</span>
<span class="lineno">236</span> <span class="bp">self</span><span class="o">.</span><span class="n">causal_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="n">attn</span><span class="o">.</span><span class="n">new_ones</span><span class="p">([</span><span class="n">nq</span><span class="p">,</span> <span class="n">nk</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">),</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">nk</span> <span class="o">-</span> <span class="n">nq</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>キャッシュから戻る</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">causal_mask</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
形がある <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">241</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>クエリ、キー、値の埋め込み (すべて連結) を取得します。最後のディメンションサイズが n_hidden から変更されます</p>-> <code class="highlight"><span></span><span class="mi">3</span> <span class="n">x</span> <span class="n">n_hidden</span></code>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">247</span> <span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv_lin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>形状を以下のように変更して頭部に分割します <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">qkv</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>形状ごとにクエリ、キー、値に分割 <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>以前のトークンの状態をキャッシュする場合</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">255</span> <span class="k">if</span> <span class="n">get_cache</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;use_cache&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>ステート ID を取得します。前のステートを取得したり、次のステートを保存したりするのに使います。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="n">prev_state_id</span><span class="p">,</span> <span class="n">next_state_id</span> <span class="o">=</span> <span class="n">get_cache</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;state_ids&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>キャッシュがある場合</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">259</span> <span class="k">if</span> <span class="n">prev_state_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>過去のキーと値を取得します。これらは形になります <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">prev_seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">261</span> <span class="n">k_past</span><span class="p">,</span> <span class="n">v_past</span> <span class="o">=</span> <span class="n">get_cache</span><span class="p">()</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;attn_kv_</span><span class="si">{</span><span class="n">prev_state_id</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>現在の埋め込みのオフセット</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="n">offset</span> <span class="o">=</span> <span class="n">k_past</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>RoPe 埋め込みを追加</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">266</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="n">offset</span><span class="p">)</span>
<span class="lineno">267</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="n">offset</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>過去を連結する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">270</span> <span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">k_past</span><span class="p">,</span> <span class="n">k</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">271</span> <span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">v_past</span><span class="p">,</span> <span class="n">v</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">272</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>RoPe 埋め込みを追加</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">274</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
<span class="lineno">275</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>現在の状態を保存する</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">278</span> <span class="n">get_cache</span><span class="p">()</span><span class="o">.</span><span class="n">push</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;attn_kv_</span><span class="si">{</span><span class="n">next_state_id</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">))</span>
<span class="lineno">279</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
<p>キャッシュなし-RoPE 埋め込みを追加するだけ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">281</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
<span class="lineno">282</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope</span><span class="p">(</span><span class="n">k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>フラッシュアテンションを使う</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">285</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="ow">and</span> <span class="n">q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">286</span> <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_flash_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>それ以外の場合は、通常の注意を払ってください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">289</span> <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p><code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span> <span class="n">to</span></code>
<a href="batch_size, seq_len, n_hidden">バッチサイズ、シーケンス番号、n_hidden `</a>から形状を変更</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>最終線形レイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">295</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">output</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">297</span> <span class="k">def</span> <span class="nf">compute_flash_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>それらを積み重ねて形を整える <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">299</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="lineno">300</span> <span class="n">d_k</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="lineno">301</span> <span class="k">if</span> <span class="n">d_k</span> <span class="o">&lt;=</span> <span class="mi">32</span><span class="p">:</span>
<span class="lineno">302</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">32</span> <span class="o">-</span> <span class="n">d_k</span>
<span class="lineno">303</span> <span class="k">elif</span> <span class="n">d_k</span> <span class="o">&lt;=</span> <span class="mi">64</span><span class="p">:</span>
<span class="lineno">304</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">64</span> <span class="o">-</span> <span class="n">d_k</span>
<span class="lineno">305</span> <span class="k">elif</span> <span class="n">d_k</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">306</span> <span class="n">pad</span> <span class="o">=</span> <span class="mi">128</span> <span class="o">-</span> <span class="n">d_k</span>
<span class="lineno">307</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">308</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Head size </span><span class="si">{</span><span class="n">d_k</span><span class="si">}</span><span class="s1"> too large for flash attention&#39;</span><span class="p">)</span>
<span class="lineno">309</span>
<span class="lineno">310</span> <span class="k">if</span> <span class="n">pad</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="lineno">311</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">qkv</span><span class="p">,</span> <span class="n">qkv</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="o">*</span><span class="n">qkv</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">pad</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="lineno">312</span>
<span class="lineno">313</span> <span class="n">output</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">causal</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
<p>出力は整形しています <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span> <span class="o">+</span> <span class="n">padding</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">315</span> <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="n">d_k</span><span class="p">]</span>
<span class="lineno">316</span>
<span class="lineno">317</span> <span class="k">return</span> <span class="n">output</span></pre></div>
</div>
</div>
<div class='section' id='section-70'>
<div class='docs'>
<div class='section-link'>
<a href='#section-70'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">319</span> <span class="k">def</span> <span class="nf">compute_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-71'>
<div class='docs'>
<div class='section-link'>
<a href='#section-71'>#</a>
</div>
<p>アテンション計算の fp16 への自動キャストを無効にする</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">321</span> <span class="k">with</span> <span class="n">autocast</span><span class="p">(</span><span class="n">enabled</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="lineno">322</span> <span class="k">if</span> <span class="n">q</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-72'>
<div class='docs'>
<div class='section-link'>
<a href='#section-72'>#</a>
</div>
<p>現在の dtype が fp16 の場合は fp32 に変換</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">324</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bihk,bjhk-&gt;bijh&#39;</span><span class="p">,</span> <span class="n">q</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">k</span><span class="o">.</span><span class="n">float</span><span class="p">())</span>
<span class="lineno">325</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-73'>
<div class='docs'>
<div class='section-link'>
<a href='#section-73'>#</a>
</div>
<p>bfloatにはキャストしないでください</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">327</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bihk,bjhk-&gt;bijh&#39;</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-74'>
<div class='docs'>
<div class='section-link'>
<a href='#section-74'>#</a>
</div>
<p>スケールアテンション</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">330</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-75'>
<div class='docs'>
<div class='section-link'>
<a href='#section-75'>#</a>
</div>
<p>カジュアルマスクをゲット</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">333</span> <span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_mask</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-76'>
<div class='docs'>
<div class='section-link'>
<a href='#section-76'>#</a>
</div>
<p>マスクを適用</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">335</span> <span class="n">attn</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_fill</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-77'>
<div class='docs'>
<div class='section-link'>
<a href='#section-77'>#</a>
</div>
<p>注意ソフトマックス</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">338</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-78'>
<div class='docs'>
<div class='section-link'>
<a href='#section-78'>#</a>
</div>
<p>アテンション加重値を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">341</span> <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;bijh,bjhk-&gt;bihk&#39;</span><span class="p">,</span> <span class="n">attn</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="n">v</span><span class="p">)</span>
<span class="lineno">342</span>
<span class="lineno">343</span> <span class="k">return</span> <span class="n">output</span></pre></div>
</div>
</div>
<div class='section' id='section-79'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-79'>#</a>
</div>
<h2>フィードフォワードネットワーク</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">346</span><span class="k">class</span> <span class="nc">FFNLayer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-80'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-80'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_hidden</span></code>
は埋め込みサイズ</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">351</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6_144</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-81'>
<div class='docs'>
<div class='section-link'>
<a href='#section-81'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">355</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">356</span>
<span class="lineno">357</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">d_ff</span><span class="p">:</span>
<span class="lineno">358</span> <span class="n">d_ff</span> <span class="o">=</span> <span class="n">n_hidden</span> <span class="o">*</span> <span class="mi">4</span></pre></div>
</div>
</div>
<div class='section' id='section-82'>
<div class='docs'>
<div class='section-link'>
<a href='#section-82'>#</a>
</div>
<p>拡張リニアレイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">361</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_h_h4</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-83'>
<div class='docs'>
<div class='section-link'>
<a href='#section-83'>#</a>
</div>
<p>GELU アクティベーション</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">363</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-84'>
<div class='docs'>
<div class='section-link'>
<a href='#section-84'>#</a>
</div>
<p>収縮線状層</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">365</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_h4_h</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-85'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-85'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
形がある <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">367</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-86'>
<div class='docs'>
<div class='section-link'>
<a href='#section-86'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">371</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_h_h4</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">372</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">373</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_h4_h</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="lineno">374</span>
<span class="lineno">375</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-87'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-87'>#</a>
</div>
<h2>変圧器層</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">378</span><span class="k">class</span> <span class="nc">TransformerLayer</span><span class="p">(</span><span class="n">NeoXModule</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-88'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-88'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_hidden</span></code>
は埋め込みサイズ</li>
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
は頭の数です</li>
<li><code class="highlight"><span></span><span class="n">is_flash_attention</span></code>
<a href="https://github.com/HazyResearch/flash-attention">フラッシュアテンションを使用するかどうかを指定します</a></li></ul>
<p><em>アウトの実装にはドロップアウトは含まれていません</em></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">383</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6_144</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">is_flash_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-89'>
<div class='docs'>
<div class='section-link'>
<a href='#section-89'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">392</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-90'>
<div class='docs'>
<div class='section-link'>
<a href='#section-90'>#</a>
</div>
<p>注意前のレイヤー正規化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">395</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_ln_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-91'>
<div class='docs'>
<div class='section-link'>
<a href='#section-91'>#</a>
</div>
<p>FFN 前のレイヤー正規化</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">397</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_ln_ffn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-92'>
<div class='docs'>
<div class='section-link'>
<a href='#section-92'>#</a>
</div>
<p>アテンションレイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">400</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">AttentionLayer</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">is_flash_attention</span><span class="o">=</span><span class="n">is_flash_attention</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-93'>
<div class='docs'>
<div class='section-link'>
<a href='#section-93'>#</a>
</div>
<p>FFN レイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">402</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">FFNLayer</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-94'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-94'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
形が埋め込まれているものです <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">404</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-95'>
<div class='docs'>
<div class='section-link'>
<a href='#section-95'>#</a>
</div>
<p>残留接続</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">410</span> <span class="n">residual</span> <span class="o">=</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-96'>
<div class='docs'>
<div class='section-link'>
<a href='#section-96'>#</a>
</div>
<p>NeoXはアテンションネットワークとフィードフォワードネットワークを並行して実行します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">412</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_ln_attn</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="lineno">413</span> <span class="n">ffn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_ln_ffn</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-97'>
<div class='docs'>
<div class='section-link'>
<a href='#section-97'>#</a>
</div>
<p>それらと残りの接続を追加します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">415</span> <span class="k">return</span> <span class="n">attn</span> <span class="o">+</span> <span class="n">ffn</span> <span class="o">+</span> <span class="n">residual</span></pre></div>
</div>
</div>
<div class='section' id='section-98'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-98'>#</a>
</div>
<p>チェックポイントをロードするコード</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">417</span> <span class="k">def</span> <span class="nf">load_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p1</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">p2</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-99'>
<div class='docs'>
<div class='section-link'>
<a href='#section-99'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">421</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Load transformer layer&#39;</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-100'>
<div class='docs'>
<div class='section-link'>
<a href='#section-100'>#</a>
</div>
<p>アテンション出力変換</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">423</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">output</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="s1">&#39;attention.dense.bias&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span>
<span class="lineno">424</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_dim_1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">output</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="s1">&#39;attention.dense.weight&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-101'>
<div class='docs'>
<div class='section-link'>
<a href='#section-101'>#</a>
</div>
<p>アテンションクエリ、キー、値の変換</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">427</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_dim_0</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">qkv_lin</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="s1">&#39;attention.query_key_value.bias&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span>
<span class="lineno">428</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_dim_0</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">qkv_lin</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="s1">&#39;attention.query_key_value.weight&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-102'>
<div class='docs'>
<div class='section-link'>
<a href='#section-102'>#</a>
</div>
<p>注目される前のレイヤーノルム</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">431</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_duplicate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_ln_attn</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="s1">&#39;input_layernorm.bias&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span>
<span class="lineno">432</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_duplicate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_ln_attn</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="s1">&#39;input_layernorm.weight&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-103'>
<div class='docs'>
<div class='section-link'>
<a href='#section-103'>#</a>
</div>
<p>FFN 2 番目のトランスフォーム</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">435</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_dim_0</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h_h4</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="s1">&#39;mlp.dense_h_to_4h.bias&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span>
<span class="lineno">436</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_dim_0</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h_h4</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="s1">&#39;mlp.dense_h_to_4h.weight&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-104'>
<div class='docs'>
<div class='section-link'>
<a href='#section-104'>#</a>
</div>
<p>FFN ファーストトランスフォーム</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">439</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h4_h</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="s1">&#39;mlp.dense_4h_to_h.bias&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span>
<span class="lineno">440</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_dim_1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h4_h</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="s1">&#39;mlp.dense_4h_to_h.weight&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-105'>
<div class='docs'>
<div class='section-link'>
<a href='#section-105'>#</a>
</div>
<p>FFN 前のレイヤーノルム</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">443</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_duplicate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_ln_ffn</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="s1">&#39;post_attention_layernorm.bias&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span>
<span class="lineno">444</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_duplicate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_ln_ffn</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="s1">&#39;post_attention_layernorm.weight&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-106'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-106'>#</a>
</div>
<h2>最終正規化レイヤー</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">447</span><span class="k">class</span> <span class="nc">FinalNorm</span><span class="p">(</span><span class="n">NeoXModule</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-107'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-107'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_hidden</span></code>
は埋め込みサイズ</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">452</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6_144</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-108'>
<div class='docs'>
<div class='section-link'>
<a href='#section-108'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">456</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">457</span>
<span class="lineno">458</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-109'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-109'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
形が埋め込まれているものです <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">460</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-110'>
<div class='docs'>
<div class='section-link'>
<a href='#section-110'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">464</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-111'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-111'>#</a>
</div>
<p>チェックポイントをロードするコード</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">466</span> <span class="k">def</span> <span class="nf">load_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p1</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">p2</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-112'>
<div class='docs'>
<div class='section-link'>
<a href='#section-112'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">470</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Load final normalization layer&#39;</span><span class="p">):</span>
<span class="lineno">471</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_duplicate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ln</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="s1">&#39;norm.bias&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span>
<span class="lineno">472</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_duplicate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ln</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="s1">&#39;norm.weight&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-113'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-113'>#</a>
</div>
<p>読み出し層</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">475</span><span class="k">class</span> <span class="nc">ReadoutLayer</span><span class="p">(</span><span class="n">NeoXModule</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-114'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-114'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">n_hidden</span></code>
は埋め込みサイズ</li>
<li><code class="highlight"><span></span><span class="n">n_vocab</span></code>
ボキャブラリーの大きさです</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">480</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6_144</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">50_432</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-115'>
<div class='docs'>
<div class='section-link'>
<a href='#section-115'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">485</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">486</span>
<span class="lineno">487</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-116'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-116'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
形が埋め込まれているものです <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">489</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-117'>
<div class='docs'>
<div class='section-link'>
<a href='#section-117'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">493</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-118'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-118'>#</a>
</div>
<p>チェックポイントをロードするコード</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">495</span> <span class="k">def</span> <span class="nf">load_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p1</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">p2</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-119'>
<div class='docs'>
<div class='section-link'>
<a href='#section-119'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">499</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Load final linear layer&#39;</span><span class="p">):</span>
<span class="lineno">500</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">merge_params_dim_0</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="s1">&#39;final_linear.weight&#39;</span><span class="p">,</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-120'>
<div class='docs'>
<div class='section-link'>
<a href='#section-120'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">503</span><span class="k">class</span> <span class="nc">LayerGenerator</span><span class="p">:</span>
<span class="lineno">504</span> <span class="n">pre_created_layers</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Optional</span><span class="p">[</span><span class="n">NeoXModule</span><span class="p">]]</span></pre></div>
</div>
</div>
<div class='section' id='section-121'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-121'>#</a>
</div>
<h3>レイヤーを作成するためのジェネレーター</h3>
<p>レイヤーはチェックポイントと同じ順序で生成されます。</p>
<p><code class="highlight"><span></span><span class="kc">None</span></code>
レイヤーが使用できない場合に返されます。レイヤーインデックスをNeoXとして使用し、実装には必要のない変換レイヤーが2つあります。</p>
<ul><li><code class="highlight"><span></span><span class="n">n_vocab</span></code>
ボキャブラリ内のトークンの数です</li>
<li><code class="highlight"><span></span><span class="n">n_hidden</span></code>
は埋め込み内のフィーチャの数です</li>
<li><code class="highlight"><span></span><span class="n">n_layers</span></code>
変圧器層の数です</li>
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
アテンション・ヘッドの数です</li>
<li><code class="highlight"><span></span><span class="n">filter_layers</span></code>
使用するレイヤーのセットです。None の場合はすべてのレイヤーが使用されます。これは、レイヤー数の少ないモデルの小さいバージョンをテストする場合に使用します</li>
<li><code class="highlight"><span></span><span class="n">is_clone_layers</span></code>
トランスフォーマーレイヤーのクローンを作成するかどうかを指定します (少し速くなります)</li>
<li><code class="highlight"><span></span><span class="n">dtype</span></code>
モデルのデータ型です</li>
<li><code class="highlight"><span></span><span class="n">device</span></code>
モデルのデバイスです</li>
<li><code class="highlight"><span></span><span class="n">is_llm_int8</span></code>
int8 量子化を使用するかどうかを指定します</li>
<li><code class="highlight"><span></span><span class="n">llm_int8_threshold</span></code>
<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span></span></span></span></span>外れ値の特徴を分離するための閾値です</li>
<li><code class="highlight"><span></span><span class="n">is_flash_attention</span></code>
<a href="https://github.com/HazyResearch/flash-attention">フラッシュアテンションを使用するかどうかを指定します</a></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">506</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">50_432</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">6_144</span><span class="p">,</span>
<span class="lineno">507</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">44</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span>
<span class="lineno">508</span> <span class="n">filter_layers</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Set</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">509</span> <span class="n">is_clone_layers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">510</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="lineno">511</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span>
<span class="lineno">512</span> <span class="n">is_llm_int8</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="lineno">513</span> <span class="n">llm_int8_threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">6.0</span><span class="p">,</span>
<span class="lineno">514</span> <span class="n">is_flash_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="lineno">515</span> <span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-122'>
<div class='docs'>
<div class='section-link'>
<a href='#section-122'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">538</span> <span class="k">if</span> <span class="n">filter_layers</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">539</span> <span class="n">filter_layers</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">3</span><span class="p">))</span>
<span class="lineno">540</span>
<span class="lineno">541</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_vocab</span> <span class="o">=</span> <span class="n">n_vocab</span>
<span class="lineno">542</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">n_hidden</span>
<span class="lineno">543</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">=</span> <span class="n">n_layers</span>
<span class="lineno">544</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
<span class="lineno">545</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span> <span class="o">=</span> <span class="n">filter_layers</span>
<span class="lineno">546</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_clone_layers</span> <span class="o">=</span> <span class="n">is_clone_layers</span>
<span class="lineno">547</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="lineno">548</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="lineno">549</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_llm_int8</span> <span class="o">=</span> <span class="n">is_llm_int8</span>
<span class="lineno">550</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_int8_threshold</span> <span class="o">=</span> <span class="n">llm_int8_threshold</span>
<span class="lineno">551</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_flash_attention</span> <span class="o">=</span> <span class="n">is_flash_attention</span>
<span class="lineno">552</span>
<span class="lineno">553</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span>
<span class="lineno">554</span> <span class="n">transformer_layer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="lineno">555</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-123'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-123'>#</a>
</div>
<h4>レイヤーを使用できるように準備します</h4>
<p>レイヤーをデバイスに移動し、正しいデータ型に変換します。</p>
<ul><li><code class="highlight"><span></span><span class="n">layer</span></code>
準備するレイヤーです</li>
<p><em>準備したレイヤーを返します</em></p></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">557</span> <span class="k">def</span> <span class="nf">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">NeoXModule</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-124'>
<div class='docs'>
<div class='section-link'>
<a href='#section-124'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">566</span> <span class="k">return</span> <span class="n">layer</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-125'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-125'>#</a>
</div>
<p><a id="post_load_prepare"></a></p>
<h3>チェックポイントをロードした後のレイヤー変換</h3>
<p>この関数は、チェックポイントを読み込んだ後にレイヤー変換を実装します。</p>
<p>現在、適用されるのは int8 量子化のみです。</p>
<ul><li><code class="highlight"><span></span><span class="n">layer</span></code>
準備するレイヤーです</li>
<li><code class="highlight"><span></span><span class="n">is_llm_int8</span></code>
int8 量子化を使用するかどうかを指定します</li>
<li><code class="highlight"><span></span><span class="n">device</span></code>
モデルのデバイスです</li>
<li><code class="highlight"><span></span><span class="n">llm_int8_threshold</span></code>
<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqj" style=""><span class="mord mathnormal" style="margin-right:0.0037em">α</span></span></span></span></span></span>外れ値の特徴を分離するための閾値です</li>
<p><em>準備したレイヤーを返します</em></p></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">568</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">569</span> <span class="k">def</span> <span class="nf">post_load_prepare</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">NeoXModule</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">570</span> <span class="n">is_llm_int8</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">571</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">572</span> <span class="n">llm_int8_threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">573</span> <span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-126'>
<div class='docs'>
<div class='section-link'>
<a href='#section-126'>#</a>
</div>
<p>指定しない場合はデフォルト値を取得</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">591</span> <span class="k">if</span> <span class="n">is_llm_int8</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">592</span> <span class="n">is_llm_int8</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_llm_int8</span>
<span class="lineno">593</span> <span class="k">if</span> <span class="n">device</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">594</span> <span class="n">device</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span>
<span class="lineno">595</span> <span class="k">if</span> <span class="n">llm_int8_threshold</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">596</span> <span class="n">llm_int8_threshold</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_int8_threshold</span></pre></div>
</div>
</div>
<div class='section' id='section-127'>
<div class='docs'>
<div class='section-link'>
<a href='#section-127'>#</a>
</div>
<p>int8 量子化を使用しない場合はスキップ</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">599</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">is_llm_int8</span><span class="p">:</span>
<span class="lineno">600</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='section' id='section-128'>
<div class='docs'>
<div class='section-link'>
<a href='#section-128'>#</a>
</div>
<p>トランスレイヤーの線形レイヤーのみを変換します</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">603</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">TransformerLayer</span><span class="p">):</span>
<span class="lineno">604</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='section' id='section-129'>
<div class='docs'>
<div class='section-link'>
<a href='#section-129'>#</a>
</div>
<p><code class="highlight"><span></span><span class="n">make_llm_int8_linear</span></code>
<a href="./utils/llm_int8.html">ユーティリティで定義されている用途</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">607</span> <span class="kn">from</span> <span class="nn">labml_nn.neox.utils.llm_int8</span> <span class="kn">import</span> <span class="n">make_llm_int8_linear</span></pre></div>
</div>
</div>
<div class='section' id='section-130'>
<div class='docs'>
<div class='section-link'>
<a href='#section-130'>#</a>
</div>
<p>線形レイヤーの変換</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">610</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Convert to int8&#39;</span><span class="p">):</span>
<span class="lineno">611</span> <span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">output</span><span class="p">,</span>
<span class="lineno">612</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">613</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span>
<span class="lineno">614</span> <span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">qkv_lin</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">qkv_lin</span><span class="p">,</span>
<span class="lineno">615</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">616</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span>
<span class="lineno">617</span> <span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h_h4</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h_h4</span><span class="p">,</span>
<span class="lineno">618</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">619</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span>
<span class="lineno">620</span> <span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h4_h</span> <span class="o">=</span> <span class="n">make_llm_int8_linear</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">ffn</span><span class="o">.</span><span class="n">dense_h4_h</span><span class="p">,</span>
<span class="lineno">621</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">622</span> <span class="n">threshold</span><span class="o">=</span><span class="n">llm_int8_threshold</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-131'>
<div class='docs'>
<div class='section-link'>
<a href='#section-131'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">624</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='section' id='section-132'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-132'>#</a>
</div>
<h4>レイヤーを作成してキャッシュします</h4>
<p>キャッシュされたレイヤーのコピーは、パラメーターの初期化に時間がかかるため、新しいレイヤーを初期化するよりも高速です。</p>
<ul><li><code class="highlight"><span></span><span class="n">name</span></code>
レイヤーの名前です</li>
<li><code class="highlight"><span></span><span class="n">creator</span></code>
レイヤーを作成する関数です</li>
<p><em>作成されたレイヤーまたはキャッシュされたレイヤーのコピーを返します</em></p></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">626</span> <span class="k">def</span> <span class="nf">_create_and_cache_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">creator</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[],</span> <span class="n">NeoXModule</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-133'>
<div class='docs'>
<div class='section-link'>
<a href='#section-133'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">638</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_clone_layers</span><span class="p">:</span>
<span class="lineno">639</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="n">creator</span><span class="p">())</span>
<span class="lineno">640</span>
<span class="lineno">641</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">642</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="n">creator</span><span class="p">())</span>
<span class="lineno">643</span>
<span class="lineno">644</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
<span class="lineno">645</span> <span class="k">return</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='section' id='section-134'>
<div class='docs'>
<div class='section-link'>
<a href='#section-134'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">647</span> <span class="k">def</span> <span class="nf">_create_transformer_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">648</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_and_cache_layer</span><span class="p">(</span>
<span class="lineno">649</span> <span class="s1">&#39;transformer_layer&#39;</span><span class="p">,</span>
<span class="lineno">650</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">TransformerLayer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">is_flash_attention</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">is_flash_attention</span><span class="p">)</span>
<span class="lineno">651</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-135'>
<div class='docs'>
<div class='section-link'>
<a href='#section-135'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">653</span> <span class="k">def</span> <span class="nf">_create_embedding_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">654</span> <span class="k">return</span> <span class="n">Embedding</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_vocab</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-136'>
<div class='docs'>
<div class='section-link'>
<a href='#section-136'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">656</span> <span class="k">def</span> <span class="nf">_create_final_norm_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">657</span> <span class="k">return</span> <span class="n">FinalNorm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-137'>
<div class='docs'>
<div class='section-link'>
<a href='#section-137'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">659</span> <span class="k">def</span> <span class="nf">_create_readout_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="lineno">660</span> <span class="k">return</span> <span class="n">ReadoutLayer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_vocab</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-138'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-138'>#</a>
</div>
<h3>レイヤーを取得するためのジェネレーター</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">662</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">663</span> <span class="k">def</span> <span class="nf">get_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Generator</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">NeoXModule</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">str</span><span class="p">]],</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-139'>
<div class='docs'>
<div class='section-link'>
<a href='#section-139'>#</a>
</div>
<p>埋め込みレイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">668</span> <span class="k">if</span> <span class="mi">0</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">669</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Embedding layer&#39;</span><span class="p">):</span>
<span class="lineno">670</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_create_embedding_layer</span><span class="p">())</span>
<span class="lineno">671</span> <span class="k">yield</span> <span class="n">layer</span><span class="p">,</span> <span class="p">(</span><span class="s1">&#39;layer_00-model_00-model_states.pt&#39;</span><span class="p">,</span> <span class="s1">&#39;layer_00-model_01-model_states.pt&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-140'>
<div class='docs'>
<div class='section-link'>
<a href='#section-140'>#</a>
</div>
<p>トランスフォーマー層</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">674</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-141'>
<div class='docs'>
<div class='section-link'>
<a href='#section-141'>#</a>
</div>
<p>変圧器層</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">676</span> <span class="k">if</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">677</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Transformer Layer </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">):</span>
<span class="lineno">678</span> <span class="k">yield</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_transformer_layer</span><span class="p">(),</span> \
<span class="lineno">679</span> <span class="p">(</span><span class="sa">f</span><span class="s1">&#39;layer_</span><span class="si">{</span><span class="n">i</span> <span class="o">+</span> <span class="mi">2</span> <span class="si">:</span><span class="s1">02d</span><span class="si">}</span><span class="s1">-model_00-model_states.pt&#39;</span><span class="p">,</span>
<span class="lineno">680</span> <span class="sa">f</span><span class="s1">&#39;layer_</span><span class="si">{</span><span class="n">i</span> <span class="o">+</span> <span class="mi">2</span> <span class="si">:</span><span class="s1">02d</span><span class="si">}</span><span class="s1">-model_01-model_states.pt&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-142'>
<div class='docs'>
<div class='section-link'>
<a href='#section-142'>#</a>
</div>
<p>最終正規化レイヤー</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">683</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">1</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">684</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Final norm layer&#39;</span><span class="p">):</span>
<span class="lineno">685</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_create_final_norm_layer</span><span class="p">())</span>
<span class="lineno">686</span> <span class="k">yield</span> <span class="n">layer</span><span class="p">,</span> <span class="p">(</span><span class="s1">&#39;layer_47-model_00-model_states.pt&#39;</span><span class="p">,</span> <span class="s1">&#39;layer_47-model_01-model_states.pt&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-143'>
<div class='docs'>
<div class='section-link'>
<a href='#section-143'>#</a>
</div>
<p>読み出し層</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">689</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">2</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_layers</span><span class="p">:</span>
<span class="lineno">690</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Readout layer&#39;</span><span class="p">):</span>
<span class="lineno">691</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_create_readout_layer</span><span class="p">())</span>
<span class="lineno">692</span> <span class="k">yield</span> <span class="n">layer</span><span class="p">,</span> <span class="p">(</span><span class="s1">&#39;layer_48-model_00-model_states.pt&#39;</span><span class="p">,</span> <span class="s1">&#39;layer_48-model_01-model_states.pt&#39;</span><span class="p">)</span>
<span class="lineno">693</span>
<span class="lineno">694</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="lineno">695</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_created_layers</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-144'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-144'>#</a>
</div>
<h3>レイヤーの総数を返します</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">697</span> <span class="nd">@property</span>
<span class="lineno">698</span> <span class="k">def</span> <span class="nf">total_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-145'>
<div class='docs'>
<div class='section-link'>
<a href='#section-145'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">702</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">3</span></pre></div>
</div>
</div>
<div class='section' id='section-146'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-146'>#</a>
</div>
<h3>レイヤーをロードするジェネレーター</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">704</span> <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
<span class="lineno">705</span> <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Generator</span><span class="p">[</span><span class="n">NeoXModule</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-147'>
<div class='docs'>
<div class='section-link'>
<a href='#section-147'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">709</span> <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s2">&quot;Layers&quot;</span><span class="p">):</span>
<span class="lineno">710</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">files</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_layers</span><span class="p">()):</span>
<span class="lineno">711</span> <span class="k">if</span> <span class="n">files</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">712</span> <span class="n">layer</span><span class="o">.</span><span class="n">load_state</span><span class="p">(</span><span class="o">*</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">load_checkpoint_files</span><span class="p">(</span><span class="n">files</span><span class="p">))</span>
<span class="lineno">713</span>
<span class="lineno">714</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_load_prepare</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>
<span class="lineno">715</span>
<span class="lineno">716</span> <span class="n">monit</span><span class="o">.</span><span class="n">progress</span><span class="p">(</span><span class="nb">min</span><span class="p">(</span><span class="mf">0.99</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_layers</span><span class="p">))</span>
<span class="lineno">717</span> <span class="k">yield</span> <span class="n">layer</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>