mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
134 lines
13 KiB
HTML
134 lines
13 KiB
HTML
<!DOCTYPE html>
|
||
<html>
|
||
<head>
|
||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||
<meta name="description" content=""/>
|
||
|
||
<meta name="twitter:card" content="summary"/>
|
||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta name="twitter:title" content="Compressive Transformer"/>
|
||
<meta name="twitter:description" content=""/>
|
||
<meta name="twitter:site" content="@labmlai"/>
|
||
<meta name="twitter:creator" content="@labmlai"/>
|
||
|
||
<meta property="og:url" content="https://nn.labml.ai/transformers/compressive/readme.html"/>
|
||
<meta property="og:title" content="Compressive Transformer"/>
|
||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||
<meta property="og:type" content="object"/>
|
||
<meta property="og:title" content="Compressive Transformer"/>
|
||
<meta property="og:description" content=""/>
|
||
|
||
<title>Compressive Transformer</title>
|
||
<link rel="shortcut icon" href="/icon.png"/>
|
||
<link rel="stylesheet" href="../../pylit.css">
|
||
<link rel="canonical" href="https://nn.labml.ai/transformers/compressive/readme.html"/>
|
||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
|
||
|
||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||
<script>
|
||
window.dataLayer = window.dataLayer || [];
|
||
|
||
function gtag() {
|
||
dataLayer.push(arguments);
|
||
}
|
||
|
||
gtag('js', new Date());
|
||
|
||
gtag('config', 'G-4V3HC8HBLH');
|
||
</script>
|
||
</head>
|
||
<body>
|
||
<div id='container'>
|
||
<div id="background"></div>
|
||
<div class='section'>
|
||
<div class='docs'>
|
||
<p>
|
||
<a class="parent" href="/">home</a>
|
||
<a class="parent" href="../index.html">transformers</a>
|
||
<a class="parent" href="index.html">compressive</a>
|
||
</p>
|
||
<p>
|
||
|
||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/compressive/readme.md">
|
||
<img alt="Github"
|
||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||
style="max-width:100%;"/></a>
|
||
<a href="https://twitter.com/labmlai"
|
||
rel="nofollow">
|
||
<img alt="Twitter"
|
||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||
style="max-width:100%;"/></a>
|
||
</p>
|
||
</div>
|
||
</div>
|
||
<div class='section' id='section-0'>
|
||
<div class='docs'>
|
||
<div class='section-link'>
|
||
<a href='#section-0'>#</a>
|
||
</div>
|
||
<h1><a href="https://nn.labml.ai/transformers/compressive/index.html">Compressive Transformer</a></h1>
|
||
<p>This is an implementation of <a href="https://papers.labml.ai/paper/1911.05507">Compressive Transformers for Long-Range Sequence Modelling</a> in <a href="https://pytorch.org">PyTorch</a>.</p>
|
||
<p>This is an extension of <a href="https://nn.labml.ai/transformers/xl/index.html">Transformer XL</a> where past memories are compressed to give a longer attention range. That is, the furthest <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">c</span><span class="mord mathnormal mtight">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord mathnormal">c</span></span></span></span> memories are compressed into <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">c</span><span class="mord mathnormal mtight">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> memories, where <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal">c</span></span></span></span> is the compression rate.</p>
|
||
<h2>Compression operation</h2>
|
||
<p>The compression operation is defined as <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">:</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8491079999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8491079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">n</span><span class="mord mathnormal mtight">c</span><span class="mbin mtight">×</span><span class="mord mathnormal mtight">d</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8491079999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8491079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">n</span><span class="mbin mtight">×</span><span class="mord mathnormal mtight">d</span></span></span></span></span></span></span></span></span></span></span></span>. The paper introduces multiple choices for <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span> and we have only implemented 1D convolution which seems to give the best results. Each layer has a separate compression operation <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.23924em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.5834080000000004em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.11659199999999997em;"><span></span></span></span></span></span></span></span></span></span> where <span class="katex"><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord mathnormal">i</span></span></span></span> is the layer number.</p>
|
||
<h2>Training compression operation</h2>
|
||
<p>Since training compression with BPTT requires maintaining a very large computational graph (many time steps), the paper proposes an <em>auto-encoding loss</em> and an <em>attention reconstruction loss</em>. The auto-encoding loss decodes the original memories from the compressed memories and calculates the loss. Attention reconstruction loss computes the multi-headed attention results on the compressed memory and on uncompressed memory and gets a mean squared error between them. We have implemented the latter here since it gives better results.</p>
|
||
<p>This implementation uses pre-layer normalization while the paper uses post-layer normalization. Pre-layer norm does the layer norm before <a href="../feedforward.html">FFN</a> and self-attention, and the pass-through in the residual connection is not normalized. This is supposed to be more stable in standard transformer setups.</p>
|
||
<p>Here are <a href="https://nn.labml.ai/transformers/compressive/experiment.html">the training code</a> and a notebook for training a compressive transformer model on the Tiny Shakespeare dataset.</p>
|
||
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/compressive/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://app.labml.ai/run/0d9b5338726c11ebb7c80242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen"></a> </p>
|
||
|
||
</div>
|
||
<div class='code'>
|
||
|
||
</div>
|
||
</div>
|
||
<div class='footer'>
|
||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||
<a href="https://labml.ai">labml.ai</a>
|
||
</div>
|
||
</div>
|
||
<script>
|
||
function handleImages() {
|
||
var images = document.querySelectorAll('p>img')
|
||
|
||
console.log(images);
|
||
for (var i = 0; i < images.length; ++i) {
|
||
handleImage(images[i])
|
||
}
|
||
}
|
||
|
||
function handleImage(img) {
|
||
img.parentElement.style.textAlign = 'center'
|
||
|
||
var modal = document.createElement('div')
|
||
modal.id = 'modal'
|
||
|
||
var modalContent = document.createElement('div')
|
||
modal.appendChild(modalContent)
|
||
|
||
var modalImage = document.createElement('img')
|
||
modalContent.appendChild(modalImage)
|
||
|
||
var span = document.createElement('span')
|
||
span.classList.add('close')
|
||
span.textContent = 'x'
|
||
modal.appendChild(span)
|
||
|
||
img.onclick = function () {
|
||
console.log('clicked')
|
||
document.body.appendChild(modal)
|
||
modalImage.src = img.src
|
||
}
|
||
|
||
span.onclick = function () {
|
||
document.body.removeChild(modal)
|
||
}
|
||
}
|
||
|
||
handleImages()
|
||
</script>
|
||
</body>
|
||
</html> |