Files
2024-06-21 19:35:22 +05:30

760 lines
62 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This is an annotated implementation/tutorial a miniature version of Switch Transformer in PyTorch."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Switch Transformer"/>
<meta name="twitter:description" content="This is an annotated implementation/tutorial a miniature version of Switch Transformer in PyTorch."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/switch/index.html"/>
<meta property="og:title" content="Switch Transformer"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Switch Transformer"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Switch Transformer"/>
<meta property="og:description" content="This is an annotated implementation/tutorial a miniature version of Switch Transformer in PyTorch."/>
<title>Switch Transformer</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/transformers/switch/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">switch</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/switch/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Switch Transformer</h1>
<p>This is a miniature <a href="https://pytorch.org">PyTorch</a> implementation of the paper <a href="https://arxiv.org/abs/2101.03961">Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity</a>. Our implementation only has a few million parameters and doesn&#x27;t do model parallel distributed training. It does single GPU training, but we implement the concept of switching as described in the paper.</p>
<p>The Switch Transformer uses different parameters for each token by switching among parameters based on the token. Therefore, only a fraction of parameters are chosen for each token. So you can have more parameters but less computational cost.</p>
<p>The switching happens at the Position-wise Feedforward network (FFN) of each transformer block. Position-wise feedforward network consists of two sequentially fully connected layers. In switch transformer we have multiple FFNs (multiple experts), and we chose which one to use based on a router. The output is a set of probabilities for picking a FFN, and we pick the one with the highest probability and only evaluate that. So essentially the computational cost is the same as having a single FFN. In our implementation this doesn&#x27;t parallelize well when you have many or large FFNs since it&#x27;s all happening on a single GPU. In a distributed setup you would have each FFN (each very large) on a different device.</p>
<p>The paper introduces another loss term to balance load among the experts (FFNs) and discusses dropping tokens when routing is not balanced.</p>
<p>Here&#x27;s <a href="experiment.html">the training code</a> and a notebook for training a switch transformer on Tiny Shakespeare dataset.</p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/switch/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">41</span>
<span class="lineno">42</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Routing among multiple FFNs</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span><span class="k">class</span> <span class="nc">SwitchFeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">capacity_factor</span></code>
is the capacity of each expert as a factor relative to ideally balanced load </li>
<li><code class="highlight"><span></span><span class="n">drop_tokens</span></code>
specifies whether to drop tokens if more tokens are routed to an expert than the capacity </li>
<li><code class="highlight"><span></span><span class="n">is_scale_prob</span></code>
specifies whether to multiply the input to the FFN by the routing probability </li>
<li><code class="highlight"><span></span><span class="n">n_experts</span></code>
is the number of experts </li>
<li><code class="highlight"><span></span><span class="n">expert</span></code>
is the expert layer, a <a href="../feed_forward.html">FFN module</a> </li>
<li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the number of features in a token embedding </li>
<li><code class="highlight"><span></span><span class="n">d_ff</span></code>
is the number of features in the hidden layer of the FFN </li>
<li><code class="highlight"><span></span><span class="n">dropout</span></code>
is dropout probability in the FFN</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">54</span> <span class="n">capacity_factor</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="lineno">55</span> <span class="n">drop_tokens</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="lineno">56</span> <span class="n">is_scale_prob</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="lineno">57</span> <span class="n">n_experts</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">58</span> <span class="n">expert</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">59</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">71</span>
<span class="lineno">72</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">=</span> <span class="n">capacity_factor</span>
<span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span> <span class="o">=</span> <span class="n">is_scale_prob</span>
<span class="lineno">74</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span> <span class="o">=</span> <span class="n">n_experts</span>
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span> <span class="o">=</span> <span class="n">drop_tokens</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>make copies of the FFNs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">expert</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Routing layer and softmax </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">switch</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span>
<span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
is the input to the switching module with shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Capture the shape to change shapes later </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Flatten the sequence and batch dimensions </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Get routing probabilities for each of the tokens. <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.872049em;vertical-align:-1.3070490000000001em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.565em;"><span style="top:-2.128769em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mop"><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.981231em;"><span style="top:-2.40029em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span><span style="top:-3.2029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqf" style=""><span class="mord mathnormal mtight" style="margin-right:0.10903em">N</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.43581800000000004em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8300699999999999em;"><span style="top:-3.0050700000000004em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">h</span><span class="mopen mtight">(</span><span class="mord mathnormal mtight">x</span><span class="mclose mtight"><span class="mclose mtight">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.05724em;">j</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">h</span><span class="mopen mtight">(</span><span class="mord mathnormal mtight">x</span><span class="mclose mtight"><span class="mclose mtight">)</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3280857142857143em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.3070490000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span> where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqf" style=""><span class="mord mathnormal" style="margin-right:0.10903em">N</span></span></span></span></span></span> is the number of experts <code class="highlight"><span></span><span class="n">n_experts</span></code>
and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">h</span><span class="mopen">(</span><span class="mord"></span><span class="mclose">)</span></span></span></span></span> is the linear transformation of token embeddings. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">route_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">switch</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Get the maximum routing probabilities and the routes. We route to the expert with highest probability </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">route_prob_max</span><span class="p">,</span> <span class="n">routes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">route_prob</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Get indexes of tokens going to each expert </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">indexes_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">routes</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Initialize an empty tensor to store outputs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Capacity of each expert. <span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8623000000000001em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathrm">expert</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathrm" style="margin-right:0.01389em;">capacity</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.25188em;vertical-align:-0.8804400000000001em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3714399999999998em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathrm">number</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathrm" style="margin-right:0.07778em;">of</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathrm">experts</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathrm">tokens</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathrm">per</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathrm">batch</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8804400000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathrm" style="margin-right:0.01389em;">capacity</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mord mathrm">factor</span></span></span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">capacity</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Number of tokens routed to each expert. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">115</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)])</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Initialize an empty list of dropped tokens </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">118</span> <span class="n">dropped</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Only drop tokens if <code class="highlight"><span></span><span class="n">drop_tokens</span></code>
is <code class="highlight"><span></span><span class="kc">True</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">120</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Drop tokens in each of the experts </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Ignore if the expert is not over capacity </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">&lt;=</span> <span class="n">capacity</span><span class="p">:</span>
<span class="lineno">125</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Shuffle indexes before dropping </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">torch</span><span class="o">.</span><span class="n">randperm</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]))]</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Collect the tokens over capacity as dropped tokens </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">capacity</span><span class="p">:])</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Keep only the tokens upto the capacity of the expert </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][:</span><span class="n">capacity</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Get outputs of the expert FFNs </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="n">expert_output</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Assign to final output </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span>
<span class="lineno">138</span> <span class="n">final_output</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">expert_output</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Pass through the dropped tokens </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">141</span> <span class="k">if</span> <span class="n">dropped</span><span class="p">:</span>
<span class="lineno">142</span> <span class="n">dropped</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span>
<span class="lineno">143</span> <span class="n">final_output</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">144</span>
<span class="lineno">145</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Multiply by the expert outputs by the probabilities <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.05764em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span> <span class="o">*</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">148</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Don&#x27;t scale the values but multiply by <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.228608em;vertical-align:-0.481108em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7475em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord accent mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="mord mathnormal mtight">p</span></span><span style="top:-2.7em;"><span class="pstrut" style="height:2.7em;"></span><span class="accent-body" style="left:-0.16666em;"><span class="mord mtight">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.446108em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">p</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.481108em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">1</span></span></span></span></span> so that the gradients flow (this is something we experimented with). </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">151</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span> <span class="o">*</span> <span class="p">(</span><span class="n">route_prob_max</span> <span class="o">/</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Change the shape of the final output back to <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Return</p>
<ul><li>the final output </li>
<li>number of tokens routed to each expert </li>
<li>sum of probabilities for each expert </li>
<li>number of tokens dropped. </li>
<li>routing probabilities of the selected experts</li></ul>
<p>These are used for the load balancing loss and logging </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">165</span> <span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">dropped</span><span class="p">),</span> <span class="n">route_prob_max</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<h1>Switch Transformer Block</h1>
<p>This is the same as <a href="../models.html#TransformerLayer">normal transformer block</a> with handling extra outputs of switch feedforward module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span><span class="k">class</span> <span class="nc">SwitchTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">d_model</span></code>
is the token embedding size </li>
<li><code class="highlight"><span></span><span class="n">attn</span></code>
is the attention module </li>
<li><code class="highlight"><span></span><span class="n">feed_forward</span></code>
is the feed forward module (which is the switching module in this case) </li>
<li><code class="highlight"><span></span><span class="n">dropout_prob</span></code>
is the probability of dropping out after self attention and FFN</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">176</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">178</span> <span class="n">attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">179</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">SwitchFeedForward</span><span class="p">,</span>
<span class="lineno">180</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
<span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">191</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">192</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">193</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">195</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">196</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">197</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Normalize the vectors before doing self attention </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Run through self attention, i.e. keys and values are from self </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Add the self attention results </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Normalize for feed-forward </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Pass through the switching feed-forward network </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">ff</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span><span class="p">,</span> <span class="n">route_prob_max</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Add the feed-forward results back </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
<span class="lineno">211</span>
<span class="lineno">212</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span><span class="p">,</span> <span class="n">route_prob_max</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<h2>Switch Transformer</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">215</span><span class="k">class</span> <span class="nc">SwitchTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">SwitchTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">221</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Make copies of the transformer layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">223</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Final normalization layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">225</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Run through each transformer layer </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">229</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span><span class="p">,</span> <span class="n">route_prob_max</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="lineno">230</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">231</span> <span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">n_d</span><span class="p">,</span> <span class="n">p_max</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="lineno">232</span> <span class="n">counts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="lineno">233</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
<span class="lineno">234</span> <span class="n">n_dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">n_d</span><span class="p">)</span>
<span class="lineno">235</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_max</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Finally, normalize the vectors </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p> </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">239</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">counts</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">route_prob</span><span class="p">),</span> <span class="n">n_dropped</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">route_prob_max</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>