mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 10:51:23 +08:00
<!DOCTYPE html> <html lang="ja"> <head> <meta http-equiv="content-type" content="text/html;charset=utf-8"/> <meta name="viewport" content="width=device-width, initial-scale=1.0"/> <meta name="description" content=""/> <meta name="twitter:card" content="summary"/> <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/> <meta name="twitter:title" content="圧縮変圧器"/> <meta name="twitter:description" content=""/> <meta name="twitter:site" content="@labmlai"/> <meta name="twitter:creator" content="@labmlai"/> <meta property="og:url" content="https://nn.labml.ai/transformers/compressive/readme.html"/> <meta property="og:title" content="圧縮変圧器"/> <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/> <meta property="og:site_name" content="圧縮変圧器"/> <meta property="og:type" content="object"/> <meta property="og:title" content="圧縮変圧器"/> <meta property="og:description" content=""/> <title>圧縮変圧器</title> <link rel="shortcut icon" href="/icon.png"/> <link rel="stylesheet" href="../../pylit.css?v=1"> <link rel="canonical" href="https://nn.labml.ai/transformers/compressive/readme.html"/> <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous"> <!-- Global site tag (gtag.js) - Google Analytics --> <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script> <script> window.dataLayer = window.dataLayer || []; function gtag() { dataLayer.push(arguments); } gtag('js', new Date()); gtag('config', 'G-4V3HC8HBLH'); </script> </head> <body> <div id='container'> <div id="background"></div> <div class='section'> <div class='docs'> <p> <a class="parent" href="/">home</a> <a class="parent" href="../index.html">transformers</a> <a class="parent" href="index.html">compressive</a> </p> <p> <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank"> <img alt="Github" src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social" style="max-width:100%;"/></a> <a href="https://twitter.com/labmlai" rel="nofollow" target="_blank"> <img alt="Twitter" src="https://img.shields.io/twitter/follow/labmlai?style=social" style="max-width:100%;"/></a> </p> <p> <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/compressive/readme.md" target="_blank"> View code on Github</a> </p> </div> </div> <div class='section' id='section-0'> <div class='docs'> <div class='section-link'> <a href='#section-0'>#</a> </div> <h1><a href="https://nn.labml.ai/transformers/compressive/index.html">圧縮変圧器</a></h1> <p><a href="https://pytorch.org">これは PyTorch <a href="https://papers.labml.ai/paper/1911.05507">の長距離シーケンスモデリング用の圧縮トランスフォーマーの実装です</a>。</a></p> <p><a href="https://nn.labml.ai/transformers/xl/index.html">これはTransformer XLの拡張版で</a>、過去の記憶を圧縮して注意範囲を広げています。つまり、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord mathnormal" style="">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqf" style="">c</span></span><span class="mord mathnormal mtight" style="">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">c</span></span></span></span></span></span><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="mord coloredeq eqd" style=""><span class="mord" style=""><span class="mord mathnormal" style="">n</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqf" style="">c</span></span><span class="mord mathnormal mtight" style="">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>最も遠いメモリがメモリに圧縮されます。ここで、<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="">c</span></span></span></span></span></span>は圧縮率です</p>。 <h2>圧縮操作</h2> <p>圧縮操作は次のように定義されます<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqf" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">:</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8491079999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8491079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">n</span><span class="mord mtight coloredeq eqf" style=""><span class="mord mathnormal mtight" style="">c</span></span><span class="mbin mtight">×</span><span class="mord mathnormal mtight">d</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8491079999999999em;vertical-align:0em;"></span><span class="mord"><span class="mord mathbb">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8491079999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">n</span><span class="mbin mtight">×</span><span class="mord mathnormal mtight">d</span></span></span></span></span></span></span></span></span></span></span></span></span>。この論文では複数の選択肢を紹介していますが<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqe" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqf" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span>、最良の結果が得られると思われる1次元の畳み込みのみを実装しています。各レイヤーには個別の圧縮操作があります。<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.16678em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord coloredeq eqe" style=""><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:-0.10764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqf" style="">c</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.97234em;"><span style="top:-3.14734em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight coloredeq eqg" style=""><span class="mord mathnormal mtight" style="">i</span></span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span></span></span></span></span>ここで<span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.65952em;vertical-align:0em;"></span><span class="mord coloredeq eqg" style=""><span class="mord mathnormal" style="">i</span></span></span></span></span></span>、はレイヤー番号です。</p> <h2>トレーニング用圧縮操作</h2> <p><em>BPTTによるトレーニング圧縮では、非常に大きな計算グラフ(多くのタイムステップ)を維持する必要があるため、<em>この論文では自動エンコーディング損失と注意再構成損失を提案しています</em>。</em>自動エンコーディング損失は、圧縮されたメモリから元のメモリをデコードし、損失を計算します。アテンション再構成損失では、圧縮メモリと非圧縮メモリでマルチヘッドアテンションの結果を計算し、それらの間の平均二乗誤差を求めます。後者の方が良い結果が得られるため、ここでは後者を実装しました。</p> <p>この実装ではレイヤー前の正規化を使用しますが、ペーパーではレイヤー後の正規化を使用します。<a href="../feedforward.html">前層ノルムはFFNやセルフアテンション前の層ノルムを行い</a>、残差接続でのパススルーは正規化されません。これは標準的な変圧器の設定ではより安定しているはずです</p>。 <p>Tiny <a href="https://nn.labml.ai/transformers/compressive/experiment.html">Shakespeareデータセットで圧縮トランスフォーマーモデルをトレーニングするためのトレーニングコードとノートブックは次のとおりです</a>。</p> <p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/compressive/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p> </div> <div class='code'> </div> </div> <div class='footer'> <a href="https://papers.labml.ai">Trending Research Papers</a> <a href="https://labml.ai">labml.ai</a> </div> </div> <script src=../../interactive.js?v=1"></script> <script> function handleImages() { var images = document.querySelectorAll('p>img') for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } } function handleImage(img) { img.parentElement.style.textAlign = 'center' var modal = document.createElement('div') modal.id = 'modal' var modalContent = document.createElement('div') modal.appendChild(modalContent) var modalImage = document.createElement('img') modalContent.appendChild(modalImage) var span = document.createElement('span') span.classList.add('close') span.textContent = 'x' modal.appendChild(span) img.onclick = function () { console.log('clicked') document.body.appendChild(modal) modalImage.src = img.src } span.onclick = function () { document.body.removeChild(modal) } } handleImages() </script> </body> </html>