Files
Varuna Jayasiri a7a7a3bdb7 RETRO (#110)
2022-03-12 15:44:35 +05:30

1034 lines
99 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Documented implementation with explanations of a Compressive Transformer model."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Compressive Transformer"/>
<meta name="twitter:description" content="Documented implementation with explanations of a Compressive Transformer model."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/compressive/index.html"/>
<meta property="og:title" content="Compressive Transformer"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Compressive Transformer"/>
<meta property="og:description" content="Documented implementation with explanations of a Compressive Transformer model."/>
<title>Compressive Transformer</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/transformers/compressive/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">compressive</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/compressive/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Compressive Transformer</h1>
<p>This is an implementation of <a href="https://papers.labml.ai/paper/1911.05507">Compressive Transformers for Long-Range Sequence Modelling</a> in <a href="https://pytorch.org">PyTorch</a>.</p>
<p>This is an extension of <a href="../xl/index.html">Transformer XL</a> where past memories are compressed to give a longer attention range. That is, the furthest <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal" style="">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span><span class="mord mathnormal mtight" style="">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="">c</span></span></span></span></span> memories are compressed into <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord mathnormal" style="">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span><span class="mord mathnormal mtight" style="">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> memories, where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="">c</span></span></span></span></span> is the compression rate.</p>
<h2>Compression operation</h2>
<p>The compression operation is defined as <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">:</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8491079999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8491079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">n</span><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span><span class="mbin mtight">×</span><span class="mord mathnormal mtight">d</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8491079999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8491079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">n</span><span class="mbin mtight">×</span><span class="mord mathnormal mtight">d</span></span></span></span></span></span></span></span></span></span></span></span>. The paper introduces multiple choices for <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.16678em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97234em;"><span style="top:-3.14734em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span></span></span></span> where <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord coloredeq eql" style=""><span class="mord mathnormal" style="">i</span></span></span></span></span> is the layer number.</p>
<h2>Training compression operation</h2>
<p>Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an <em>auto-encoding loss</em> and an <em>attention reconstruction loss</em>. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.</p>
<p>This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before <a href="../feedforward.html">FFN</a> and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.</p>
<p>Here are <a href="experiment.html">the training code</a> and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.</p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/compressive/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/0d9b5338726c11ebb7c80242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">List</span>
<span class="lineno">55</span>
<span class="lineno">56</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">57</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
<span class="lineno">58</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">59</span>
<span class="lineno">60</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span><span class="p">,</span> <span class="n">TypedModuleList</span>
<span class="lineno">61</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="lineno">62</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">PrepareForMultiHeadAttention</span>
<span class="lineno">63</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.xl.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span>
<span class="lineno">64</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>1D Convolution Compression <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></h2>
<p>This is a simple wrapper around <a href="https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html"><code class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">Conv1d</span></code>
</a> with some tensor dimension permutations.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span><span class="k">class</span> <span class="nc">Conv1dCompression</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">compression_rate</span></code>
<span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="">c</span></span></span></span></span> </li>
<li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the embedding size</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">compression_rate</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv1d</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">compression_rate</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">compression_rate</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p> <code class="highlight"><span></span><span class="n">mem</span></code>
has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Permute the dimensions of <code class="highlight"><span></span><span class="n">mem</span></code>
so that we can run it through the convolution layer. The convolution layer accepts in the form <code class="highlight"><span></span><span class="p">[</span><span class="n">batch</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">sequence</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">mem</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Get compressed memory by running it through the convolution layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Permute back to form <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="k">return</span> <span class="n">c_mem</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h2>Compressive Transformer Layer</h2>
<p>This is the implementation of a single compressive transformer layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span><span class="k">class</span> <span class="nc">CompressiveTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the token embedding size </li>
<li><code class="highlight"><span></span><span class="n">self_attn</span></code>
is the <a href="../xl/relative_mha.html">self attention module</a> </li>
<li><code class="highlight"><span></span><span class="n">feed_forward</span></code>
is the <a href="../feed_forward.html">feed forward module</a> </li>
<li><code class="highlight"><span></span><span class="n">dropout_prob</span></code>
is the probability of dropping out after self attention and FFN </li>
<li><code class="highlight"><span></span><span class="n">compress</span></code>
is the compression function <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">104</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">105</span> <span class="n">self_attn</span><span class="p">:</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">,</span>
<span class="lineno">106</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">107</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="lineno">108</span> <span class="n">compress</span><span class="p">:</span> <span class="n">Conv1dCompression</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">compress</span> <span class="o">=</span> <span class="n">compress</span>
<span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">119</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
<span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p> Concatenate the normalized token embeddings with memory and compressed memory.</p>
<ul><li><code class="highlight"><span></span><span class="n">z</span></code>
is layer normalized token embeddings. </li>
<li><code class="highlight"><span></span><span class="n">mem</span></code>
and <code class="highlight"><span></span><span class="n">c_mem</span></code>
are memory and compressed memory (not normalized).</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="k">def</span> <span class="nf">concat_memory</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">c_mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>If there is no memory just return the token embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="k">if</span> <span class="n">mem</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">135</span> <span class="k">return</span> <span class="n">z</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>If there are compressed memory concatenate that with memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="k">if</span> <span class="n">c_mem</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">139</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">c_mem</span><span class="p">,</span> <span class="n">mem</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Run the memory through the normalization layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Concatenate normalized memory and normalized token embeddings </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">mem</span><span class="p">,</span> <span class="n">z</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is a tensor of token level feature vectors of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">mem</span></code>
is a tensor of the past token level feature vectors (memory) of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">mem_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">c_mem</span></code>
is a tensor of the compressed memory <code class="highlight"><span></span><span class="p">[</span><span class="n">c_mem_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">mask</span></code>
is a matrix of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">c_mem_len</span> <span class="o">+</span> <span class="n">mem_len</span> <span class="o">+</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
or <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">c_mem_len</span> <span class="o">+</span> <span class="n">mem_len</span> <span class="o">+</span> <span class="n">seq_len</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span></code>
. <code class="highlight"><span></span><span class="n">mask</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span></code>
is true if token at <code class="highlight"><span></span><span class="n">i</span></code>
can see token at <code class="highlight"><span></span><span class="n">j</span></code>
.</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">147</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">148</span> <span class="n">mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="lineno">149</span> <span class="n">c_mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="lineno">150</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Normalize the vectors before doing self attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Normalize and concatenate memory and compressed memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">m_z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">concat_memory</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">mem</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">m_z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">m_z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Add the attention results </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Normalize for feed-forward </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Pass through the feed-forward network </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Add the feed-forward results back </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<h2>Compressive Transformer Model</h2>
<p>This consists of multiple compressive transformer layers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span><span class="k">class</span> <span class="nc">CompressiveTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">CompressiveTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">187</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Make copies of the transformer layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Final normalization layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is a tensor of the token embeddings vectors of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">mem</span></code>
is a list of tensors of the past token level feature vectors of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">mem_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
for each layer </li>
<li><code class="highlight"><span></span><span class="n">c_mem</span></code>
is a list of tensors of the compressed memory <code class="highlight"><span></span><span class="p">[</span><span class="n">c_mem_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
for each layer </li>
<li><code class="highlight"><span></span><span class="n">mask</span></code>
is the masking matrix</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">c_mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>List to store token level feature vectors, which will become the memories for the next sequential batch. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">new_mem</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Run through each transformer layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Add to the list of feature vectors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">new_mem</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">m</span> <span class="o">=</span> <span class="n">mem</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">mem</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Compressed Memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="n">cm</span> <span class="o">=</span> <span class="n">c_mem</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">c_mem</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Run through the transformer XL layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mem</span><span class="o">=</span><span class="n">m</span><span class="p">,</span> <span class="n">c_mem</span><span class="o">=</span><span class="n">cm</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Finally, normalize the vectors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">216</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">new_mem</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<h2>Attention Reconstruction Loss</h2>
<p>Attention reconstruction loss recreates the self-attention output with uncompressed memory and with compressed memory and calculates the mean squared error between the two. It does this without positional encoding.</p>
<p>When calculating and training the compression function <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> with attention reconstruction loss, all parameters but <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> are frozen. This includes key/value projections and bias/scaling after normalization.</p>
<p>Since this loss can be computed independently of the cross-entropy-loss of the model you can have a separate optimizer that only updates <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>. However, we use the same optimizer to update <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> so when calculating attention reconstruction loss, we detach all other parameters except <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span> from the gradient computation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span><span class="k">class</span> <span class="nc">AttentionReconstructionLoss</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p> <code class="highlight"><span></span><span class="n">layers</span></code>
is the list of Compressive Transformer layers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layers</span><span class="p">:</span> <span class="n">TypedModuleList</span><span class="p">[</span><span class="n">CompressiveTransformerLayer</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">241</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">layers</span>
<span class="lineno">242</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MSELoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p> This is a reimplementation of <a href="../mha.html#PrepareMHA">&#x27;PrepareForMultiHeadAttention&#x27;</a> where the projections are done with the parameters detached from gradient computation.</p>
<ul><li><code class="highlight"><span></span><span class="n">pmha</span></code>
is the <a href="../mha.html#PrepareMHA">&#x27;PrepareForMultiHeadAttention&#x27;</a> module </li>
<li><code class="highlight"><span></span><span class="n">x</span></code>
is tensor with the token embeddings</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">244</span> <span class="k">def</span> <span class="nf">prepare_for_attn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pmha</span><span class="p">:</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Shape of the input except embedding dimension; <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">head_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Detach projection weights and bias </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">pmha</span><span class="o">.</span><span class="n">linear</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="lineno">258</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">pmha</span><span class="o">.</span><span class="n">linear</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span> <span class="k">if</span> <span class="n">pmha</span><span class="o">.</span><span class="n">linear</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Linear transform </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">260</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Split last dimension into heads </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">head_shape</span><span class="p">,</span> <span class="n">pmha</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="n">pmha</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Output has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
or <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">266</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p> This is a reimplementation of <a href="../mha.html#MHA">&#x27;Multi-Head Attention&#x27;</a> which calls <code class="highlight"><span></span><span class="n">prepare_for_attn</span></code>
instead of <a href="../mha.html#PrepareMHA">&#x27;PrepareForMultiHeadAttention&#x27;</a> to detach projection parameters.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">268</span> <span class="k">def</span> <span class="nf">attn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>Calculate query, key and value projections </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">275</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_for_attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">query</span><span class="p">,</span> <span class="n">query</span><span class="p">)</span>
<span class="lineno">276</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_for_attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">key</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span>
<span class="lineno">277</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_for_attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>Compute attention scores <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.043548em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">Q</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span></span></span></span>. This gives a tensor of shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">]</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">281</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ibhd,jbhd-&gt;ijbh&#39;</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Scale scores <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.633028em;vertical-align:-0.538em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqf" style=""><span class="mord mathnormal mtight" style="">Q</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">284</span> <span class="n">scores</span> <span class="o">*=</span> <span class="n">layer</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqh" style=""><span class="mord mathnormal" style="">so</span><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="mord mathnormal" style="">t</span><span class="mord mathnormal" style="">ma</span><span class="mord mathnormal" style="">x</span></span></span></span></span> attention along the key sequence dimension <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style=""><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop" style=""><span class="mord coloredeq eqh" style=""><span class="mord mathnormal" style="">so</span><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="mord mathnormal" style="">t</span><span class="mord mathnormal" style="">ma</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">(</span></span></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.095028em;"><span style="top:-2.5864385em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord sqrt mtight" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8622307142857143em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord mtight" style="padding-left:0.833em"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8222307142857144em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail mtight" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.17776928571428574em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqf" style=""><span class="mord mathnormal mtight" style="">Q</span><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.9270285714285713em;"><span style="top:-2.931em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.538em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">)</span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Multiply by values <span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord coloredeq eqc" style=""><span class="mord" style=""><span class="mop op-limits" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944399999999998em;"><span style="top:-2.20556em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">se</span><span class="mord mathnormal mtight" style="margin-right:0.03588em">q</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop" style=""><span class="mord coloredeq eqh" style=""><span class="mord mathnormal" style="">so</span><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="mord mathnormal" style="">t</span><span class="mord mathnormal" style="">ma</span><span class="mord mathnormal" style="">x</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.030548em;"><span></span></span></span></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">(</span></span></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.5261079999999998em;"><span style="top:-2.25278em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord sqrt" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.85722em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em"><span class="mord" style=""><span class="mord mathnormal" style="">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.03148em">k</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.81722em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg height="1.08em" preserveaspectratio="xMinYMin slice" viewbox="0 0 400000 1080" width="400em" xmlns="http://www.w3.org/2000/svg"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.18278000000000005em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">Q</span><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.07153em">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.849108em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span><span class="mord" style=""><span class="delimsizing size4" style=""><span style="">)</span></span></span></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p> Perform layer normalization with shift and scale parameters detached.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="k">def</span> <span class="nf">norm</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ln</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Detach shift(<code class="highlight"><span></span><span class="n">bias</span></code>
) and scaling(<code class="highlight"><span></span><span class="n">weight</span></code>
) parameters </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">300</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">ln</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span> <span class="k">if</span> <span class="n">ln</span><span class="o">.</span><span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
<span class="lineno">301</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">ln</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span> <span class="k">if</span> <span class="n">ln</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Layer normalization </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">304</span> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">ln</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">ln</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p> This calculates the loss for a layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">306</span> <span class="k">def</span> <span class="nf">calc_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">CompressiveTransformerLayer</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>Detach the token embeddings and memory. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span> <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="lineno">313</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">mem</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Compress the memory with <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.16678em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97234em;"><span style="top:-3.14734em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span></span></span></span>. The parameters of <span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.16678em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqj" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97234em;"><span style="top:-3.14734em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eql" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span></span></span></span> are the only parameters not detached from gradient computation. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">317</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">compress</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Normalize the embeddings and memories </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">320</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
<span class="lineno">321</span> <span class="n">mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">,</span> <span class="n">mem</span><span class="p">)</span>
<span class="lineno">322</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Calculate the attention with uncompressed memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">325</span> <span class="n">attn_mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">self_attn</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">mem</span><span class="p">,</span> <span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>Calculate the attention with compressed memory </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">327</span> <span class="n">attn_cmem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">self_attn</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>Calculate the mean square error </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">330</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">attn_cmem</span><span class="p">,</span> <span class="n">attn_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">332</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Calculate the losses for each layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">334</span> <span class="n">losses</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">calc_loss</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">h</span><span class="p">[</span><span class="n">n</span><span class="p">],</span> <span class="n">mem</span><span class="p">[</span><span class="n">n</span><span class="p">])</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>Sum of the losses </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">336</span> <span class="k">return</span> <span class="nb">sum</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>