|
|
|
@ -78,8 +78,9 @@
|
|
|
|
|
Our implementation only has a few million parameters and doesn’t do model parallel distributed training.
|
|
|
|
|
It does single GPU training, but we implement the concept of switching as described in the paper.</p>
|
|
|
|
|
<p>The Switch Transformer uses different parameters for each token by switching among parameters
|
|
|
|
|
based on the token. Thererfore, only a fraction of parameters are chosen for each token. So you
|
|
|
|
|
can have more parameters but less computational cost.</p>
|
|
|
|
|
based on the token.
|
|
|
|
|
Therefore, only a fraction of parameters are chosen for each token.
|
|
|
|
|
So you can have more parameters but less computational cost.</p>
|
|
|
|
|
<p>The switching happens at the Position-wise Feedforward network (FFN) of each transformer block.
|
|
|
|
|
Position-wise feedforward network consists of two sequentially fully connected layers.
|
|
|
|
|
In switch transformer we have multiple FFNs (multiple experts),
|
|
|
|
@ -97,13 +98,13 @@ discusses dropping tokens when routing is not balanced.</p>
|
|
|
|
|
<a href="https://web.lab-ml.com/run?uuid=c4656c605b9311eba13d0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">39</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
|
|
|
|
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
|
|
|
|
<span class="lineno">41</span>
|
|
|
|
|
<span class="lineno">42</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
|
|
|
|
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
|
|
|
|
|
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
|
|
|
|
|
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">40</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
|
|
|
|
<span class="lineno">41</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
|
|
|
|
<span class="lineno">42</span>
|
|
|
|
|
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
|
|
|
|
|
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
|
|
|
|
|
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
|
|
|
|
|
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-1'>
|
|
|
|
@ -114,7 +115,7 @@ discusses dropping tokens when routing is not balanced.</p>
|
|
|
|
|
<h2>Routing among multiple FFNs</h2>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">48</span><span class="k">class</span> <span class="nc">SwitchFeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">49</span><span class="k">class</span> <span class="nc">SwitchFeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-2'>
|
|
|
|
@ -134,13 +135,13 @@ discusses dropping tokens when routing is not balanced.</p>
|
|
|
|
|
</ul>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">53</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">54</span> <span class="n">capacity_factor</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">55</span> <span class="n">drop_tokens</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">56</span> <span class="n">is_scale_prob</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">57</span> <span class="n">n_experts</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">58</span> <span class="n">expert</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">59</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">54</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">55</span> <span class="n">capacity_factor</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">56</span> <span class="n">drop_tokens</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">57</span> <span class="n">is_scale_prob</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">58</span> <span class="n">n_experts</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">59</span> <span class="n">expert</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">60</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-3'>
|
|
|
|
@ -151,12 +152,12 @@ discusses dropping tokens when routing is not balanced.</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">70</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
|
|
|
<span class="lineno">71</span>
|
|
|
|
|
<span class="lineno">72</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">=</span> <span class="n">capacity_factor</span>
|
|
|
|
|
<span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span> <span class="o">=</span> <span class="n">is_scale_prob</span>
|
|
|
|
|
<span class="lineno">74</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span> <span class="o">=</span> <span class="n">n_experts</span>
|
|
|
|
|
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span> <span class="o">=</span> <span class="n">drop_tokens</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">71</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
|
|
|
<span class="lineno">72</span>
|
|
|
|
|
<span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">=</span> <span class="n">capacity_factor</span>
|
|
|
|
|
<span class="lineno">74</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span> <span class="o">=</span> <span class="n">is_scale_prob</span>
|
|
|
|
|
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span> <span class="o">=</span> <span class="n">n_experts</span>
|
|
|
|
|
<span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span> <span class="o">=</span> <span class="n">drop_tokens</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-4'>
|
|
|
|
@ -167,7 +168,7 @@ discusses dropping tokens when routing is not balanced.</p>
|
|
|
|
|
<p>make copies of the FFNs</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">expert</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">79</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">expert</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-5'>
|
|
|
|
@ -178,8 +179,8 @@ discusses dropping tokens when routing is not balanced.</p>
|
|
|
|
|
<p>Routing layer and softmax</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">switch</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">switch</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">82</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-6'>
|
|
|
|
@ -192,7 +193,7 @@ discusses dropping tokens when routing is not balanced.</p>
|
|
|
|
|
</ul>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">84</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-7'>
|
|
|
|
@ -203,7 +204,7 @@ discusses dropping tokens when routing is not balanced.</p>
|
|
|
|
|
<p>Capture the shape to change shapes later</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-8'>
|
|
|
|
@ -214,7 +215,7 @@ discusses dropping tokens when routing is not balanced.</p>
|
|
|
|
|
<p>Flatten the sequence and batch dimensions</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-9'>
|
|
|
|
@ -228,7 +229,7 @@ where $N$ is the number of experts <code>n_experts</code> and
|
|
|
|
|
$h(\cdot)$ is the linear transformation of token embeddings.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">route_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">switch</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">98</span> <span class="n">route_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">switch</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-10'>
|
|
|
|
@ -240,7 +241,7 @@ $h(\cdot)$ is the linear transformation of token embeddings.</p>
|
|
|
|
|
We route to the expert with highest probability</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">route_prob_max</span><span class="p">,</span> <span class="n">routes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">route_prob</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">102</span> <span class="n">route_prob_max</span><span class="p">,</span> <span class="n">routes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">route_prob</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-11'>
|
|
|
|
@ -251,8 +252,8 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Scale the inputs to the experts by the routing probabilities</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">104</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">105</span> <span class="n">factor</span> <span class="o">=</span> <span class="n">route_prob_max</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">105</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">106</span> <span class="n">factor</span> <span class="o">=</span> <span class="n">route_prob_max</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-12'>
|
|
|
|
@ -263,8 +264,8 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Don’t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">else</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">108</span> <span class="n">factor</span> <span class="o">=</span> <span class="n">route_prob_max</span> <span class="o">/</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">else</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">109</span> <span class="n">factor</span> <span class="o">=</span> <span class="n">route_prob_max</span> <span class="o">/</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-13'>
|
|
|
|
@ -275,7 +276,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Multiply by the scaling factor</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">110</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="n">factor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="n">factor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-14'>
|
|
|
|
@ -286,7 +287,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Get indexes of tokens going to each expert</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">indexes_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">routes</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">indexes_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">routes</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-15'>
|
|
|
|
@ -297,7 +298,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Initialize an empty tensor to store outputs</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">116</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-16'>
|
|
|
|
@ -312,7 +313,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">122</span> <span class="n">capacity</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">capacity</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-17'>
|
|
|
|
@ -323,7 +324,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Number of tokens routed to each expert.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)])</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">125</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)])</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-18'>
|
|
|
|
@ -334,7 +335,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Initialize an empty list of dropped tokens</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">dropped</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">dropped</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-19'>
|
|
|
|
@ -345,7 +346,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Only drop tokens if <code>drop_tokens</code> is <code>True</code>.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">129</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span><span class="p">:</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span><span class="p">:</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-20'>
|
|
|
|
@ -356,7 +357,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Drop tokens in each of the experts</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">132</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-21'>
|
|
|
|
@ -367,8 +368,8 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Ignore if the expert is not over capacity</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o"><=</span> <span class="n">capacity</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">134</span> <span class="k">continue</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">134</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o"><=</span> <span class="n">capacity</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">135</span> <span class="k">continue</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-22'>
|
|
|
|
@ -379,7 +380,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Shuffle indexes before dropping</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">torch</span><span class="o">.</span><span class="n">randperm</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]))]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">torch</span><span class="o">.</span><span class="n">randperm</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]))]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-23'>
|
|
|
|
@ -390,7 +391,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Collect the tokens over capacity as dropped tokens</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">capacity</span><span class="p">:])</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">capacity</span><span class="p">:])</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-24'>
|
|
|
|
@ -401,7 +402,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Keep only the tokens upto the capacity of the expert</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][:</span><span class="n">capacity</span><span class="p">]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][:</span><span class="n">capacity</span><span class="p">]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-25'>
|
|
|
|
@ -412,7 +413,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Get outputs of the expert FFNs</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">route_outputs</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">route_outputs</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-26'>
|
|
|
|
@ -423,8 +424,8 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Assign to final output</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">146</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span>
|
|
|
|
|
<span class="lineno">147</span> <span class="n">final_output</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">route_outputs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">147</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span>
|
|
|
|
|
<span class="lineno">148</span> <span class="n">final_output</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">route_outputs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-27'>
|
|
|
|
@ -435,9 +436,9 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Pass through the dropped tokens</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">150</span> <span class="k">if</span> <span class="n">dropped</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">151</span> <span class="n">dropped</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">152</span> <span class="n">final_output</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">151</span> <span class="k">if</span> <span class="n">dropped</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">152</span> <span class="n">dropped</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">153</span> <span class="n">final_output</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-28'>
|
|
|
|
@ -448,7 +449,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
<p>Change the shape of the final output back to <code>[seq_len, batch_size, d_model]</code></p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">155</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-29'>
|
|
|
|
@ -464,7 +465,7 @@ We route to the expert with highest probability</p>
|
|
|
|
|
These are used for the load balancing loss and logging</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-30'>
|
|
|
|
@ -477,7 +478,7 @@ These are used for the load balancing loss and logging</p>
|
|
|
|
|
with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">166</span><span class="k">class</span> <span class="nc">SwitchTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">167</span><span class="k">class</span> <span class="nc">SwitchTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-31'>
|
|
|
|
@ -493,11 +494,11 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
</ul>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">174</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">175</span> <span class="n">attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">176</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">SwitchFeedForward</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">177</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">175</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">176</span> <span class="n">attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">177</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">SwitchFeedForward</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">178</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-32'>
|
|
|
|
@ -508,13 +509,13 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">184</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
|
|
|
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
|
|
|
|
|
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
|
|
|
|
|
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
|
|
|
|
|
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
|
|
|
|
|
<span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">185</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
|
|
|
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
|
|
|
|
|
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
|
|
|
|
|
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
|
|
|
|
|
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
|
|
|
|
|
<span class="lineno">191</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-33'>
|
|
|
|
@ -525,9 +526,9 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">192</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">193</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">194</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">194</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">195</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-34'>
|
|
|
|
@ -538,7 +539,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Normalize the vectors before doing self attention</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-35'>
|
|
|
|
@ -549,7 +550,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Run through self attention, i.e. keys and values are from self</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">198</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-36'>
|
|
|
|
@ -560,7 +561,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Add the self attention results</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-37'>
|
|
|
|
@ -571,7 +572,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Normalize for feed-forward</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-38'>
|
|
|
|
@ -582,7 +583,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Pass through the switching feed-forward network</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">ff</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">ff</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-39'>
|
|
|
|
@ -593,9 +594,9 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Add the feed-forward results back</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">208</span>
|
|
|
|
|
<span class="lineno">209</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">209</span>
|
|
|
|
|
<span class="lineno">210</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-40'>
|
|
|
|
@ -606,7 +607,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<h2>Switch Transformer</h2>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">212</span><span class="k">class</span> <span class="nc">SwitchTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">213</span><span class="k">class</span> <span class="nc">SwitchTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-41'>
|
|
|
|
@ -617,8 +618,8 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">217</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">SwitchTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
|
|
|
<span class="lineno">218</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">218</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">SwitchTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
|
|
|
<span class="lineno">219</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-42'>
|
|
|
|
@ -629,7 +630,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Make copies of the transformer layer</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">220</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">221</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-43'>
|
|
|
|
@ -640,7 +641,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Final normalization layer</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">222</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">223</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-44'>
|
|
|
|
@ -651,7 +652,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">224</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">225</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-45'>
|
|
|
|
@ -662,12 +663,12 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Run through each transformer layer</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
|
|
|
|
|
<span class="lineno">227</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">228</span> <span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">n_d</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">229</span> <span class="n">counts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">230</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">231</span> <span class="n">n_dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">n_d</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
|
|
|
|
|
<span class="lineno">228</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">229</span> <span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">n_d</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">230</span> <span class="n">counts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">231</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">232</span> <span class="n">n_dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">n_d</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-46'>
|
|
|
|
@ -678,7 +679,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
<p>Finally, normalize the vectors</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">233</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-47'>
|
|
|
|
@ -689,7 +690,7 @@ with handling extra outputs of switch feedforward module.</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">235</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">counts</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">route_prob</span><span class="p">),</span> <span class="n">n_dropped</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">236</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">counts</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">route_prob</span><span class="p">),</span> <span class="n">n_dropped</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|