mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
<!DOCTYPE html> <html> <head> <meta http-equiv="content-type" content="text/html;charset=utf-8"/> <meta name="viewport" content="width=device-width, initial-scale=1.0"/> <meta name="description" content=""/> <meta name="twitter:card" content="summary"/> <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/> <meta name="twitter:title" content="Feedback Transformer"/> <meta name="twitter:description" content=""/> <meta name="twitter:site" content="@labmlai"/> <meta name="twitter:creator" content="@labmlai"/> <meta property="og:url" content="https://nn.labml.ai/transformers/feedback/readme.html"/> <meta property="og:title" content="Feedback Transformer"/> <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/> <meta property="og:site_name" content="LabML Neural Networks"/> <meta property="og:type" content="object"/> <meta property="og:title" content="Feedback Transformer"/> <meta property="og:description" content=""/> <title>Feedback Transformer</title> <link rel="shortcut icon" href="/icon.png"/> <link rel="stylesheet" href="../../pylit.css"> <link rel="canonical" href="https://nn.labml.ai/transformers/feedback/readme.html"/> <!-- Global site tag (gtag.js) - Google Analytics --> <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script> <script> window.dataLayer = window.dataLayer || []; function gtag() { dataLayer.push(arguments); } gtag('js', new Date()); gtag('config', 'G-4V3HC8HBLH'); </script> </head> <body> <div id='container'> <div id="background"></div> <div class='section'> <div class='docs'> <p> <a class="parent" href="/">home</a> <a class="parent" href="../index.html">transformers</a> <a class="parent" href="index.html">feedback</a> </p> <p> <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/feedback/readme.md"> <img alt="Github" src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social" style="max-width:100%;"/></a> <a href="https://twitter.com/labmlai" rel="nofollow"> <img alt="Twitter" src="https://img.shields.io/twitter/follow/labmlai?style=social" style="max-width:100%;"/></a> </p> </div> </div> <div class='section' id='section-0'> <div class='docs'> <div class='section-link'> <a href='#section-0'>#</a> </div> <h1><a href="https://nn.labml.ai/transformers/feedback/index.html">Feedback Transformer</a></h1> <p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper <a href="https://papers.labml.ai/paper/2002.09402">Accessing Higher-level Representations in Sequential Transformers with Feedback Memory</a>.</p> <p>Normal transformers process tokens in parallel. Each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence, and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However, when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.</p> <p>In order to speed up the training the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.</p> <p>The original feedback transformer doesn’t keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction. The first half of this file implements this.</p> <p>The updated feedback transformer shares weights used to calculate keys and values among the layers. We then calculate the keys and values for each step only once and keep them cached. The <a href="#shared_kv">second half</a> of this file implements this. We implemented a custom PyTorch function to improve performance.</p> <p>Here’s <a href="experiment.html">the training code</a> and a notebook for training a feedback transformer on Tiny Shakespeare dataset.</p> <p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/feedback/experiment.ipynb">Colab Notebook</a></p> <p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/feedback/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a> <a href="https://app.labml.ai/run/d8eb9416530a11eb8fb50242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p> </div> <div class='code'> </div> </div> <div class='footer'> <a href="https://papers.labml.ai">Trending Research Papers</a> <a href="https://labml.ai">labml.ai</a> </div> </div> <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML"> </script> <!-- MathJax configuration --> <script type="text/x-mathjax-config"> MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'] ], displayMath: [ ['$$','$$'] ], processEscapes: true, processEnvironments: true }, // Center justify equations in code and markdown cells. Elsewhere // we use CSS to left justify single line equations in code cells. displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); </script> <script> function handleImages() { var images = document.querySelectorAll('p>img') console.log(images); for (var i = 0; i < images.length; ++i) { handleImage(images[i]) } } function handleImage(img) { img.parentElement.style.textAlign = 'center' var modal = document.createElement('div') modal.id = 'modal' var modalContent = document.createElement('div') modal.appendChild(modalContent) var modalImage = document.createElement('img') modalContent.appendChild(modalImage) var span = document.createElement('span') span.classList.add('close') span.textContent = 'x' modal.appendChild(span) img.onclick = function () { console.log('clicked') document.body.appendChild(modal) modalImage.src = img.src } span.onclick = function () { document.body.removeChild(modal) } } handleImages() </script> </body> </html>