Files
Varuna Jayasiri efd2673735 cleanup
2021-06-02 21:40:05 +05:30

967 lines
66 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Documented implementation with explanations of a Compressive Transformer model."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Compressive Transformer"/>
<meta name="twitter:description" content="Documented implementation with explanations of a Compressive Transformer model."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/compressive/index.html"/>
<meta property="og:title" content="Compressive Transformer"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Compressive Transformer"/>
<meta property="og:description" content="Documented implementation with explanations of a Compressive Transformer model."/>
<title>Compressive Transformer</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/compressive/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">compressive</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/transformers/compressive/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Compressive Transformer</h1>
<p>This is an implementation of
<a href="https://arxiv.org/abs/1911.05507">Compressive Transformers for Long-Range Sequence Modelling</a>
in <a href="https://pytorch.org">PyTorch</a>.</p>
<p>This is an extension of <a href="../xl/index.html">Transformer XL</a> where past memories
are compressed to give a longer attention range.
That is, the furthest $n_{cm} c$ memories are compressed into
$n_{cm}$ memories, where $c$ is the compression rate.</p>
<h2>Compression operation</h2>
<p>The compression operation is defined as
$f_c: \mathbb{R}^{nc \times d} \rightarrow \mathbb{R}^{n \times d}$.
The paper introduces multiple choices for $f_c$ and we have only implemented
1D convolution which seems to give the best results.
Each layer has a separate compression operation $f_c^{(i)}$ where
$i$ is the layer number.</p>
<h2>Training compression operation</h2>
<p>Since training compression with BPTT requires maintaining
a very large computational graph (many time steps), the paper proposes
an <em>auto-encoding loss</em> and an <em>attention reconstruction loss</em>.
The auto-encoding loss decodes the original memories from the compressed memories
and calculates the loss.
Attention reconstruction loss computes the multi-headed attention results
on the compressed memory and on uncompressed memory and gets a mean squared error
between them.
We have implemented the latter here since it gives better results.</p>
<p>This implementation uses pre-layer normalization
while the paper uses post-layer normalization.
Pre-layer norm does the layer norm before FFN[../feedforward.html) and
self-attention, and the pass-through in the residual connection is not normalized.
This is supposed to be more stable in standard transformer setups.</p>
<p>Here are <a href="experiment.html">the training code</a> and a notebook for training a compressive transformer
model on the Tiny Shakespeare dataset.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/compressive/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/0d9b5338726c11ebb7c80242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">List</span>
<span class="lineno">55</span>
<span class="lineno">56</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">57</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
<span class="lineno">58</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">59</span>
<span class="lineno">60</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span><span class="p">,</span> <span class="n">TypedModuleList</span>
<span class="lineno">61</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="lineno">62</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">PrepareForMultiHeadAttention</span>
<span class="lineno">63</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.xl.relative_mha</span> <span class="kn">import</span> <span class="n">RelativeMultiHeadAttention</span>
<span class="lineno">64</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>1D Convolution Compression $f_c$</h2>
<p>This is a simple wrapper around
<a href="https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html"><code>nn.Conv1d</code></a>
with some tensor dimension permutations.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span><span class="k">class</span> <span class="nc">Conv1dCompression</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul>
<li><code>compression_rate</code> $c$</li>
<li><code>d_model</code> is the embedding size</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">75</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">compression_rate</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv1d</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">compression_rate</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">compression_rate</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p><code>mem</code> has shape <code>[seq_len, batch, d_model]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Permute the dimensions of <code>mem</code> so that we can run it through the convolution layer.
The convolution layer accepts in the form <code>[batch, features, sequence]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">mem</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Get compressed memory by running it through the convolution layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Permute back to form <code>[seq_len, batch, d_model]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">94</span> <span class="k">return</span> <span class="n">c_mem</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h2>Compressive Transformer Layer</h2>
<p>This is the implementation of a single compressive transformer layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span><span class="k">class</span> <span class="nc">CompressiveTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<ul>
<li><code>d_model</code> is the token embedding size</li>
<li><code>self_attn</code> is the <a href="../xl/relative_mha.html">self attention module</a></li>
<li><code>feed_forward</code> is the <a href="../feed_forward.html">feed forward module</a></li>
<li><code>dropout_prob</code> is the probability of dropping out after self attention and FFN</li>
<li><code>compress</code> is the compression function $f_c$</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">103</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">104</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">105</span> <span class="n">self_attn</span><span class="p">:</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">,</span>
<span class="lineno">106</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">107</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="lineno">108</span> <span class="n">compress</span><span class="p">:</span> <span class="n">Conv1dCompression</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">compress</span> <span class="o">=</span> <span class="n">compress</span>
<span class="lineno">118</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">119</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span> <span class="o">=</span> <span class="n">self_attn</span>
<span class="lineno">120</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">121</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">122</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">123</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Concatenate the normalized token embeddings with memory and compressed memory.</p>
<ul>
<li><code>z</code> is layer normalized token embeddings.</li>
<li><code>mem</code> and <code>c_mem</code> are memory and compressed memory (not normalized).</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="k">def</span> <span class="nf">concat_memory</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">c_mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>If there is no memory just return the token embeddings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="k">if</span> <span class="n">mem</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">135</span> <span class="k">return</span> <span class="n">z</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>If there are compressed memory concatenate that with memory</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="k">if</span> <span class="n">c_mem</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">139</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">c_mem</span><span class="p">,</span> <span class="n">mem</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Run the memory through the normalization layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Concatenate normalized memory and normalized token embeddings</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">mem</span><span class="p">,</span> <span class="n">z</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<ul>
<li><code>x</code> is a tensor of token level feature vectors of shape <code>[seq_len, batch_size, d_model]</code></li>
<li><code>mem</code> is a tensor of the past token level feature vectors (memory) of shape <code>[mem_len, batch_size, d_model]</code></li>
<li><code>c_mem</code> is a tensor of the compressed memory <code>[c_mem_len, batch_size, d_model]</code></li>
<li><code>mask</code> is a matrix of shape <code>[seq_len, c_mem_len + mem_len + seq_len, batch_size]</code> or <code>[seq_len, c_mem_len + mem_len + seq_len, 1]</code>.
<code>mask[i, j]</code> is true if token at <code>i</code> can see token at <code>j</code>.</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">147</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">148</span> <span class="n">mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="lineno">149</span> <span class="n">c_mem</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="lineno">150</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Normalize the vectors before doing self attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Normalize and concatenate memory and compressed memory</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span> <span class="n">m_z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">concat_memory</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">mem</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">self_attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">m_z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">m_z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Add the attention results</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Normalize for feed-forward</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Pass through the feed-forward network</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">ff</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Add the feed-forward results back</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<h2>Compressive Transformer Model</h2>
<p>This consists of multiple compressive transformer layers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span><span class="k">class</span> <span class="nc">CompressiveTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">CompressiveTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">187</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Make copies of the transformer layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Final normalization layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<ul>
<li><code>x</code> is a tensor of the token embeddings vectors of shape <code>[seq_len, batch_size, d_model]</code></li>
<li><code>mem</code> is a list of tensors of the past token level feature vectors of shape
<code>[mem_len, batch_size, d_model]</code> for each layer</li>
<li><code>c_mem</code> is a list of tensors of the compressed memory
<code>[c_mem_len, batch_size, d_model]</code> for each layer</li>
<li><code>mask</code> is the masking matrix</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">c_mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>List to store token level feature vectors,
which will become the memories for the next sequential batch.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">new_mem</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Run through each transformer layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Add to the list of feature vectors</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">new_mem</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Memory</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">m</span> <span class="o">=</span> <span class="n">mem</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">mem</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Compressed Memory</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="n">cm</span> <span class="o">=</span> <span class="n">c_mem</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">c_mem</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Run through the transformer XL layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mem</span><span class="o">=</span><span class="n">m</span><span class="p">,</span> <span class="n">c_mem</span><span class="o">=</span><span class="n">cm</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Finally, normalize the vectors</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">216</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">new_mem</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<h2>Attention Reconstruction Loss</h2>
<p>Attention reconstruction loss recreates the self-attention output with
uncompressed memory and with compressed memory and calculates the mean squared error
between the two. It does this without positional encoding.</p>
<p>When calculating and training the compression function $f_c$ with attention
reconstruction loss, all parameters but $f_c$ are frozen.
This includes key/value projections and bias/scaling after normalization.</p>
<p>Since this loss can be computed independently of the cross-entropy-loss of the model
you can have a separate optimizer that only updates $f_c$.
However, we use the same optimizer to update $f_c$ so when calculating
attention reconstruction loss, we detach all other parameters except $f_c$
from the gradient computation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span><span class="k">class</span> <span class="nc">AttentionReconstructionLoss</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p><code>layers</code> is the list of Compressive Transformer layers</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layers</span><span class="p">:</span> <span class="n">TypedModuleList</span><span class="p">[</span><span class="n">CompressiveTransformerLayer</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">241</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">layers</span>
<span class="lineno">242</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MSELoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<p>This is a reimplementation of <a href="../mha.html#PrepareMHA">&lsquo;PrepareForMultiHeadAttention&rsquo;</a>
where the projections are done with the parameters detached from gradient computation.</p>
<ul>
<li>`pmha* is the <a href="../mha.html#PrepareMHA">&lsquo;PrepareForMultiHeadAttention&rsquo;</a> module</li>
<li><code>x</code> is tensor with the token embeddings</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">244</span> <span class="k">def</span> <span class="nf">prepare_for_attn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pmha</span><span class="p">:</span> <span class="n">PrepareForMultiHeadAttention</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Shape of the input except embedding dimension; <code>[seq_len, batch_size]</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">254</span> <span class="n">head_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Detach projection weights and bias</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">pmha</span><span class="o">.</span><span class="n">linear</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="lineno">258</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">pmha</span><span class="o">.</span><span class="n">linear</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span> <span class="k">if</span> <span class="n">pmha</span><span class="o">.</span><span class="n">linear</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Linear transform</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">260</span> <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Split last dimension into heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">263</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">head_shape</span><span class="p">,</span> <span class="n">pmha</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="n">pmha</span><span class="o">.</span><span class="n">d_k</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Output has shape <code>[seq_len, batch_size, heads, d_k]</code> or <code>[batch_size, d_model]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">266</span> <span class="k">return</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>This is a reimplementation of <a href="../mha.html#MHA">&lsquo;Multi-Head Attention&rsquo;</a> which calls
<code>prepare_for_attn</code> instead of <a href="../mha.html#PrepareMHA">&lsquo;PrepareForMultiHeadAttention&rsquo;</a>
to detach projection parameters.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">268</span> <span class="k">def</span> <span class="nf">attn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">RelativeMultiHeadAttention</span><span class="p">,</span> <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
<p>Calculate query, key and value projections</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">275</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_for_attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">query</span><span class="p">,</span> <span class="n">query</span><span class="p">)</span>
<span class="lineno">276</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_for_attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">key</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span>
<span class="lineno">277</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_for_attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">value</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>Compute attention scores $Q K^\top$.
This gives a tensor of shape <code>[seq_len, seq_len, batch_size, heads]</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">281</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ibhd,jbhd-&gt;ijbh&#39;</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">284</span> <span class="n">scores</span> <span class="o">*=</span> <span class="n">layer</span><span class="o">.</span><span class="n">scale</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>$softmax$ attention along the key sequence dimension
$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">288</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Multiply by values
<script type="math/tex; mode=display">\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">292</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
<p>Perform layer normalization with shift and scale parameters detached.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">294</span> <span class="k">def</span> <span class="nf">norm</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ln</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Detach shift(<code>bias</code>) and scaling(<code>weight</code>) parameters</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">300</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">ln</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span> <span class="k">if</span> <span class="n">ln</span><span class="o">.</span><span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
<span class="lineno">301</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">ln</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span> <span class="k">if</span> <span class="n">ln</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
<p>Layer normalization</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">304</span> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">ln</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">ln</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>This calculates the loss for a layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">306</span> <span class="k">def</span> <span class="nf">calc_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">CompressiveTransformerLayer</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mem</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
<p>Detach the token embeddings and memory.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">312</span> <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="lineno">313</span> <span class="n">mem</span> <span class="o">=</span> <span class="n">mem</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<p>Compress the memory with $f_c^{(i)}$.
The parameters of $f_c^{(i)}$ are the only parameters not detached from gradient computation.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">317</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">compress</span><span class="p">(</span><span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
<p>Normalize the embeddings and memories</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">320</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
<span class="lineno">321</span> <span class="n">mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">,</span> <span class="n">mem</span><span class="p">)</span>
<span class="lineno">322</span> <span class="n">c_mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
<p>Calculate the attention with uncompressed memory</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">325</span> <span class="n">attn_mem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">self_attn</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">mem</span><span class="p">,</span> <span class="n">mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>Calculate the attention with compressed memory</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">327</span> <span class="n">attn_cmem</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">self_attn</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">,</span> <span class="n">c_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
<p>Calculate the mean square error</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">330</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">attn_cmem</span><span class="p">,</span> <span class="n">attn_mem</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">332</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">mem</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Calculate the losses for each layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">334</span> <span class="n">losses</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">calc_loss</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">h</span><span class="p">[</span><span class="n">n</span><span class="p">],</span> <span class="n">mem</span><span class="p">[</span><span class="n">n</span><span class="p">])</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>Sum of the losses</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">336</span> <span class="k">return</span> <span class="nb">sum</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>