mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 01:26:44 +08:00
2061 lines
231 KiB
HTML
2061 lines
231 KiB
HTML
<!DOCTYPE html>
|
||
<html lang="en">
|
||
<head>
|
||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||
<meta name="description" content="RETRO model with encoder for neighbors and autoregressive decoder"/>
|
||
|
||
<meta name="twitter:card" content="summary"/>
|
||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta name="twitter:title" content="RETRO model"/>
|
||
<meta name="twitter:description" content="RETRO model with encoder for neighbors and autoregressive decoder"/>
|
||
<meta name="twitter:site" content="@labmlai"/>
|
||
<meta name="twitter:creator" content="@labmlai"/>
|
||
|
||
<meta property="og:url" content="https://nn.labml.ai/transformers/retro/model.html"/>
|
||
<meta property="og:title" content="RETRO model"/>
|
||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta property="og:site_name" content="RETRO model"/>
|
||
<meta property="og:type" content="object"/>
|
||
<meta property="og:title" content="RETRO model"/>
|
||
<meta property="og:description" content="RETRO model with encoder for neighbors and autoregressive decoder"/>
|
||
|
||
<title>RETRO model</title>
|
||
<link rel="shortcut icon" href="/icon.png"/>
|
||
<link rel="stylesheet" href="../../pylit.css?v=1">
|
||
<link rel="canonical" href="https://nn.labml.ai/transformers/retro/model.html"/>
|
||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||
|
||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||
<script>
|
||
window.dataLayer = window.dataLayer || [];
|
||
|
||
function gtag() {
|
||
dataLayer.push(arguments);
|
||
}
|
||
|
||
gtag('js', new Date());
|
||
|
||
gtag('config', 'G-4V3HC8HBLH');
|
||
</script>
|
||
</head>
|
||
<body>
|
||
<div id='container'>
|
||
<div id="background"></div>
|
||
<div class='section'>
|
||
<div class='docs'>
|
||
<p>
|
||
<a class="parent" href="/">home</a>
|
||
<a class="parent" href="../index.html">transformers</a>
|
||
<a class="parent" href="index.html">retro</a>
|
||
</p>
|
||
<p>
|
||
<a href="https://github.com/sponsors/labmlai" target="_blank">
|
||
<img alt="Sponsor"
|
||
src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86"
|
||
style="max-width:100%;"/></a>
|
||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
|
||
<img alt="Github"
|
||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||
style="max-width:100%;"/></a>
|
||
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
|
||
<img alt="Twitter"
|
||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||
style="max-width:100%;"/></a>
|
||
</p>
|
||
<p>
|
||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/retro/model.py" target="_blank">
|
||
View code on Github</a>
|
||
</p>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-0'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-0'>#</a>
|
||
</div>
|
||
<h1>RETRO model</h1>
|
||
<p>This is the model definition for <a href="index.html">RETRO</a>.</p>
|
||
<p><a href="https://app.labml.ai/run/3113dd3ea1e711ec85ee295d18534021"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">16</span><span></span><span class="kn">import</span> <span class="nn">math</span>
|
||
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Set</span>
|
||
<span class="lineno">18</span>
|
||
<span class="lineno">19</span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||
<span class="lineno">21</span>
|
||
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml.logger</span> <span class="kn">import</span> <span class="n">inspect</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-1'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-1'>#</a>
|
||
</div>
|
||
<h2><a href="../rope/index.html">RoPE embeddings</a></h2>
|
||
<p><em>We use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. <a href="https://papers.labml.ai/paper/3999902edc8511eba3db37f65e372566">Non-causal self-attention needs explicit positional information because it cannot infer it</a>.</em></p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">25</span><span class="k">class</span> <span class="nc">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-2'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-2'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">d</span></code>
|
||
is the number of features <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord coloredeq eqbq" style=""><span class="mord mathnormal" style="">d</span></span></span></span></span></span> </li>
|
||
<li><code class="highlight"><span></span><span class="n">base</span></code>
|
||
is the constant used for calculating <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbk" style=""><span class="mord" style="">Θ</span></span></span></span></span></span></li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">36</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">base</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10_000</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-3'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-3'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">41</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-4'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-4'>#</a>
|
||
</div>
|
||
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbk" style=""><span class="mord" style="">Θ</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.43445em;vertical-align:-0.345em;"></span><span class="mord"><span class="mord coloredeq eqbi" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord">1000</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:1.08945em;"><span style="top:-3.363em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0377857142857143em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.5020714285714285em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mbin mtight">−</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">i</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mopen">[</span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">2</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbg" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mclose">]</span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">43</span> <span class="bp">self</span><span class="o">.</span><span class="n">theta</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="mf">1.</span> <span class="o">/</span> <span class="p">(</span><span class="n">base</span> <span class="o">**</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">d</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">/</span> <span class="n">d</span><span class="p">)),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-5'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-5'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
|
||
is the Tensor at the head of a key or a query with shape <code class="highlight"><span></span><span class="p">[</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d</span><span class="p">]</span></code>
|
||
</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">45</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-6'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-6'>#</a>
|
||
</div>
|
||
<p>Extract the shape </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">50</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-7'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-7'>#</a>
|
||
</div>
|
||
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mord coloredeq eqbg" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">53</span> <span class="n">d_2</span> <span class="o">=</span> <span class="n">d</span> <span class="o">//</span> <span class="mi">2</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-8'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-8'>#</a>
|
||
</div>
|
||
<p>Create position indexes <code class="highlight"><span></span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="n">seq_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">seq_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">type_as</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">theta</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-9'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-9'>#</a>
|
||
</div>
|
||
<p>Calculate the product of position index and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbi" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">idx_theta</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'n,d->nd'</span><span class="p">,</span> <span class="n">seq_idx</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">theta</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-10'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-10'>#</a>
|
||
</div>
|
||
<p>Concatenate so that for row <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span></span></span></span></span> we have <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.26202em;vertical-align:-0.5120199999999999em;"></span><span class="mopen">[</span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.7287800000000004em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbg" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.5120199999999999em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.7287800000000004em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbg" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.5120199999999999em;"><span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">idx_theta2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">idx_theta</span><span class="p">,</span> <span class="n">idx_theta</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-11'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-11'>#</a>
|
||
</div>
|
||
<p>Calculate <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.22902em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">−</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97902em;"><span style="top:-3.363em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqbg" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span><span class="mbin mtight">+</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">−</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97902em;"><span style="top:-3.363em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqbg" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span><span class="mbin mtight">+</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">−</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">−</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97902em;"><span style="top:-3.363em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqbg" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">67</span> <span class="n">neg_half_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="o">-</span><span class="n">x</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="n">d_2</span><span class="p">:],</span> <span class="n">x</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="p">:</span><span class="n">d_2</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-12'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-12'>#</a>
|
||
</div>
|
||
<p>Calculate</p>
|
||
<span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.42324em;vertical-align:-1.4616200000000001em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.9616199999999997em;"><span style="top:-3.96162em;"><span class="pstrut" style="height:3.8116199999999996em;"></span><span class="mord"><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size4">(</span></span><span class="mord"><span class="mtable"><span class="col-align-c"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8116199999999998em;"><span style="top:-3.81162em;"><span class="pstrut" style="height:3.20162em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.5834080000000004em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqbr" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.11659199999999997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">cos</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord coloredeq eqbi" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.20162em;"><span style="top:-2.883408em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqbr" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span style="top:-3.5856000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mbin mtight">+</span><span class="mord mtight coloredeq eqbg" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.11659199999999997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">sin</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord coloredeq eqbi" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><span style="top:-2.25em;"><span class="pstrut" style="height:3.20162em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.20162em;"><span style="top:-2.883408em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqbr" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span style="top:-3.5856000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mbin mtight">+</span><span class="mord mtight coloredeq eqbg" style=""><span class="mord mtight" style=""><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8800285714285714em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.11659199999999997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">cos</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord coloredeq eqbi" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.5834080000000004em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqbr" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.11659199999999997em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">sin</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbr" style=""><span class="mord mathnormal" style="">m</span></span><span class="mord coloredeq eqbi" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.3116200000000002em;"><span></span></span></span></span></span></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size4">)</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.4616200000000001em;"><span></span></span></span></span></span></span></span></span></span></span></span></span><p>for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.69862em;vertical-align:-0.0391em;"></span><span class="mord mathnormal">i</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.2251079999999999em;vertical-align:-0.345em;"></span><span class="mord"><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">2</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">...</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbg" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">rx</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">*</span> <span class="n">idx_theta2</span><span class="o">.</span><span class="n">cos</span><span class="p">()[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:])</span> <span class="o">+</span> <span class="p">(</span><span class="n">neg_half_x</span> <span class="o">*</span> <span class="n">idx_theta2</span><span class="o">.</span><span class="n">sin</span><span class="p">()[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-13'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-13'>#</a>
|
||
</div>
|
||
<p> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">82</span> <span class="k">return</span> <span class="n">rx</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-14'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-14'>#</a>
|
||
</div>
|
||
<h2>Self-Attention Layer <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqw" style=""><span class="mord text" style=""><span class="mord" style="">A</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">TTN</span></span></span></span></span></span></span></span></h2>
|
||
<p>This applies causal and non-causal <a href="../mha.html">multi-headed self-attention</a>.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">85</span><span class="k">class</span> <span class="nc">SelfAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-15'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-15'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||
is the number of features in transformer embeddings </li>
|
||
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
|
||
is the number of attention heads </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_k</span></code>
|
||
is the number of features per head </li>
|
||
<li><code class="highlight"><span></span><span class="n">is_causal</span></code>
|
||
indicates whether this is causal attention (masked)</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">92</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">is_causal</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-16'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-16'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">99</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||
<span class="lineno">100</span>
|
||
<span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_causal</span> <span class="o">=</span> <span class="n">is_causal</span>
|
||
<span class="lineno">102</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
|
||
<span class="lineno">103</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_k</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-17'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-17'>#</a>
|
||
</div>
|
||
<p>To scale attentions before softmax by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.383108em;vertical-align:-0.538em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord sqrt mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em;"><span class="mord mtight"><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
|
||
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
|
||
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
|
||
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
|
||
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
|
||
c69,-144,104.5,-217.7,106.5,-221
|
||
l0 -0
|
||
c5.3,-9.3,12,-14,20,-14
|
||
H400000v40H845.2724
|
||
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
|
||
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
|
||
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">106</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-18'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-18'>#</a>
|
||
</div>
|
||
<p>Linear layers for query, key and value heads. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">)</span>
|
||
<span class="lineno">110</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">)</span>
|
||
<span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-19'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-19'>#</a>
|
||
</div>
|
||
<p>Pre-norm layer. The paper uses RMSNorm instead. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">114</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-20'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-20'>#</a>
|
||
</div>
|
||
<p>Softmax for attention probabilities </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-21'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-21'>#</a>
|
||
</div>
|
||
<p>Rotary positional embeddings </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">rotary_pe</span> <span class="o">=</span> <span class="n">RotaryPositionalEmbeddings</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-22'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-22'>#</a>
|
||
</div>
|
||
<p>Final linear layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-23'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-23'>#</a>
|
||
</div>
|
||
<h3>Mask the attention layer for causal attention</h3>
|
||
<ul><li><code class="highlight"><span></span><span class="n">attn</span></code>
|
||
is the attention matrix of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">]</span></code>
|
||
</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">125</span> <span class="k">def</span> <span class="nf">mask_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attn</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-24'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-24'>#</a>
|
||
</div>
|
||
<p>No masking for non-causal attention </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_causal</span><span class="p">:</span>
|
||
<span class="lineno">134</span> <span class="k">return</span> <span class="n">attn</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-25'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-25'>#</a>
|
||
</div>
|
||
<p>Create a triangular mask </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">attn</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">attn</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">:]))</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-26'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-26'>#</a>
|
||
</div>
|
||
<p>Filter by the mask </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">139</span> <span class="k">return</span> <span class="n">attn</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">))</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-27'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-27'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">h</span></code>
|
||
is the transformer embeddings of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">141</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-28'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-28'>#</a>
|
||
</div>
|
||
<p>Residual connection </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">h_res</span> <span class="o">=</span> <span class="n">h</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-29'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-29'>#</a>
|
||
</div>
|
||
<p>Pre-normalization </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-30'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-30'>#</a>
|
||
</div>
|
||
<p>Get query, key, and values and split them in to heads. These will have shapes <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">mh_shape</span> <span class="o">=</span> <span class="p">(</span><span class="o">*</span><span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span>
|
||
<span class="lineno">155</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">mh_shape</span><span class="p">)</span>
|
||
<span class="lineno">156</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">mh_shape</span><span class="p">)</span>
|
||
<span class="lineno">157</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">mh_shape</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-31'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-31'>#</a>
|
||
</div>
|
||
<p>Apply rotary positional embeddings </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rotary_pe</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
|
||
<span class="lineno">161</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rotary_pe</span><span class="p">(</span><span class="n">k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-32'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-32'>#</a>
|
||
</div>
|
||
<p>Calculate attentions </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bihd,bjhd->bhij'</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-33'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-33'>#</a>
|
||
</div>
|
||
<p>Scale it by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.383108em;vertical-align:-0.538em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord sqrt mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em;"><span class="mord mtight"><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
|
||
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
|
||
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
|
||
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
|
||
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
|
||
c69,-144,104.5,-217.7,106.5,-221
|
||
l0 -0
|
||
c5.3,-9.3,12,-14,20,-14
|
||
H400000v40H845.2724
|
||
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
|
||
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
|
||
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-34'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-34'>#</a>
|
||
</div>
|
||
<p>Apply masks if it's causal attention </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_attention</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-35'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-35'>#</a>
|
||
</div>
|
||
<p>Calculate attention probabilities </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">172</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-36'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-36'>#</a>
|
||
</div>
|
||
<p>Get values </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">175</span> <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"bhij,bjhd->bihd"</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-37'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-37'>#</a>
|
||
</div>
|
||
<p>Change from shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
|
||
to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-38'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-38'>#</a>
|
||
</div>
|
||
<p>Apply final linear layer. The result will have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">183</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-39'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-39'>#</a>
|
||
</div>
|
||
<p>Add the residual connection </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">186</span> <span class="k">return</span> <span class="n">h</span> <span class="o">+</span> <span class="n">h_res</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-40'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-40'>#</a>
|
||
</div>
|
||
<h2>Cross-Attention Layer <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqba" style=""><span class="mord text" style=""><span class="mord" style="">C</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">A</span></span></span></span></span></span></span></span></h2>
|
||
<p>This is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.</p>
|
||
<p>This is used in the encoder to encode the retrieved chunks based on the input chunks.</p>
|
||
<p><em>We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.</em></p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">189</span><span class="k">class</span> <span class="nc">CrossAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-41'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-41'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||
is the number of features in transformer embeddings </li>
|
||
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
|
||
is the number of attention heads </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_k</span></code>
|
||
is the number of features per head</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-42'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-42'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">209</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||
<span class="lineno">210</span>
|
||
<span class="lineno">211</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
|
||
<span class="lineno">212</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_k</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-43'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-43'>#</a>
|
||
</div>
|
||
<p>To scale attentions before softmax by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.383108em;vertical-align:-0.538em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord sqrt mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em;"><span class="mord mtight"><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
|
||
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
|
||
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
|
||
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
|
||
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
|
||
c69,-144,104.5,-217.7,106.5,-221
|
||
l0 -0
|
||
c5.3,-9.3,12,-14,20,-14
|
||
H400000v40H845.2724
|
||
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
|
||
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
|
||
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">215</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-44'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-44'>#</a>
|
||
</div>
|
||
<p>Linear layers for query, key and value heads. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">218</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">)</span>
|
||
<span class="lineno">219</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">)</span>
|
||
<span class="lineno">220</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-45'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-45'>#</a>
|
||
</div>
|
||
<p>Pre-norm layer for the query embeddings. The paper uses RMSNorm instead. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">223</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-46'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-46'>#</a>
|
||
</div>
|
||
<p>Softmax for attention probabilities </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">226</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-47'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-47'>#</a>
|
||
</div>
|
||
<p>Final linear layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">229</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-48'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-48'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">e</span></code>
|
||
are the retrieved nearest neighbor chunk embeddings with shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</li>
|
||
<li><code class="highlight"><span></span><span class="n">h</span></code>
|
||
are the input chunks from which the nearest neighbors were retrieved with shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
. This is already normalized.</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">231</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">e</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-49'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-49'>#</a>
|
||
</div>
|
||
<p>Residual connection </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">240</span> <span class="n">e_res</span> <span class="o">=</span> <span class="n">e</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-50'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-50'>#</a>
|
||
</div>
|
||
<p>Normalize retrieved chunks </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">243</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">e</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-51'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-51'>#</a>
|
||
</div>
|
||
<p>Get query from the retrieved chunks </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">246</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">e</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">e</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-52'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-52'>#</a>
|
||
</div>
|
||
<p>Get keys and values from the input chunks </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">248</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span>
|
||
<span class="lineno">249</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-53'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-53'>#</a>
|
||
</div>
|
||
<p>Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bcnihd,bcjhd->bcnhij'</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-54'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-54'>#</a>
|
||
</div>
|
||
<p>Scale attention scores </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">256</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-55'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-55'>#</a>
|
||
</div>
|
||
<p>Calculate softmax across the last dimension </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">259</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-56'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-56'>#</a>
|
||
</div>
|
||
<p>Gather values </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">262</span> <span class="n">e</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"bcnhij,bcjhd->bcnihd"</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-57'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-57'>#</a>
|
||
</div>
|
||
<p>Change from shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
|
||
to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">266</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">e</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-58'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-58'>#</a>
|
||
</div>
|
||
<p>Apply final linear layer. The result will have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">270</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">e</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-59'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-59'>#</a>
|
||
</div>
|
||
<p>Add residual connection </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">273</span> <span class="k">return</span> <span class="n">e</span> <span class="o">+</span> <span class="n">e_res</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-60'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-60'>#</a>
|
||
</div>
|
||
<h2>Chunked Cross-Attention Layer <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqy" style=""><span class="mord text" style=""><span class="mord" style="">C</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">CA</span></span></span></span></span></span></span></span></h2>
|
||
<p>This is similar to the cross-attention layer defined above.</p>
|
||
<p>This is used in the decoder to pay attention to the retrieved neighbor chunks.</p>
|
||
<p><em>We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.</em></p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">276</span><span class="k">class</span> <span class="nc">ChunkedCrossAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-61'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-61'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||
is the number of features in transformer embeddings </li>
|
||
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
|
||
is the number of attention heads </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_k</span></code>
|
||
is the number of features per head </li>
|
||
<li><code class="highlight"><span></span><span class="n">chunk_len</span></code>
|
||
is the length of a chunk</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">288</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-62'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-62'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">296</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||
<span class="lineno">297</span>
|
||
<span class="lineno">298</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">=</span> <span class="n">chunk_len</span>
|
||
<span class="lineno">299</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
|
||
<span class="lineno">300</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_k</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-63'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-63'>#</a>
|
||
</div>
|
||
<p>To scale attentions before softmax by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.383108em;vertical-align:-0.538em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord sqrt mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em;"><span class="mord mtight"><span class="mord mtight coloredeq eqbq" style=""><span class="mord mathnormal mtight" style="">d</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
|
||
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
|
||
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
|
||
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
|
||
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
|
||
c69,-144,104.5,-217.7,106.5,-221
|
||
l0 -0
|
||
c5.3,-9.3,12,-14,20,-14
|
||
H400000v40H845.2724
|
||
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
|
||
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
|
||
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">303</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-64'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-64'>#</a>
|
||
</div>
|
||
<p>Linear layers for query, key and value heads. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">306</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">)</span>
|
||
<span class="lineno">307</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">)</span>
|
||
<span class="lineno">308</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-65'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-65'>#</a>
|
||
</div>
|
||
<p>Pre-norm layer for the query embeddings. The paper uses RMSNorm instead. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">311</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-66'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-66'>#</a>
|
||
</div>
|
||
<p>Softmax for attention probabilities </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">314</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-67'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-67'>#</a>
|
||
</div>
|
||
<p>Final linear layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">317</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-68'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-68'>#</a>
|
||
</div>
|
||
<p> <code class="highlight"><span></span><span class="n">h</span></code>
|
||
are the input embeddings of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
<code class="highlight"><span></span><span class="n">e</span></code>
|
||
are the retrieved nearest neighbors of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">319</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">e</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-69'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-69'>#</a>
|
||
</div>
|
||
<p>Get shape </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">326</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-70'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-70'>#</a>
|
||
</div>
|
||
<p>No attention if there are no chunks (for short inputs when sampling) </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">329</span> <span class="k">if</span> <span class="n">chunks</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="lineno">330</span> <span class="k">return</span> <span class="n">h</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-71'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-71'>#</a>
|
||
</div>
|
||
<p>Residual connection </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">333</span> <span class="n">h_res</span> <span class="o">=</span> <span class="n">h</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-72'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-72'>#</a>
|
||
</div>
|
||
<p>Remove the first <code class="highlight"><span></span><span class="n">chunk_len</span> <span class="o">-</span> <span class="mi">1</span></code>
|
||
embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by <code class="highlight"><span></span><span class="n">chunk_len</span> <span class="o">-</span> <span class="mi">1</span></code>
|
||
we make sure that information only flows to the right. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">341</span> <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-73'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-73'>#</a>
|
||
</div>
|
||
<p>Pre-norm </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">343</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-74'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-74'>#</a>
|
||
</div>
|
||
<p>Append empty embeddings to the end to be able to split the input into chunks </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">345</span> <span class="k">if</span> <span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o"><</span> <span class="n">chunks</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">:</span>
|
||
<span class="lineno">346</span> <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">h</span><span class="p">,</span> <span class="n">h</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">-</span> <span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">d_model</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-75'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-75'>#</a>
|
||
</div>
|
||
<p>Reshape the input into chunks. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">348</span> <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-76'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-76'>#</a>
|
||
</div>
|
||
<p>Get query from the input </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">351</span> <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-77'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-77'>#</a>
|
||
</div>
|
||
<p>Get keys and values from the retrieved neighbors </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">353</span> <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">e</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">e</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span>
|
||
<span class="lineno">354</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">e</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">e</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-78'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-78'>#</a>
|
||
</div>
|
||
<p>Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">359</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bcihd,bcnjhd->bchinj'</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-79'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-79'>#</a>
|
||
</div>
|
||
<p>Scale attention scores </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">361</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-80'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-80'>#</a>
|
||
</div>
|
||
<p>Apply softmax over the last two dimensions <code class="highlight"><span></span><span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">364</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">attn</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">attn</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-81'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-81'>#</a>
|
||
</div>
|
||
<p>Gather values </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">367</span> <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"bchinj,bcnjhd->bcihd"</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-82'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-82'>#</a>
|
||
</div>
|
||
<p>Change from shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
|
||
to <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span> <span class="o">*</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">*</span> <span class="n">d_k</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">371</span> <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-83'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-83'>#</a>
|
||
</div>
|
||
<p>Apply final linear layer. The result will have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span> <span class="o">*</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">375</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-84'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-84'>#</a>
|
||
</div>
|
||
<p>Append <code class="highlight"><span></span><span class="n">chunk_len</span> <span class="o">-</span> <span class="mi">1</span></code>
|
||
zero embedding to the left; i.e. right shift it back </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">378</span> <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">h</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">),</span> <span class="n">h</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-85'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-85'>#</a>
|
||
</div>
|
||
<p>Truncate and add the residual connection </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">381</span> <span class="k">return</span> <span class="n">h</span><span class="p">[:,</span> <span class="p">:</span><span class="n">h_res</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span> <span class="o">+</span> <span class="n">h_res</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-86'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-86'>#</a>
|
||
</div>
|
||
<h3>Position-wise Feed Forward Layer <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqz" style=""><span class="mord text" style=""><span class="mord" style="">F</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">FW</span></span></span></span></span></span></span></span></h3>
|
||
<p>This consists of two linear layers and an activation in the middle.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">384</span><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-87'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-87'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||
is the number of features in transformer embeddings </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_ff</span></code>
|
||
is the number features in the hidden layer</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">391</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-88'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-88'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">397</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-89'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-89'>#</a>
|
||
</div>
|
||
<p>The two linear layers </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">400</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span>
|
||
<span class="lineno">401</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-90'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-90'>#</a>
|
||
</div>
|
||
<p>ReLU Activation </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">404</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-91'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-91'>#</a>
|
||
</div>
|
||
<p>Pre-norm layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">407</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-92'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-92'>#</a>
|
||
</div>
|
||
<p> <code class="highlight"><span></span><span class="n">h</span></code>
|
||
are the embeddings of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">409</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-93'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-93'>#</a>
|
||
</div>
|
||
<p>Residual </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">415</span> <span class="n">h_res</span> <span class="o">=</span> <span class="n">h</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-94'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-94'>#</a>
|
||
</div>
|
||
<p>Pre-norm </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">417</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-95'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-95'>#</a>
|
||
</div>
|
||
<p>First linear layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">419</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin1</span><span class="p">(</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-96'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-96'>#</a>
|
||
</div>
|
||
<p>Activation </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">421</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-97'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-97'>#</a>
|
||
</div>
|
||
<p>Second linear layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">423</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin2</span><span class="p">(</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-98'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-98'>#</a>
|
||
</div>
|
||
<p>Add the residual connection </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">426</span> <span class="k">return</span> <span class="n">h</span> <span class="o">+</span> <span class="n">h_res</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-99'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-99'>#</a>
|
||
</div>
|
||
<h2>Nearest Neighbor Encoder <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqh" style=""><span class="mord text" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">NCOD</span><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">R</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq eqm" style=""><span class="mord text" style=""><span class="mord" style="">R</span><span class="mord sizing reset-size6 size5" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">T</span></span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style=""><span class="mclose" style="">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="">u</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.24517899999999998em;"><span></span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mclose" style="">)</span></span></span></span></span></span></h2>
|
||
<p>This module encodes the retrieved nearest neighbors</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">429</span><span class="k">class</span> <span class="nc">NearestNeighborEncoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-100'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-100'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">chunk_len</span></code>
|
||
is the length of a chunk </li>
|
||
<li><code class="highlight"><span></span><span class="n">n_layer</span></code>
|
||
is the number of layers in the encoder <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbb" style=""><span class="mord" style=""><span class="mord coloredeq eqbn" style=""><span class="mord mathnormal" style="">L</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">enc</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </li>
|
||
<li><code class="highlight"><span></span><span class="n">ca_layers</span></code>
|
||
are the layers with cross attention <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbc" style=""><span class="mord" style=""><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">enc</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||
is the number of features in embeddings </li>
|
||
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
|
||
is the number of heads in attention layers </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_k</span></code>
|
||
is the size of attention heads </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_ff</span></code>
|
||
is the size of the feed-forward networks hidden layers</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">436</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ca_layers</span><span class="p">:</span> <span class="n">Set</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||
<span class="lineno">437</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-101'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-101'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">448</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||
<span class="lineno">449</span> <span class="bp">self</span><span class="o">.</span><span class="n">ca_layers</span> <span class="o">=</span> <span class="n">ca_layers</span>
|
||
<span class="lineno">450</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">=</span> <span class="n">chunk_len</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-102'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-102'>#</a>
|
||
</div>
|
||
<p>Cross-attention layers </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">452</span> <span class="bp">self</span><span class="o">.</span><span class="n">ca</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">CrossAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ca_layers</span><span class="p">))])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-103'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-103'>#</a>
|
||
</div>
|
||
<p>Bi-directional self attention layers </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">454</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">SelfAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">is_causal</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span><span class="p">)])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-104'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-104'>#</a>
|
||
</div>
|
||
<p>Feed forward layers </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">456</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffw</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">FeedForward</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span><span class="p">)])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-105'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-105'>#</a>
|
||
</div>
|
||
<p>Pre-normalization layer for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">459</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_h</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-106'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-106'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">e</span></code>
|
||
are token embeddings of the retrieved nearest neighbors, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord text"><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord sizing reset-size6 size5"><span class="mord">MB</span></span></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mord coloredeq eqm" style=""><span class="mord text" style=""><span class="mord" style="">R</span><span class="mord sizing reset-size6 size5" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">T</span></span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style=""><span class="mclose" style="">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="">u</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.24517899999999998em;"><span></span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span> of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</li></ul>
|
||
<ul><li><code class="highlight"><span></span><span class="n">h</span></code>
|
||
is are the input token embeddings, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span></span></span></span></span> of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</li></ul>
|
||
<p><em>The chunks <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqbf" style=""><span class="mord mathnormal" style="">u</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel" style="">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mopen" style="">[</span><span class="mord" style="">1</span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="mclose" style="">]</span></span></span></span></span></span> and neighbors <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.85396em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.05724em;">j</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="mclose">]</span></span></span></span></span> are processed in parallel.</em></p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">461</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">e</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-107'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-107'>#</a>
|
||
</div>
|
||
<p>Get shape </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">474</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-108'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-108'>#</a>
|
||
</div>
|
||
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.1052em;vertical-align:-0.3551999999999999em;"></span><span class="mopen">(</span><span class="mord"><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.34480000000000005em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqbf" style=""><span class="mord mathnormal mtight" style="">u</span><span class="mrel mtight" style="">∈</span><span class="mopen mtight" style="">[</span><span class="mord mtight" style="">1</span><span class="mpunct mtight" style="">,</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span><span class="mclose mtight" style="">]</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.3551999999999999em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord">S</span><span class="mord sizing reset-size6 size5"><span class="mord coloredeq eqbo" style=""><span class="mord" style="">P</span></span><span class="mord coloredeq eqbn" style=""><span class="mord" style="">L</span></span><span class="mord">IT</span></span></span><span class="mopen">(</span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">477</span> <span class="n">h_split</span> <span class="o">=</span> <span class="n">h</span><span class="p">[:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span> <span class="o">*</span> <span class="n">chunks</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-109'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-109'>#</a>
|
||
</div>
|
||
<p>Pre-norm </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">480</span> <span class="n">h_split</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_h</span><span class="p">(</span><span class="n">h_split</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-110'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-110'>#</a>
|
||
</div>
|
||
<p>Keep the index of the cross attention layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">483</span> <span class="n">p_ca</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-111'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-111'>#</a>
|
||
</div>
|
||
<p>For all layers <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.946332em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.751892em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbb" style=""><span class="mord" style=""><span class="mord coloredeq eqbn" style=""><span class="mord mathnormal" style="">L</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">enc</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">485</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">)):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-112'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-112'>#</a>
|
||
</div>
|
||
<p>Bi-directional self attention <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.071664em;vertical-align:-0.247em;"></span><span class="mord"><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.824664em;"><span style="top:-2.4530000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0746639999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord coloredeq eqw" style=""><span class="mord text" style=""><span class="mord" style="">A</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">TTN</span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">enc</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.824664em;"><span style="top:-2.4530000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">488</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">[</span><span class="n">p</span><span class="p">](</span><span class="n">e</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">e</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-113'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-113'>#</a>
|
||
</div>
|
||
<p>Cross attention if <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.946332em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.751892em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqbc" style=""><span class="mord" style=""><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord text mtight" style=""><span class="mord mtight" style="">enc</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">491</span> <span class="k">if</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ca_layers</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-114'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-114'>#</a>
|
||
</div>
|
||
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.071664em;vertical-align:-0.247em;"></span><span class="mord"><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.824664em;"><span style="top:-2.4530000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0746639999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord coloredeq eqba" style=""><span class="mord text" style=""><span class="mord" style="">C</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">A</span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">enc</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.824664em;"><span style="top:-2.4530000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">493</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ca</span><span class="p">[</span><span class="n">p_ca</span><span class="p">](</span><span class="n">e</span><span class="p">,</span> <span class="n">h_split</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-115'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-115'>#</a>
|
||
</div>
|
||
<p>Incremnt the cross attention index </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">495</span> <span class="n">p_ca</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-116'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-116'>#</a>
|
||
</div>
|
||
<p>Feed forward layer <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.071664em;vertical-align:-0.247em;"></span><span class="mord"><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.824664em;"><span style="top:-2.4530000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.0746639999999998em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord coloredeq eqz" style=""><span class="mord text" style=""><span class="mord" style="">F</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">FW</span></span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">enc</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.824664em;"><span style="top:-2.4530000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">498</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffw</span><span class="p">[</span><span class="n">p</span><span class="p">](</span><span class="n">e</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-117'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-117'>#</a>
|
||
</div>
|
||
<p>return <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">501</span> <span class="k">return</span> <span class="n">e</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-118'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-118'>#</a>
|
||
</div>
|
||
<h2>Retro Model</h2>
|
||
<p>This is the Retro decoder</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">504</span><span class="k">class</span> <span class="nc">RetroModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-119'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-119'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">v_vocab</span></code>
|
||
is the number of tokens in the vocabulary </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||
is the number of features in embeddings </li>
|
||
<li><code class="highlight"><span></span><span class="n">n_layers</span></code>
|
||
is the number of layers in the decoder <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbn" style=""><span class="mord mathnormal" style="">L</span></span></span></span></span></span> </li>
|
||
<li><code class="highlight"><span></span><span class="n">ca_layers</span></code>
|
||
are the layers with cross attention <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span></span></span></span></span></span> </li>
|
||
<li><code class="highlight"><span></span><span class="n">chunk_len</span></code>
|
||
is the length of a chunk </li>
|
||
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
|
||
is the number of heads in attention layers </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_k</span></code>
|
||
is the size of attention heads </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_ff</span></code>
|
||
is the size of the feed-forward networks hidden layers </li>
|
||
<li><code class="highlight"><span></span><span class="n">encoder</span></code>
|
||
is the nearest neighbor encoder</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">511</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ca_layers</span><span class="p">:</span> <span class="n">Set</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">chunk_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="lineno">512</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">NearestNeighborEncoder</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-120'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-120'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">524</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||
<span class="lineno">525</span>
|
||
<span class="lineno">526</span> <span class="bp">self</span><span class="o">.</span><span class="n">ca_layers</span> <span class="o">=</span> <span class="n">ca_layers</span>
|
||
<span class="lineno">527</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-121'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-121'>#</a>
|
||
</div>
|
||
<p>Token embedding layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">530</span> <span class="bp">self</span><span class="o">.</span><span class="n">emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">n_vocab</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-122'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-122'>#</a>
|
||
</div>
|
||
<p>Chunked cross attention layers <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqy" style=""><span class="mord text" style=""><span class="mord" style="">C</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">CA</span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">532</span> <span class="bp">self</span><span class="o">.</span><span class="n">cca</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">(</span>
|
||
<span class="lineno">533</span> <span class="p">[</span><span class="n">ChunkedCrossAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">chunk_len</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ca_layers</span><span class="p">))])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-123'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-123'>#</a>
|
||
</div>
|
||
<p>Attention layers <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqw" style=""><span class="mord text" style=""><span class="mord" style="">A</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">TTN</span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">535</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">SelfAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">is_causal</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span><span class="p">)])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-124'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-124'>#</a>
|
||
</div>
|
||
<p>Feed forward layers <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqz" style=""><span class="mord text" style=""><span class="mord" style="">F</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">FW</span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">537</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffw</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">FeedForward</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_layers</span><span class="p">)])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-125'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-125'>#</a>
|
||
</div>
|
||
<p>Readout layer <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqx" style=""><span class="mord text" style=""><span class="mord" style="">R</span><span class="mord sizing reset-size6 size5" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">AD</span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">539</span> <span class="bp">self</span><span class="o">.</span><span class="n">read</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_vocab</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-126'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-126'>#</a>
|
||
</div>
|
||
<p>Pre-normalization layer for nearest neighbor embeddings from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqh" style=""><span class="mord text" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">NCOD</span><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">R</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq eqm" style=""><span class="mord text" style=""><span class="mord" style="">R</span><span class="mord sizing reset-size6 size5" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">T</span></span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style=""><span class="mclose" style="">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="">u</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.24517899999999998em;"><span></span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mclose" style="">)</span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">543</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_e</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-127'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-127'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
|
||
is the input sequence, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbp" style=""><span class="mord mathnormal" style="margin-right:0.07847em">X</span></span></span></span></span></span> of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">]</span></code>
|
||
</li>
|
||
<li><code class="highlight"><span></span><span class="n">ret</span></code>
|
||
are the retrieved neighbors <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqm" style=""><span class="mord text" style=""><span class="mord" style="">R</span><span class="mord sizing reset-size6 size5" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">T</span></span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style=""><span class="mclose" style="">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="">u</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.24517899999999998em;"><span></span></span></span></span></span></span></span></span></span></span></span> of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">neighbors</span><span class="p">,</span> <span class="n">neighbor_len</span><span class="p">]</span></code>
|
||
</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">545</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">ret</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-128'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-128'>#</a>
|
||
</div>
|
||
<p>Get input embeddings <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord sizing reset-size6 size5"><span class="mord">MB</span></span></span><span class="mopen">(</span><span class="mord coloredeq eqbp" style=""><span class="mord mathnormal" style="margin-right:0.07847em">X</span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">554</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-129'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-129'>#</a>
|
||
</div>
|
||
<p>Embeddings of the retrieved neighbors <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.071664em;vertical-align:-0.247em;"></span><span class="mord"><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.824664em;"><span style="top:-2.4530000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord"><span class="mord text"><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord sizing reset-size6 size5"><span class="mord">MB</span></span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">enc</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">(</span></span><span class="mord text"><span class="mord">R</span><span class="mord sizing reset-size6 size5"><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord">T</span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.824664em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size1">)</span></span></span></span></span></span>.</p>
|
||
<p>We use same embeddings for both input and neighbors </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">560</span> <span class="n">ret_emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">emb</span><span class="p">(</span><span class="n">ret</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-130'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-130'>#</a>
|
||
</div>
|
||
<p>Keep index of the chunked cross attention layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">563</span> <span class="n">p_ca</span> <span class="o">=</span> <span class="mi">0</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-131'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-131'>#</a>
|
||
</div>
|
||
<p>For all layers <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.7335400000000001em;vertical-align:-0.19444em;"></span><span class="mord mathnormal">p</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbn" style=""><span class="mord mathnormal" style="">L</span></span><span class="mclose">]</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">565</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">)):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-132'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-132'>#</a>
|
||
</div>
|
||
<p>Causal self attention <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqw" style=""><span class="mord text" style=""><span class="mord" style="">A</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">TTN</span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">567</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">[</span><span class="n">p</span><span class="p">](</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-133'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-133'>#</a>
|
||
</div>
|
||
<p>Get encoder embeddings before the first <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqy" style=""><span class="mord text" style=""><span class="mord" style="">C</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">CA</span></span></span></span></span></span></span></span> layer, when <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord mathnormal">p</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">min</span><span class="mopen">(</span><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">571</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ca_layers</span> <span class="ow">and</span> <span class="n">p</span> <span class="o">==</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ca_layers</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-134'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-134'>#</a>
|
||
</div>
|
||
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqh" style=""><span class="mord text" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">NCOD</span><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">R</span></span></span><span class="mopen" style="">(</span><span class="mord coloredeq eqm" style=""><span class="mord text" style=""><span class="mord" style="">R</span><span class="mord sizing reset-size6 size5" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">T</span></span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style=""><span class="mclose" style="">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="">u</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.24517899999999998em;"><span></span></span></span></span></span></span></span><span class="mpunct" style="">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mclose" style="">)</span></span></span></span></span></span></p>
|
||
<p>We passed the embeddings of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqm" style=""><span class="mord text" style=""><span class="mord" style="">R</span><span class="mord sizing reset-size6 size5" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">T</span></span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">C</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">u</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose" style=""><span class="mclose" style="">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style="">1</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="">u</span><span class="mrel mtight" style="">≤</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.24517899999999998em;"><span></span></span></span></span></span></span></span></span></span></span></span> to encoder. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">575</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">ret_emb</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-135'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-135'>#</a>
|
||
</div>
|
||
<p>Normalize encoder embeddings </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">577</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_e</span><span class="p">(</span><span class="n">e</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-136'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-136'>#</a>
|
||
</div>
|
||
<p>Chunked-cross attention if <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.7335400000000001em;vertical-align:-0.19444em;"></span><span class="mord mathnormal">p</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbo" style=""><span class="mord mathnormal" style="margin-right:0.13889em">P</span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">580</span> <span class="k">if</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ca_layers</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-137'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-137'>#</a>
|
||
</div>
|
||
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqy" style=""><span class="mord text" style=""><span class="mord" style="">C</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">CA</span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqbl" style=""><span class="mord mathnormal" style="margin-right:0.05764em">E</span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">582</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cca</span><span class="p">[</span><span class="n">p_ca</span><span class="p">](</span><span class="n">h</span><span class="p">,</span> <span class="n">e</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-138'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-138'>#</a>
|
||
</div>
|
||
<p>Increment chunked cross-attention index </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">584</span> <span class="n">p_ca</span> <span class="o">+=</span> <span class="mi">1</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-139'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-139'>#</a>
|
||
</div>
|
||
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqz" style=""><span class="mord text" style=""><span class="mord" style="">F</span><span class="mord sizing reset-size6 size5" style=""><span class="mord" style="">FW</span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">587</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffw</span><span class="p">[</span><span class="n">p</span><span class="p">](</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-140'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-140'>#</a>
|
||
</div>
|
||
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">O</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">←</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqx" style=""><span class="mord text" style=""><span class="mord" style="">R</span><span class="mord sizing reset-size6 size5" style=""><span class="mord coloredeq eqbl" style=""><span class="mord" style="">E</span></span><span class="mord" style="">AD</span></span></span></span><span class="mopen">(</span><span class="mord coloredeq eqbm" style=""><span class="mord mathnormal" style="margin-right:0.08125em">H</span></span><span class="mclose">)</span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">590</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">h</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-141'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-141'>#</a>
|
||
</div>
|
||
<h3>Test the model with fake data</h3>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">593</span><span class="k">def</span> <span class="nf">_test</span><span class="p">():</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-142'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-142'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">597</span> <span class="n">chunk_len</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="lineno">598</span> <span class="n">d_model</span> <span class="o">=</span> <span class="mi">8</span>
|
||
<span class="lineno">599</span> <span class="n">d_ff</span> <span class="o">=</span> <span class="mi">32</span>
|
||
<span class="lineno">600</span> <span class="n">n_heads</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="lineno">601</span> <span class="n">d_k</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="lineno">602</span>
|
||
<span class="lineno">603</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">'cuda:0'</span><span class="p">)</span>
|
||
<span class="lineno">604</span>
|
||
<span class="lineno">605</span> <span class="n">m</span> <span class="o">=</span> <span class="n">RetroModel</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="p">{</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">},</span> <span class="n">chunk_len</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span>
|
||
<span class="lineno">606</span> <span class="n">encoder</span><span class="o">=</span><span class="n">NearestNeighborEncoder</span><span class="p">(</span><span class="n">chunk_len</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="p">{</span><span class="mi">1</span><span class="p">},</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">))</span>
|
||
<span class="lineno">607</span>
|
||
<span class="lineno">608</span> <span class="n">m</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="lineno">609</span> <span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">]</span>
|
||
<span class="lineno">610</span> <span class="n">ret</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="lineno">611</span> <span class="p">[[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]],</span>
|
||
<span class="lineno">612</span> <span class="p">[[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]],</span>
|
||
<span class="lineno">613</span> <span class="p">]</span>
|
||
<span class="lineno">614</span> <span class="n">res</span> <span class="o">=</span> <span class="n">m</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">x</span><span class="p">]</span> <span class="o">*</span> <span class="mi">10</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">ret</span><span class="p">]</span> <span class="o">*</span> <span class="mi">10</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span>
|
||
<span class="lineno">615</span>
|
||
<span class="lineno">616</span> <span class="n">inspect</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-143'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-143'>#</a>
|
||
</div>
|
||
<p> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">620</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||
<span class="lineno">621</span> <span class="n">_test</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='footer'>
|
||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||
<a href="https://labml.ai">labml.ai</a>
|
||
</div>
|
||
</div>
|
||
<script src=../../interactive.js?v=1"></script>
|
||
<script>
|
||
function handleImages() {
|
||
var images = document.querySelectorAll('p>img')
|
||
|
||
for (var i = 0; i < images.length; ++i) {
|
||
handleImage(images[i])
|
||
}
|
||
}
|
||
|
||
function handleImage(img) {
|
||
img.parentElement.style.textAlign = 'center'
|
||
|
||
var modal = document.createElement('div')
|
||
modal.id = 'modal'
|
||
|
||
var modalContent = document.createElement('div')
|
||
modal.appendChild(modalContent)
|
||
|
||
var modalImage = document.createElement('img')
|
||
modalContent.appendChild(modalImage)
|
||
|
||
var span = document.createElement('span')
|
||
span.classList.add('close')
|
||
span.textContent = 'x'
|
||
modal.appendChild(span)
|
||
|
||
img.onclick = function () {
|
||
console.log('clicked')
|
||
document.body.appendChild(modal)
|
||
modalImage.src = img.src
|
||
}
|
||
|
||
span.onclick = function () {
|
||
document.body.removeChild(modal)
|
||
}
|
||
}
|
||
|
||
handleImages()
|
||
</script>
|
||
</body>
|
||
</html> |