mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 01:13:00 +08:00
1848 lines
159 KiB
HTML
1848 lines
159 KiB
HTML
<!DOCTYPE html>
|
||
<html>
|
||
<head>
|
||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||
<meta name="description" content="This is an annotated implementation/tutorial the Feedback Transformer in PyTorch."/>
|
||
|
||
<meta name="twitter:card" content="summary"/>
|
||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta name="twitter:title" content="Feedback Transformer"/>
|
||
<meta name="twitter:description" content="This is an annotated implementation/tutorial the Feedback Transformer in PyTorch."/>
|
||
<meta name="twitter:site" content="@labmlai"/>
|
||
<meta name="twitter:creator" content="@labmlai"/>
|
||
|
||
<meta property="og:url" content="https://nn.labml.ai/transformers/feedback/index.html"/>
|
||
<meta property="og:title" content="Feedback Transformer"/>
|
||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||
<meta property="og:type" content="object"/>
|
||
<meta property="og:title" content="Feedback Transformer"/>
|
||
<meta property="og:description" content="This is an annotated implementation/tutorial the Feedback Transformer in PyTorch."/>
|
||
|
||
<title>Feedback Transformer</title>
|
||
<link rel="shortcut icon" href="/icon.png"/>
|
||
<link rel="stylesheet" href="../../pylit.css?v=1">
|
||
<link rel="canonical" href="https://nn.labml.ai/transformers/feedback/index.html"/>
|
||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||
|
||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||
<script>
|
||
window.dataLayer = window.dataLayer || [];
|
||
|
||
function gtag() {
|
||
dataLayer.push(arguments);
|
||
}
|
||
|
||
gtag('js', new Date());
|
||
|
||
gtag('config', 'G-4V3HC8HBLH');
|
||
</script>
|
||
</head>
|
||
<body>
|
||
<div id='container'>
|
||
<div id="background"></div>
|
||
<div class='section'>
|
||
<div class='docs'>
|
||
<p>
|
||
<a class="parent" href="/">home</a>
|
||
<a class="parent" href="../index.html">transformers</a>
|
||
<a class="parent" href="index.html">feedback</a>
|
||
</p>
|
||
<p>
|
||
|
||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/feedback/__init__.py">
|
||
<img alt="Github"
|
||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||
style="max-width:100%;"/></a>
|
||
<a href="https://twitter.com/labmlai"
|
||
rel="nofollow">
|
||
<img alt="Twitter"
|
||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||
style="max-width:100%;"/></a>
|
||
</p>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-0'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-0'>#</a>
|
||
</div>
|
||
<h1>Feedback Transformer</h1>
|
||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper <a href="https://papers.labml.ai/paper/2002.09402">Accessing Higher-level Representations in Sequential Transformers with Feedback Memory</a>.</p>
|
||
<p>Normal transformers process tokens in parallel. Each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence, and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However, when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.</p>
|
||
<p>In order to speed up the training, the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.</p>
|
||
<p>The original feedback transformer doesn't keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction. The first half of this file implements this.</p>
|
||
<p>The updated feedback transformer shares weights <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.132216em;vertical-align:-0.2831079999999999em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-2.4168920000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.01968em;">l</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2831079999999999em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.096108em;vertical-align:-0.247em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-2.4530000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">v</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.01968em;">l</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span></span></span></span> used to calculate keys and values among the layers. We then calculate the keys and values for each step only once and keep them cached. The <a href="#shared_kv">second half</a> of this file implements this. We implemented a custom PyTorch function to improve performance.</p>
|
||
<p>Here's <a href="experiment.html">the training code</a> and a notebook for training a feedback transformer on Tiny Shakespeare dataset.</p>
|
||
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/feedback/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/d8eb9416530a11eb8fb50242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">43</span><span></span><span class="kn">import</span> <span class="nn">math</span>
|
||
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span>
|
||
<span class="lineno">45</span>
|
||
<span class="lineno">46</span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="lineno">47</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||
<span class="lineno">48</span>
|
||
<span class="lineno">49</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
||
<span class="lineno">50</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
|
||
<span class="lineno">51</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">PrepareForMultiHeadAttention</span>
|
||
<span class="lineno">52</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-1'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-1'>#</a>
|
||
</div>
|
||
<h2>Feedback Attention</h2>
|
||
<p>This module computes recurrent attention similar to attention from original transformers paper.</p>
|
||
<p><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop"><span class="mord mathnormal">A</span><span class="mord mathnormal">tt</span><span class="mord mathnormal">e</span><span class="mord mathnormal">n</span><span class="mord mathnormal">t</span><span class="mord mathnormal">i</span><span class="mord mathnormal">o</span><span class="mord mathnormal">n</span></span><span class="mopen">(</span><span class="mord mathnormal">Q</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mop"><span class="mord mathnormal">so</span><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="mord mathnormal">t</span><span class="mord mathnormal">ma</span><span class="mord mathnormal">x</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size4">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.5261079999999998em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
|
||
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
|
||
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
|
||
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
|
||
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
|
||
c69,-144,104.5,-217.7,106.5,-221
|
||
l0 -0
|
||
c5.3,-9.3,12,-14,20,-14
|
||
H400000v40H845.2724
|
||
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
|
||
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
|
||
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.18278000000000005em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">Q</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right:0.07153em;">K</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="delimsizing size4">)</span></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span></span></span></span></span></p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">55</span><span class="k">class</span> <span class="nc">FeedbackAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-2'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-2'>#</a>
|
||
</div>
|
||
<ul><li>'heads' is the number of attention heads </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||
is the number of features in the transformer </li>
|
||
<li><code class="highlight"><span></span><span class="n">dropout_prob</span></code>
|
||
is the attention dropout probability </li>
|
||
<li><code class="highlight"><span></span><span class="n">is_kv_precomputed</span></code>
|
||
is whether key, value tensors are already calculated</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||
<span class="lineno">67</span> <span class="n">is_kv_precomputed</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-3'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-3'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">75</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-4'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-4'>#</a>
|
||
</div>
|
||
<p>Number of features per head </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">heads</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-5'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-5'>#</a>
|
||
</div>
|
||
<p> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">heads</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-6'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-6'>#</a>
|
||
</div>
|
||
<p>These transform the <code class="highlight"><span></span><span class="n">query</span></code>
|
||
multi-headed attention. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-7'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-7'>#</a>
|
||
</div>
|
||
<p>These transform the <code class="highlight"><span></span><span class="n">key</span></code>
|
||
and <code class="highlight"><span></span><span class="n">value</span></code>
|
||
for multi-headed attention. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">85</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">is_kv_precomputed</span><span class="p">:</span>
|
||
<span class="lineno">86</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="lineno">87</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-8'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-8'>#</a>
|
||
</div>
|
||
<p>Keys and values are already calculated </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">89</span> <span class="k">else</span><span class="p">:</span>
|
||
<span class="lineno">90</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="lineno">91</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-9'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-9'>#</a>
|
||
</div>
|
||
<p>Output layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">94</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-10'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-10'>#</a>
|
||
</div>
|
||
<p>Dropout </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">96</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-11'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-11'>#</a>
|
||
</div>
|
||
<p>Scaling factor before the softmax </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">98</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-12'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-12'>#</a>
|
||
</div>
|
||
<p>Softmax for attention along the time dimension of <code class="highlight"><span></span><span class="n">key</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">101</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-13'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-13'>#</a>
|
||
</div>
|
||
<p>Number of relative positions </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">104</span> <span class="bp">self</span><span class="o">.</span><span class="n">P</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">**</span> <span class="mi">12</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-14'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-14'>#</a>
|
||
</div>
|
||
<p>Relative positional embeddings for key relative to the query. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_pos_embeddings</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">P</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-15'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-15'>#</a>
|
||
</div>
|
||
<p>Relative positional embedding bias for key relative to the query. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_pos_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">P</span><span class="p">,</span> <span class="n">heads</span><span class="p">)),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-16'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-16'>#</a>
|
||
</div>
|
||
<p>Positional embeddings for the query is independent of the position of the query </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_pos_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_k</span><span class="p">)),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-17'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-17'>#</a>
|
||
</div>
|
||
<p>We store attentions so that it can be used for logging, or other computations if needed </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">114</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-18'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-18'>#</a>
|
||
</div>
|
||
<h3>Get attention scores</h3>
|
||
<p>We use relative positional encodings for attention, similar to <a href="../relative_mha.html">relative multi-head attention form Transformer-XL paper</a>.</p>
|
||
<p>Attention from current step's query to key in step <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.85396em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqr" style=""><span class="mord mathnormal" style="margin-right:0.05724em">j</span></span></span></span></span> (relative to current step) is,</p>
|
||
<span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:7.281318em;vertical-align:-3.3906589999999994em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.8906590000000003em;"><span style="top:-6.1218900000000005em;"><span class="pstrut" style="height:3.130339em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span><span style="top:-4.562782em;"><span class="pstrut" style="height:3.130339em;"></span><span class="mord"></span></span><span style="top:-2.9805660000000005em;"><span class="pstrut" style="height:3.130339em;"></span><span class="mord"></span></span><span style="top:-1.167119000000001em;"><span class="pstrut" style="height:3.130339em;"></span><span class="mord"></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:3.3906589999999994em;"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:3.8906590000000003em;"><span style="top:-6.1218900000000005em;"><span class="pstrut" style="height:3.130339em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mord mathnormal">Q</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8991079999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span><span style="top:-4.562782em;"><span class="pstrut" style="height:3.130339em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">i</span><span class="mord"><span class="mord mathnormal">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.714392em;"><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8991079999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">i</span><span class="mord"><span class="mord mathnormal">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.899108em;"><span style="top:-2.4530000000000003em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.383108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-2.9805660000000005em;"><span class="pstrut" style="height:3.130339em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mopen">(</span><span class="mord mathnormal">Q</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eqq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">Q</span></span></span></span></span></span></span></span></span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8991079999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-2.4530000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.383108em;"><span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span><span style="top:-1.167119000000001em;"><span class="pstrut" style="height:3.130339em;"></span><span class="mord"><span class="mord"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8991079999999999em;"><span style="top:-2.069561em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight" style="color:lightgreen"><span class="mord mathnormal mtight" style="">A</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mord"><span class="mord mathnormal">Q</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8991079999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.0304389999999999em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.899108em;"><span style="top:-1.972561em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight" style="color:lightgreen"><span class="mord mathnormal mtight" style="margin-right:0.05017em">B</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mord"><span class="mord mathnormal">Q</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8991079999999999em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-2.4530000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.383108em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.127439em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.1303389999999998em;"><span style="top:-2.1999em;margin-left:0em;"><span class="pstrut" style="height:3.130339em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight" style="color:lightgreen"><span class="mord mathnormal mtight" style="margin-right:0.07153em">C</span></span></span></span><span style="top:-3.130339em;"><span class="pstrut" style="height:3.130339em;"></span><span><span class="mop"><span class="mord"><span class="mord"><span class="mord coloredeq eqq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">Q</span></span></span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:1.130339em;"><span style="top:-3.344231em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.030439em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.130339em;"><span style="top:-2.1029em;margin-left:0em;"><span class="pstrut" style="height:3.130339em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight" style="color:lightgreen"><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span style="top:-3.130339em;"><span class="pstrut" style="height:3.130339em;"></span><span><span class="mop"><span class="mord"><span class="mord"><span class="mord coloredeq eqq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">Q</span></span></span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:1.130339em;"><span style="top:-3.344231em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.891331em;"><span style="top:-2.4530000000000003em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.1130000000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.383108em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.127439em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:3.3906589999999994em;"><span></span></span></span></span></span></span></span></span></span></span></span><p>where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord mathnormal">Q</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span>, are linear transformations of original embeddings <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.2438799999999999em;vertical-align:-0.394772em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.664392em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07847em;">X</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-2.441336em;margin-left:-0.07847em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.394772em;"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.236103em;vertical-align:-0.394772em;"></span><span class="mord coloredeq eqq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">Q</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.441336em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.394772em;"><span></span></span></span></span></span></span></span></span></span></span> are linear transformations of positional encodings <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.13889em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span>.</p>
|
||
<p>We replace term <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord" style="color:lightgreen"><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span></span></span></span> with <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span>.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">116</span> <span class="k">def</span> <span class="nf">get_scores</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-19'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-19'>#</a>
|
||
</div>
|
||
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.236103em;vertical-align:-0.394772em;"></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.441336em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.394772em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">key_pos_emb</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_pos_embeddings</span><span class="p">[</span><span class="o">-</span><span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-20'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-20'>#</a>
|
||
</div>
|
||
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8413309999999999em;vertical-align:0em;"></span><span class="mord coloredeq eqq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">Q</span></span></span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">146</span> <span class="n">query_pos_bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_pos_bias</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-21'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-21'>#</a>
|
||
</div>
|
||
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">key_pos_bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_pos_bias</span><span class="p">[</span><span class="o">-</span><span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-22'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-22'>#</a>
|
||
</div>
|
||
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.879547em;vertical-align:-1.0304389999999999em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8491080000000001em;"><span style="top:-2.069561em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight" style="color:lightgreen"><span class="mord mathnormal mtight" style="">A</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mord"><span class="mord mathnormal">Q</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.0304389999999999em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:2.110778em;vertical-align:-1.0304389999999999em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.080339em;"><span style="top:-2.1499em;margin-left:0em;"><span class="pstrut" style="height:3.080339em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight" style="color:lightgreen"><span class="mord mathnormal mtight" style="margin-right:0.07153em">C</span></span></span></span><span style="top:-3.080339em;"><span class="pstrut" style="height:3.080339em;"></span><span><span class="mop"><span class="mord"><span class="mord"><span class="mord coloredeq eqq" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="">Q</span></span></span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:1.080339em;"><span style="top:-3.294231em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.07153em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.0304389999999999em;"><span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">ac</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhd,jbhd->jbh'</span><span class="p">,</span> <span class="n">query</span> <span class="o">+</span> <span class="n">query_pos_bias</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-23'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-23'>#</a>
|
||
</div>
|
||
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.988211em;vertical-align:-1.1391029999999998em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8491080000000002em;"><span style="top:-1.9608970000000003em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight" style="color:lightgreen"><span class="mord mathnormal mtight" style="margin-right:0.05017em">B</span></span></span></span><span style="top:-3.0000000000000004em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mord"><span class="mord mathnormal">Q</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">⊤</span></span></span></span></span></span></span></span><span class="mord coloredeq eql" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10903em">U</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.441336em;margin-left:-0.10903em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.394772em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.1391029999999998em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.713769em;vertical-align:-1.0304389999999999em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6833300000000001em;"><span style="top:-2.069561em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight" style="color:lightgreen"><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.05764em">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.0304389999999999em;"><span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">bd</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'bhd,jhd->jbh'</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key_pos_emb</span><span class="p">)</span> <span class="o">+</span> <span class="n">key_pos_bias</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-24'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-24'>#</a>
|
||
</div>
|
||
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.969438em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqr" style=""><span class="mord mathnormal mtight" style="margin-right:0.05724em">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">156</span> <span class="k">return</span> <span class="n">ac</span> <span class="o">+</span> <span class="n">bd</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-25'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-25'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">query</span></code>
|
||
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</li>
|
||
<li><code class="highlight"><span></span><span class="n">key</span></code>
|
||
and <code class="highlight"><span></span><span class="n">value</span></code>
|
||
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">158</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||
<span class="lineno">159</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="lineno">160</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="lineno">161</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-26'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-26'>#</a>
|
||
</div>
|
||
<p>Prepare <code class="highlight"><span></span><span class="n">query</span></code>
|
||
, <code class="highlight"><span></span><span class="n">key</span></code>
|
||
and <code class="highlight"><span></span><span class="n">value</span></code>
|
||
for attention computation <code class="highlight"><span></span><span class="n">key</span></code>
|
||
and <code class="highlight"><span></span><span class="n">value</span></code>
|
||
will then have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
|
||
and <code class="highlight"><span></span><span class="n">query</span></code>
|
||
will have shape <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">170</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
|
||
<span class="lineno">171</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">:</span>
|
||
<span class="lineno">172</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
|
||
<span class="lineno">173</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">:</span>
|
||
<span class="lineno">174</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">value</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-27'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-27'>#</a>
|
||
</div>
|
||
<p>Compute attention scores. Results in a tensor of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">]</span></code>
|
||
</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">178</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-28'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-28'>#</a>
|
||
</div>
|
||
<p>Scale scores <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.383108em;vertical-align:-0.538em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.845108em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord sqrt mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em;"><span class="mord mtight"><span class="mord mathnormal mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03148em;">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em;"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
|
||
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
|
||
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
|
||
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
|
||
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
|
||
c69,-144,104.5,-217.7,106.5,-221
|
||
l0 -0
|
||
c5.3,-9.3,12,-14,20,-14
|
||
H400000v40H845.2724
|
||
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
|
||
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
|
||
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">scores</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-29'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-29'>#</a>
|
||
</div>
|
||
<p>Softmax </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">184</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-30'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-30'>#</a>
|
||
</div>
|
||
<p>Apply dropout </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">187</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-31'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-31'>#</a>
|
||
</div>
|
||
<p>Multiply by the values </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">190</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"jbh,jbhd->bhd"</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-32'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-32'>#</a>
|
||
</div>
|
||
<p>Concatenate multiple heads </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-33'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-33'>#</a>
|
||
</div>
|
||
<p>Output layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-34'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-34'>#</a>
|
||
</div>
|
||
<h2>Feedback Transformer Layer</h2>
|
||
<p>This implements a single transformer layer in the feedback transformer.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">199</span><span class="k">class</span> <span class="nc">FeedbackTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-35'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-35'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||
is the number of features in the transformer </li>
|
||
<li><code class="highlight"><span></span><span class="n">attn</span></code>
|
||
is the feedback attention module </li>
|
||
<li><code class="highlight"><span></span><span class="n">feed_forward</span></code>
|
||
is the position-wise feed forward layer </li>
|
||
<li><code class="highlight"><span></span><span class="n">dropout_prob</span></code>
|
||
is the dropout probability for dropout layers after attention and feed-forward</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">206</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||
<span class="lineno">207</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="lineno">208</span> <span class="n">attn</span><span class="p">:</span> <span class="n">FeedbackAttention</span><span class="p">,</span>
|
||
<span class="lineno">209</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
|
||
<span class="lineno">210</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-36'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-36'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">217</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-37'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-37'>#</a>
|
||
</div>
|
||
<p>Transformer size <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">m</span><span class="mord mathnormal mtight">o</span><span class="mord mathnormal mtight">d</span><span class="mord mathnormal mtight">e</span><span class="mord mathnormal mtight" style="margin-right:0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">219</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-38'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-38'>#</a>
|
||
</div>
|
||
<p> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">221</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
|
||
<span class="lineno">222</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
|
||
<span class="lineno">223</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-39'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-39'>#</a>
|
||
</div>
|
||
<p>Normalization layers </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">226</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
|
||
<span class="lineno">227</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-40'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-40'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">229</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
||
<span class="lineno">230</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="lineno">231</span> <span class="n">key</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="lineno">232</span> <span class="n">value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-41'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-41'>#</a>
|
||
</div>
|
||
<p>If there is memory </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">234</span> <span class="k">if</span> <span class="n">key</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-42'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-42'>#</a>
|
||
</div>
|
||
<p>Normalize the vectors before doing self attention </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">236</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-43'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-43'>#</a>
|
||
</div>
|
||
<p>Run through self attention, i.e. keys and values are from self </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">238</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-44'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-44'>#</a>
|
||
</div>
|
||
<p>Add the self attention results </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">240</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-45'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-45'>#</a>
|
||
</div>
|
||
<p>Normalize for feed-forward </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">243</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-46'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-46'>#</a>
|
||
</div>
|
||
<p>Pass through the feed-forward network </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">245</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-47'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-47'>#</a>
|
||
</div>
|
||
<p>Add the feed-forward results back </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">247</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-48'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-48'>#</a>
|
||
</div>
|
||
<p> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">250</span> <span class="k">return</span> <span class="n">x</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-49'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-49'>#</a>
|
||
</div>
|
||
<h2>Feedback Transformer Module</h2>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">253</span><span class="k">class</span> <span class="nc">FeedbackTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-50'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-50'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">layer</span></code>
|
||
is the feedback transformer layer, which we clone for each layer </li>
|
||
<li><code class="highlight"><span></span><span class="n">n_layers</span></code>
|
||
is the number of layers in the transformer</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">258</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">FeedbackTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-51'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-51'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">264</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-52'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-52'>#</a>
|
||
</div>
|
||
<p>Make copies of the transformer layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">266</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-53'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-53'>#</a>
|
||
</div>
|
||
<p>Final normalization layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">268</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-54'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-54'>#</a>
|
||
</div>
|
||
<p>Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">271</span> <span class="bp">self</span><span class="o">.</span><span class="n">weights</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-55'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-55'>#</a>
|
||
</div>
|
||
<p>Softmax for weights before taking the weighted sum </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">273</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-56'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-56'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">x_seq</span></code>
|
||
is the input with shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">275</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x_seq</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-57'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-57'>#</a>
|
||
</div>
|
||
<p>Split the input to a list along the sequence axis </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">281</span> <span class="n">x_seq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unbind</span><span class="p">(</span><span class="n">x_seq</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-58'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-58'>#</a>
|
||
</div>
|
||
<p>List to store the outputs </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">283</span> <span class="n">res</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-59'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-59'>#</a>
|
||
</div>
|
||
<p>List to store the memory vectors </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">285</span> <span class="n">mem</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-60'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-60'>#</a>
|
||
</div>
|
||
<p>For each input step </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">287</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">x_seq</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-61'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-61'>#</a>
|
||
</div>
|
||
<p>List to store layer outputs </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">289</span> <span class="n">layer_outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-62'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-62'>#</a>
|
||
</div>
|
||
<p>If there is memory, stack them into a vector </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">292</span> <span class="n">mem_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span> <span class="k">if</span> <span class="n">mem</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-63'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-63'>#</a>
|
||
</div>
|
||
<p>Run through each layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">295</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-64'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-64'>#</a>
|
||
</div>
|
||
<p>Get layer output </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">297</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">mem_tensor</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">mem_tensor</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-65'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-65'>#</a>
|
||
</div>
|
||
<p>Append them to the list of layer outputs </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">299</span> <span class="n">layer_outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-66'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-66'>#</a>
|
||
</div>
|
||
<p>Stack the layer outputs to a tensor </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">302</span> <span class="n">layer_outputs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">layer_outputs</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-67'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-67'>#</a>
|
||
</div>
|
||
<p>Calculate the memory vector as a weighted sum of layer outputs </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">304</span> <span class="n">mem</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'lbd,l->bd'</span><span class="p">,</span> <span class="n">layer_outputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">)))</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-68'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-68'>#</a>
|
||
</div>
|
||
<p>Append the output to results </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">306</span> <span class="n">res</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-69'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-69'>#</a>
|
||
</div>
|
||
<p>Stack the output tensors </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">309</span> <span class="n">res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-70'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-70'>#</a>
|
||
</div>
|
||
<p>Normalize the output </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">311</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-71'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-71'>#</a>
|
||
</div>
|
||
<p><a id="shared_kv"></a></p>
|
||
<h1>Shared keys and values among layers</h1>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-72'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-72'>#</a>
|
||
</div>
|
||
<h3>Stack Function implementation</h3>
|
||
<p>We implement a custom function instead of appending to a python list and then doing <code class="highlight"><span></span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span></code>
|
||
. This greatly improves the performance over calling <code class="highlight"><span></span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span></code>
|
||
at each step along the sequence. Everytime <code class="highlight"><span></span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span></code>
|
||
is called, it creates a new tensor, while this method and the accompanying class <code class="highlight"><span></span><span class="n">Stack</span></code>
|
||
share memory for each step.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">318</span><span class="k">class</span> <span class="nc">StackFunction</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-73'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-73'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">ctx</span></code>
|
||
is the context of the function (which lets us cache stuff) </li>
|
||
<li><code class="highlight"><span></span><span class="n">memory</span></code>
|
||
is the shared memory tensor where we stack and store the values of each step (keys & values) </li>
|
||
<li><code class="highlight"><span></span><span class="n">memory_grad</span></code>
|
||
is the shared memory tensor to store and accumulate gradients of each step </li>
|
||
<li><code class="highlight"><span></span><span class="n">last</span></code>
|
||
is the last value stacked </li>
|
||
<li><code class="highlight"><span></span><span class="n">n</span></code>
|
||
is the number of steps (i.e. size of the stack)</li></ul>
|
||
<p>This returns the stacked tensor for steps upto <code class="highlight"><span></span><span class="n">n</span></code>
|
||
.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">330</span> <span class="nd">@staticmethod</span>
|
||
<span class="lineno">331</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">memory</span><span class="p">,</span> <span class="n">memory_grad</span><span class="p">,</span> <span class="n">last</span><span class="p">,</span> <span class="n">n</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-74'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-74'>#</a>
|
||
</div>
|
||
<p>Cache accumulated gradients </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">343</span> <span class="n">ctx</span><span class="o">.</span><span class="n">_mem_grad</span> <span class="o">=</span> <span class="n">memory_grad</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-75'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-75'>#</a>
|
||
</div>
|
||
<p>Cache the size of the stack </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">345</span> <span class="n">ctx</span><span class="o">.</span><span class="n">_n</span> <span class="o">=</span> <span class="n">n</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-76'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-76'>#</a>
|
||
</div>
|
||
<p>Return the stack </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">347</span> <span class="k">return</span> <span class="n">memory</span><span class="p">[:</span><span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-77'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-77'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">grad_output</span></code>
|
||
is the gradient with respect to the output of about <code class="highlight"><span></span><span class="n">forward</span></code>
|
||
function</li></ul>
|
||
<p>This accumulates the gradients in the shared memory tensor and return the gradients with respect to the <code class="highlight"><span></span><span class="n">last</span></code>
|
||
result in the stack.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">349</span> <span class="nd">@staticmethod</span>
|
||
<span class="lineno">350</span> <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_output</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-78'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-78'>#</a>
|
||
</div>
|
||
<p>Get the current size of the stack </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">358</span> <span class="n">n</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">_n</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-79'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-79'>#</a>
|
||
</div>
|
||
<p>Get the accumulated gradients </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">360</span> <span class="n">memory_grad</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">_mem_grad</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-80'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-80'>#</a>
|
||
</div>
|
||
<p>Add the gradients </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">362</span> <span class="n">memory_grad</span><span class="p">[:</span><span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+=</span> <span class="n">grad_output</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-81'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-81'>#</a>
|
||
</div>
|
||
<p>Return the gradients w.r.t to last value in the stack </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">364</span> <span class="k">return</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">memory_grad</span><span class="p">[</span><span class="n">n</span><span class="p">],</span> <span class="kc">None</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-82'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-82'>#</a>
|
||
</div>
|
||
<h3>Stack Module</h3>
|
||
<p>This uses the stack function defined above, and does the necessary initializations.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">367</span><span class="k">class</span> <span class="nc">Stack</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-83'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-83'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">max_len</span></code>
|
||
is the maximum size of the stack</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">374</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-84'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-84'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">378</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_len</span> <span class="o">=</span> <span class="n">max_len</span>
|
||
<span class="lineno">379</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="lineno">380</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_grad</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="lineno">381</span> <span class="bp">self</span><span class="o">.</span><span class="n">last</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="lineno">382</span> <span class="bp">self</span><span class="o">.</span><span class="n">n</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
|
||
<span class="lineno">383</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_get_n</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-85'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-85'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">n</span></code>
|
||
is the size of the stack </li>
|
||
<li><code class="highlight"><span></span><span class="n">value</span></code>
|
||
is the tensor that needs to be added to the stack</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">385</span> <span class="k">def</span> <span class="nf">append</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-86'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-86'>#</a>
|
||
</div>
|
||
<p>You need to get (use) the stack after adding a value. Otherwise this implementation fails </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">393</span> <span class="k">assert</span> <span class="n">n</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_get_n</span> <span class="o">==</span> <span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">n</span><span class="si">}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">last_get_n</span><span class="si">}</span><span class="s2">"</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-87'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-87'>#</a>
|
||
</div>
|
||
<p>Do this without gradients </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">396</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-88'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-88'>#</a>
|
||
</div>
|
||
<p>Initialize the shared memory tensor to keep the stack </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">398</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">!=</span> <span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-89'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-89'>#</a>
|
||
</div>
|
||
<p>This should only happen when the stack is empty </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">400</span> <span class="k">assert</span> <span class="n">n</span> <span class="o">==</span> <span class="mi">0</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-90'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-90'>#</a>
|
||
</div>
|
||
<p>Create a tensor for the stack </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">402</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_len</span><span class="p">,</span> <span class="o">*</span><span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-91'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-91'>#</a>
|
||
</div>
|
||
<p>Create a tensor to accumulate the gradients </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">404</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_grad</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-92'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-92'>#</a>
|
||
</div>
|
||
<p>The memory is already initialized but we are resetting the stack.</p>
|
||
<p>This could have been another function like <code class="highlight"><span></span><span class="n">reset</span></code>
|
||
, but we found this easier to use. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">409</span> <span class="k">elif</span> <span class="n">n</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-93'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-93'>#</a>
|
||
</div>
|
||
<p>Reset accumulated gradients </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">411</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_grad</span><span class="o">.</span><span class="n">fill_</span><span class="p">(</span><span class="mf">0.</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-94'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-94'>#</a>
|
||
</div>
|
||
<p>Set the value in the correct position of the stack </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">414</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-95'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-95'>#</a>
|
||
</div>
|
||
<p>Keep track of the stack (for debugging) </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">416</span> <span class="bp">self</span><span class="o">.</span><span class="n">n</span> <span class="o">=</span> <span class="n">n</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-96'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-96'>#</a>
|
||
</div>
|
||
<p>Keep track of the last value added to the stack. We need this to be passed on to <code class="highlight"><span></span><span class="n">StackFunction</span></code>
|
||
in order to get the gradients propagated backwards. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">421</span> <span class="bp">self</span><span class="o">.</span><span class="n">last</span> <span class="o">=</span> <span class="n">value</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-97'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-97'>#</a>
|
||
</div>
|
||
<p> Returns the stack</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">423</span> <span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-98'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-98'>#</a>
|
||
</div>
|
||
<p>Keep track of the size of the stack when it was used. This is used for a sanity check in <code class="highlight"><span></span><span class="n">append</span></code>
|
||
. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">430</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_get_n</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">n</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-99'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-99'>#</a>
|
||
</div>
|
||
<p>Take it all through <code class="highlight"><span></span><span class="n">StackFunction</span></code>
|
||
so that <code class="highlight"><span></span><span class="n">StackFunction</span><span class="o">.</span><span class="n">backwards</span></code>
|
||
is called by PyTorch during backpropagation. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">433</span> <span class="k">return</span> <span class="n">StackFunction</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_grad</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-100'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-100'>#</a>
|
||
</div>
|
||
<p> To release memory</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">435</span> <span class="k">def</span> <span class="nf">free</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-101'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-101'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">440</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="lineno">441</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_grad</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="lineno">442</span> <span class="bp">self</span><span class="o">.</span><span class="n">last</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-102'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-102'>#</a>
|
||
</div>
|
||
<h2>Updated Feedback Transformer Module</h2>
|
||
<p>This is the updated feedback transformer module that caches the keys and values.</p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">445</span><span class="k">class</span> <span class="nc">FeedbackTransformerKV</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-103'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-103'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">layer</span></code>
|
||
is the feedback transformer layer, which we clone for each layer </li>
|
||
<li><code class="highlight"><span></span><span class="n">n_layers</span></code>
|
||
is the number of layers in the transformer </li>
|
||
<li><code class="highlight"><span></span><span class="n">d_model</span></code>
|
||
is the number of features in the transformer </li>
|
||
<li>'heads' is the number of attention heads</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">452</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">FeedbackTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-104'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-104'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">460</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-105'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-105'>#</a>
|
||
</div>
|
||
<p>Make copies of the transformer layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">462</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-106'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-106'>#</a>
|
||
</div>
|
||
<p>Final normalization layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">464</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-107'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-107'>#</a>
|
||
</div>
|
||
<p>Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that. </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">467</span> <span class="bp">self</span><span class="o">.</span><span class="n">weights</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">n_layers</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-108'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-108'>#</a>
|
||
</div>
|
||
<p>Softmax for weights before taking the weighted sum </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">469</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-109'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-109'>#</a>
|
||
</div>
|
||
<p>Number of features in a head </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">472</span> <span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">heads</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-110'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-110'>#</a>
|
||
</div>
|
||
<p>Module to transform embeddings (memory) to get keys </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">474</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-111'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-111'>#</a>
|
||
</div>
|
||
<p>Module to transform embeddings (memory) to get keys </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">476</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-112'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-112'>#</a>
|
||
</div>
|
||
<p>Memory for stacked keys </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">479</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_key</span> <span class="o">=</span> <span class="n">Stack</span><span class="p">(</span><span class="mi">512</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-113'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-113'>#</a>
|
||
</div>
|
||
<p>Memory for stacked values </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">481</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_value</span> <span class="o">=</span> <span class="n">Stack</span><span class="p">(</span><span class="mi">512</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-114'>
|
||
<div class='docs doc-strings'>
|
||
<div class='section-link'>
|
||
<a href='#section-114'>#</a>
|
||
</div>
|
||
<ul><li><code class="highlight"><span></span><span class="n">x_seq</span></code>
|
||
is the input with shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
|
||
</li></ul>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">483</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x_seq</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-115'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-115'>#</a>
|
||
</div>
|
||
<p>Split the input to a list along the sequence axis </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">489</span> <span class="n">x_seq</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unbind</span><span class="p">(</span><span class="n">x_seq</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-116'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-116'>#</a>
|
||
</div>
|
||
<p>List to store the outputs </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">491</span> <span class="n">res</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-117'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-117'>#</a>
|
||
</div>
|
||
<p>For each input step </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">493</span> <span class="k">for</span> <span class="n">step</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">x_seq</span><span class="p">):</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-118'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-118'>#</a>
|
||
</div>
|
||
<p>List to store layer outputs </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">495</span> <span class="n">layer_outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">]</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-119'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-119'>#</a>
|
||
</div>
|
||
<p>Stack of keys and values </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">498</span> <span class="n">key_tensor</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="lineno">499</span> <span class="n">value_tensor</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-120'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-120'>#</a>
|
||
</div>
|
||
<p>Get the keys and values tensors if we are beyond the initial step </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">501</span> <span class="k">if</span> <span class="n">step</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="lineno">502</span> <span class="n">key_tensor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_key</span><span class="o">.</span><span class="n">get</span><span class="p">()</span>
|
||
<span class="lineno">503</span> <span class="n">value_tensor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_value</span><span class="o">.</span><span class="n">get</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-121'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-121'>#</a>
|
||
</div>
|
||
<p>Run through each layer </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">506</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-122'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-122'>#</a>
|
||
</div>
|
||
<p>Get layer output </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">508</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">key_tensor</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">value_tensor</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-123'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-123'>#</a>
|
||
</div>
|
||
<p>Append them to the list of layer outputs </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">510</span> <span class="n">layer_outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-124'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-124'>#</a>
|
||
</div>
|
||
<p>Stack the layer outputs to a tensor </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">513</span> <span class="n">layer_outputs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">layer_outputs</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-125'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-125'>#</a>
|
||
</div>
|
||
<p>Calculate the memory vector as a weighted sum of layer outputs </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">515</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'lbd,l->bd'</span><span class="p">,</span> <span class="n">layer_outputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">))</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-126'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-126'>#</a>
|
||
</div>
|
||
<p>Calculate the keys from memory and add it to the stack </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">517</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_key</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="n">mem</span><span class="p">))</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-127'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-127'>#</a>
|
||
</div>
|
||
<p>Calculate the values from memory and add it to the stack </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">519</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_value</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">(</span><span class="n">mem</span><span class="p">))</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-128'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-128'>#</a>
|
||
</div>
|
||
<p>Append the output to results </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">521</span> <span class="n">res</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-129'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-129'>#</a>
|
||
</div>
|
||
<p>Stack the output tensors </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">524</span> <span class="n">res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-130'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-130'>#</a>
|
||
</div>
|
||
<p>Normalize the output </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">526</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">res</span><span class="p">)</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-131'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-131'>#</a>
|
||
</div>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
<div class="highlight"><pre><span class="lineno">528</span> <span class="k">def</span> <span class="nf">free</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="lineno">529</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_key</span><span class="o">.</span><span class="n">free</span><span class="p">()</span>
|
||
<span class="lineno">530</span> <span class="bp">self</span><span class="o">.</span><span class="n">mem_value</span><span class="o">.</span><span class="n">free</span><span class="p">()</span></pre></div>
|
||
</div>
|
||
</div>
|
||
<div class='footer'>
|
||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||
<a href="https://labml.ai">labml.ai</a>
|
||
</div>
|
||
</div>
|
||
<script src=../../interactive.js?v=1"></script>
|
||
<script>
|
||
function handleImages() {
|
||
var images = document.querySelectorAll('p>img')
|
||
|
||
for (var i = 0; i < images.length; ++i) {
|
||
handleImage(images[i])
|
||
}
|
||
}
|
||
|
||
function handleImage(img) {
|
||
img.parentElement.style.textAlign = 'center'
|
||
|
||
var modal = document.createElement('div')
|
||
modal.id = 'modal'
|
||
|
||
var modalContent = document.createElement('div')
|
||
modal.appendChild(modalContent)
|
||
|
||
var modalImage = document.createElement('img')
|
||
modalContent.appendChild(modalImage)
|
||
|
||
var span = document.createElement('span')
|
||
span.classList.add('close')
|
||
span.textContent = 'x'
|
||
modal.appendChild(span)
|
||
|
||
img.onclick = function () {
|
||
console.log('clicked')
|
||
document.body.appendChild(modal)
|
||
modalImage.src = img.src
|
||
}
|
||
|
||
span.onclick = function () {
|
||
document.body.removeChild(modal)
|
||
}
|
||
}
|
||
|
||
handleImages()
|
||
</script>
|
||
</body>
|
||
</html> |