This commit is contained in:
Varuna Jayasiri
2021-09-17 12:06:41 +05:30
parent f87879e780
commit 40eb9cab4e
2 changed files with 86 additions and 96 deletions

View File

@ -398,7 +398,9 @@ We route to the expert with highest probability</p>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="k">if</span> <span class="n">dropped</span><span class="p">:</span>
<span class="lineno">143</span> <span class="n">dropped</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span>
<span class="lineno">144</span> <span class="n">final_output</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span></pre></div>
<span class="lineno">144</span> <span class="n">final_output</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span>
<span class="lineno">145</span>
<span class="lineno">146</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -406,10 +408,11 @@ We route to the expert with highest probability</p>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Scale the expert outputs by the routing probabilities</p>
<p>Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span> <span class="o">*</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">149</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -417,11 +420,11 @@ We route to the expert with highest probability</p>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$</p>
<p>Don&rsquo;t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
(this is something we experimented with).</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">149</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span> <span class="o">*</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="lineno">150</span> <span class="k">else</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">152</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span> <span class="o">*</span> <span class="p">(</span><span class="n">route_prob_max</span> <span class="o">/</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -429,11 +432,10 @@ We route to the expert with highest probability</p>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Don&rsquo;t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
(this is something we experimented with).</p>
<p>Change the shape of the final output back to <code>[seq_len, batch_size, d_model]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">153</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span> <span class="o">*</span> <span class="p">(</span><span class="n">route_prob_max</span> <span class="o">/</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">155</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
@ -441,17 +443,6 @@ We route to the expert with highest probability</p>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Change the shape of the final output back to <code>[seq_len, batch_size, d_model]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Return
* the final output
* number of tokens routed to each expert
@ -460,26 +451,26 @@ We route to the expert with highest probability</p>
These are used for the load balancing loss and logging</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='section' id='section-29'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-30'>#</a>
<a href='#section-29'>#</a>
</div>
<h1>Switch Transformer Block</h1>
<p>This is the same as <a href="../models.html#TransformerLayer">normal transformer block</a>
with handling extra outputs of switch feedforward module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">167</span><span class="k">class</span> <span class="nc">SwitchTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">166</span><span class="k">class</span> <span class="nc">SwitchTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='section' id='section-30'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-31'>#</a>
<a href='#section-30'>#</a>
</div>
<ul>
<li><code>d_model</code> is the token embedding size</li>
@ -489,11 +480,28 @@ with handling extra outputs of switch feedforward module.</p>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">175</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">176</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">178</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">SwitchFeedForward</span><span class="p">,</span>
<span class="lineno">179</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">175</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">176</span> <span class="n">attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">SwitchFeedForward</span><span class="p">,</span>
<span class="lineno">178</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">185</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">191</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
@ -504,13 +512,9 @@ with handling extra outputs of switch feedforward module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">186</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">191</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">192</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">194</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">195</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
@ -518,12 +522,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Normalize the vectors before doing self attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">195</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">196</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
@ -531,10 +533,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Normalize the vectors before doing self attention</p>
<p>Run through self attention, i.e. keys and values are from self</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
@ -542,10 +544,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Run through self attention, i.e. keys and values are from self</p>
<p>Add the self attention results</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
@ -553,10 +555,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Add the self attention results</p>
<p>Normalize for feed-forward</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">202</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
@ -564,10 +566,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Normalize for feed-forward</p>
<p>Pass through the switching feed-forward network</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">ff</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
@ -575,34 +577,35 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Pass through the switching feed-forward network</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">ff</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Add the feed-forward results back</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">209</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
<span class="lineno">210</span>
<span class="lineno">211</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span></pre></div>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
<span class="lineno">209</span>
<span class="lineno">210</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='section' id='section-39'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-40'>#</a>
<a href='#section-39'>#</a>
</div>
<h2>Switch Transformer</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span><span class="k">class</span> <span class="nc">SwitchTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">213</span><span class="k">class</span> <span class="nc">SwitchTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">218</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">SwitchTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">219</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
@ -610,11 +613,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
<p>Make copies of the transformer layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">219</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">SwitchTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">220</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">221</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
@ -622,10 +624,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Make copies of the transformer layer</p>
<p>Final normalization layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">223</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
@ -633,10 +635,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Final normalization layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">225</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
@ -644,10 +646,15 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>Run through each transformer layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="lineno">228</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">229</span> <span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">n_d</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="lineno">230</span> <span class="n">counts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="lineno">231</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
<span class="lineno">232</span> <span class="n">n_dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">n_d</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
@ -655,15 +662,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Run through each transformer layer</p>
<p>Finally, normalize the vectors</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="lineno">229</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">230</span> <span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">n_d</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="lineno">231</span> <span class="n">counts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="lineno">232</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
<span class="lineno">233</span> <span class="n">n_dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">n_d</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
@ -671,21 +673,10 @@ with handling extra outputs of switch feedforward module.</p>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<p>Finally, normalize the vectors</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">counts</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">route_prob</span><span class="p">),</span> <span class="n">n_dropped</span></pre></div>
<div class="highlight"><pre><span class="lineno">236</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">counts</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">route_prob</span><span class="p">),</span> <span class="n">n_dropped</span></pre></div>
</div>
</div>
<div class='footer'>