typos in readmes

This commit is contained in:
Varuna Jayasiri
2021-02-19 09:23:27 +05:30
parent 3b1e75da62
commit ccb9ee2e4c
8 changed files with 142 additions and 138 deletions

View File

@ -82,11 +82,11 @@ network parameters during training.
For example, let’s say there are two layers $l_1$ and $l_2$.
During the beginning of the training $l_1$ outputs (inputs to $l_2$)
could be in distribution $\mathcal{N}(0.5, 1)$.
Then, after some training steps, it could move to $\mathcal{N}(0.5, 1)$.
Then, after some training steps, it could move to $\mathcal{N}(0.6, 1.5)$.
This is <em>internal covariate shift</em>.</p>
<p>Internal covariate shift will adversely affect training speed because the later layers
($l_2$ in the above example) has to adapt to this shifted distribution.</p>
<p>By stabilizing the distribution batch normalization minimizes the internal covariate shift.</p>
($l_2$ in the above example) have to adapt to this shifted distribution.</p>
<p>By stabilizing the distribution, batch normalization minimizes the internal covariate shift.</p>
<h2>Normalization</h2>
<p>It is known that whitening improves training speed and convergence.
<em>Whitening</em> is linearly transforming inputs to have zero mean, unit variance,
@ -95,9 +95,9 @@ and be uncorrelated.</p>
<p>Normalizing outside the gradient computation using pre-computed (detached)
means and variances doesn&rsquo;t work. For instance. (ignoring variance), let
<script type="math/tex; mode=display">\hat{x} = x - \mathbb{E}[x]</script>
where $x = u + b$ and $b$ is a trained bias.
and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant).</p>
<p>Note that $\hat{x}$ has no effect of $b$.
where $x = u + b$ and $b$ is a trained bias
and $\mathbb{E}[x]$ is an outside gradient computation (pre-computed constant).</p>
<p>Note that $\hat{x}$ has no effect on $b$.
Therefore,
$b$ will increase or decrease based
$\frac{\partial{\mathcal{L}}}{\partial x}$,
@ -106,14 +106,14 @@ The paper notes that similar explosions happen with variances.</p>
<h3>Batch Normalization</h3>
<p>Whitening is computationally expensive because you need to de-correlate and
the gradients must flow through the full whitening calculation.</p>
<p>The paper introduces simplified version which they call <em>Batch Normalization</em>.
<p>The paper introduces a simplified version which they call <em>Batch Normalization</em>.
First simplification is that it normalizes each feature independently to have
zero mean and unit variance:
<script type="math/tex; mode=display">\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}</script>
where $x = (x^{(1)} &hellip; x^{(d)})$ is the $d$-dimensional input.</p>
<p>The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$
and variance $Var[x^{(k)}]$ from the mini-batch
for normalization; instead of calculating the mean and variance across whole dataset.</p>
for normalization; instead of calculating the mean and variance across the whole dataset.</p>
<p>Normalizing each feature to zero mean and unit variance could affect what the layer
can represent.
As an example paper illustrates that, if the inputs to a sigmoid are normalized
@ -126,8 +126,8 @@ where $y^{(k)}$ is the output of the batch normalization layer.</p>
like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization.
So you can and should omit bias parameter in linear transforms right before the
batch normalization.</p>
<p>Batch normalization also makes the back propagation invariant to the scale of the weights.
And empirically it improves generalization, so it has regularization effects too.</p>
<p>Batch normalization also makes the back propagation invariant to the scale of the weights
and empirically it improves generalization, so it has regularization effects too.</p>
<h2>Inference</h2>
<p>We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to
perform the normalization.
@ -135,8 +135,8 @@ So during inference, you either need to go through the whole (or part of) datase
and find the mean and variance, or you can use an estimate calculated during training.
The usual practice is to calculate an exponential moving average of
mean and variance during the training phase and use that for inference.</p>
<p>Here&rsquo;s <a href="https://nn.labml.ai/normalization/layer_norm/mnist.html">the training code</a> and a notebook for training
a CNN classifier that use batch normalization for MNIST dataset.</p>
<p>Here&rsquo;s <a href="mnist.html">the training code</a> and a notebook for training
a CNN classifier that uses batch normalization for MNIST dataset.</p>
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/batch_norm/mnist.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://web.lab-ml.com/run?uuid=011254fe647011ebbb8e0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>

View File

@ -78,8 +78,9 @@
Our implementation only has a few million parameters and doesn&rsquo;t do model parallel distributed training.
It does single GPU training, but we implement the concept of switching as described in the paper.</p>
<p>The Switch Transformer uses different parameters for each token by switching among parameters
based on the token. Thererfore, only a fraction of parameters are chosen for each token. So you
can have more parameters but less computational cost.</p>
based on the token.
Therefore, only a fraction of parameters are chosen for each token.
So you can have more parameters but less computational cost.</p>
<p>The switching happens at the Position-wise Feedforward network (FFN) of each transformer block.
Position-wise feedforward network consists of two sequentially fully connected layers.
In switch transformer we have multiple FFNs (multiple experts),
@ -97,13 +98,13 @@ discusses dropping tokens when routing is not balanced.</p>
<a href="https://web.lab-ml.com/run?uuid=c4656c605b9311eba13d0242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">41</span>
<span class="lineno">42</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
<div class="highlight"><pre><span class="lineno">40</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">41</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">42</span>
<span class="lineno">43</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">44</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.mha</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="lineno">45</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.feed_forward</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="lineno">46</span><span class="kn">from</span> <span class="nn">labml_nn.utils</span> <span class="kn">import</span> <span class="n">clone_module_list</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -114,7 +115,7 @@ discusses dropping tokens when routing is not balanced.</p>
<h2>Routing among multiple FFNs</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span><span class="k">class</span> <span class="nc">SwitchFeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">49</span><span class="k">class</span> <span class="nc">SwitchFeedForward</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -134,13 +135,13 @@ discusses dropping tokens when routing is not balanced.</p>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">53</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">54</span> <span class="n">capacity_factor</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="lineno">55</span> <span class="n">drop_tokens</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="lineno">56</span> <span class="n">is_scale_prob</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="lineno">57</span> <span class="n">n_experts</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">58</span> <span class="n">expert</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">59</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">54</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">55</span> <span class="n">capacity_factor</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="lineno">56</span> <span class="n">drop_tokens</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="lineno">57</span> <span class="n">is_scale_prob</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="lineno">58</span> <span class="n">n_experts</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">59</span> <span class="n">expert</span><span class="p">:</span> <span class="n">FeedForward</span><span class="p">,</span>
<span class="lineno">60</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -151,12 +152,12 @@ discusses dropping tokens when routing is not balanced.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">71</span>
<span class="lineno">72</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">=</span> <span class="n">capacity_factor</span>
<span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span> <span class="o">=</span> <span class="n">is_scale_prob</span>
<span class="lineno">74</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span> <span class="o">=</span> <span class="n">n_experts</span>
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span> <span class="o">=</span> <span class="n">drop_tokens</span></pre></div>
<div class="highlight"><pre><span class="lineno">71</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">72</span>
<span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">=</span> <span class="n">capacity_factor</span>
<span class="lineno">74</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span> <span class="o">=</span> <span class="n">is_scale_prob</span>
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span> <span class="o">=</span> <span class="n">n_experts</span>
<span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span> <span class="o">=</span> <span class="n">drop_tokens</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -167,7 +168,7 @@ discusses dropping tokens when routing is not balanced.</p>
<p>make copies of the FFNs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">expert</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">79</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">expert</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -178,8 +179,8 @@ discusses dropping tokens when routing is not balanced.</p>
<p>Routing layer and softmax</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">switch</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span>
<span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">switch</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_experts</span><span class="p">)</span>
<span class="lineno">82</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -192,7 +193,7 @@ discusses dropping tokens when routing is not balanced.</p>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">84</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -203,7 +204,7 @@ discusses dropping tokens when routing is not balanced.</p>
<p>Capture the shape to change shapes later</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -214,7 +215,7 @@ discusses dropping tokens when routing is not balanced.</p>
<p>Flatten the sequence and batch dimensions</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">92</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -228,7 +229,7 @@ where $N$ is the number of experts <code>n_experts</code> and
$h(\cdot)$ is the linear transformation of token embeddings.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">route_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">switch</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">98</span> <span class="n">route_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">switch</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -240,7 +241,7 @@ $h(\cdot)$ is the linear transformation of token embeddings.</p>
We route to the expert with highest probability</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="n">route_prob_max</span><span class="p">,</span> <span class="n">routes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">route_prob</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">102</span> <span class="n">route_prob_max</span><span class="p">,</span> <span class="n">routes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">route_prob</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -251,8 +252,8 @@ We route to the expert with highest probability</p>
<p>Scale the inputs to the experts by the routing probabilities</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span><span class="p">:</span>
<span class="lineno">105</span> <span class="n">factor</span> <span class="o">=</span> <span class="n">route_prob_max</span></pre></div>
<div class="highlight"><pre><span class="lineno">105</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_scale_prob</span><span class="p">:</span>
<span class="lineno">106</span> <span class="n">factor</span> <span class="o">=</span> <span class="n">route_prob_max</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -263,8 +264,8 @@ We route to the expert with highest probability</p>
<p>Don&rsquo;t scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">108</span> <span class="n">factor</span> <span class="o">=</span> <span class="n">route_prob_max</span> <span class="o">/</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">108</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">109</span> <span class="n">factor</span> <span class="o">=</span> <span class="n">route_prob_max</span> <span class="o">/</span> <span class="n">route_prob_max</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -275,7 +276,7 @@ We route to the expert with highest probability</p>
<p>Multiply by the scaling factor</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">110</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="n">factor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">111</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="n">factor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -286,7 +287,7 @@ We route to the expert with highest probability</p>
<p>Get indexes of tokens going to each expert</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="n">indexes_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">routes</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">indexes_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">routes</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -297,7 +298,7 @@ We route to the expert with highest probability</p>
<p>Initialize an empty tensor to store outputs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">117</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -312,7 +313,7 @@ We route to the expert with highest probability</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">122</span> <span class="n">capacity</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">123</span> <span class="n">capacity</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
@ -323,7 +324,7 @@ We route to the expert with highest probability</p>
<p>Number of tokens routed to each expert.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">124</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)])</span></pre></div>
<div class="highlight"><pre><span class="lineno">125</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new_tensor</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)])</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -334,7 +335,7 @@ We route to the expert with highest probability</p>
<p>Initialize an empty list of dropped tokens</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="n">dropped</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">dropped</span> <span class="o">=</span> <span class="p">[]</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
@ -345,7 +346,7 @@ We route to the expert with highest probability</p>
<p>Only drop tokens if <code>drop_tokens</code> is <code>True</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">130</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_tokens</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
@ -356,7 +357,7 @@ We route to the expert with highest probability</p>
<p>Drop tokens in each of the experts</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">132</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
@ -367,8 +368,8 @@ We route to the expert with highest probability</p>
<p>Ignore if the expert is not over capacity</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">&lt;=</span> <span class="n">capacity</span><span class="p">:</span>
<span class="lineno">134</span> <span class="k">continue</span></pre></div>
<div class="highlight"><pre><span class="lineno">134</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">&lt;=</span> <span class="n">capacity</span><span class="p">:</span>
<span class="lineno">135</span> <span class="k">continue</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
@ -379,7 +380,7 @@ We route to the expert with highest probability</p>
<p>Shuffle indexes before dropping</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">torch</span><span class="o">.</span><span class="n">randperm</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]))]</span></pre></div>
<div class="highlight"><pre><span class="lineno">137</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">torch</span><span class="o">.</span><span class="n">randperm</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]))]</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
@ -390,7 +391,7 @@ We route to the expert with highest probability</p>
<p>Collect the tokens over capacity as dropped tokens</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">capacity</span><span class="p">:])</span></pre></div>
<div class="highlight"><pre><span class="lineno">139</span> <span class="n">dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">capacity</span><span class="p">:])</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
@ -401,7 +402,7 @@ We route to the expert with highest probability</p>
<p>Keep only the tokens upto the capacity of the expert</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][:</span><span class="n">capacity</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">141</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">][:</span><span class="n">capacity</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -412,7 +413,7 @@ We route to the expert with highest probability</p>
<p>Get outputs of the expert FFNs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="n">route_outputs</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
<div class="highlight"><pre><span class="lineno">144</span> <span class="n">route_outputs</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">x</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -423,8 +424,8 @@ We route to the expert with highest probability</p>
<p>Assign to final output</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">146</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span>
<span class="lineno">147</span> <span class="n">final_output</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">route_outputs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">147</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_experts</span><span class="p">):</span>
<span class="lineno">148</span> <span class="n">final_output</span><span class="p">[</span><span class="n">indexes_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">route_outputs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -435,9 +436,9 @@ We route to the expert with highest probability</p>
<p>Pass through the dropped tokens</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="k">if</span> <span class="n">dropped</span><span class="p">:</span>
<span class="lineno">151</span> <span class="n">dropped</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span>
<span class="lineno">152</span> <span class="n">final_output</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span></pre></div>
<div class="highlight"><pre><span class="lineno">151</span> <span class="k">if</span> <span class="n">dropped</span><span class="p">:</span>
<span class="lineno">152</span> <span class="n">dropped</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span>
<span class="lineno">153</span> <span class="n">final_output</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">dropped</span><span class="p">,</span> <span class="p">:]</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
@ -448,7 +449,7 @@ We route to the expert with highest probability</p>
<p>Change the shape of the final output back to <code>[seq_len, batch_size, d_model]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">155</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">final_output</span> <span class="o">=</span> <span class="n">final_output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
@ -464,7 +465,7 @@ We route to the expert with highest probability</p>
These are used for the load balancing loss and logging</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">163</span> <span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">dropped</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
@ -477,7 +478,7 @@ These are used for the load balancing loss and logging</p>
with handling extra outputs of switch feedforward module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">166</span><span class="k">class</span> <span class="nc">SwitchTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">167</span><span class="k">class</span> <span class="nc">SwitchTransformerLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
@ -493,11 +494,11 @@ with handling extra outputs of switch feedforward module.</p>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">174</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">175</span> <span class="n">attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">176</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">SwitchFeedForward</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">175</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">176</span> <span class="n">attn</span><span class="p">:</span> <span class="n">MultiHeadAttention</span><span class="p">,</span>
<span class="lineno">177</span> <span class="n">feed_forward</span><span class="p">:</span> <span class="n">SwitchFeedForward</span><span class="p">,</span>
<span class="lineno">178</span> <span class="n">dropout_prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
@ -508,13 +509,13 @@ with handling extra outputs of switch feedforward module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">185</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">185</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">186</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span>
<span class="lineno">188</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span> <span class="o">=</span> <span class="n">feed_forward</span>
<span class="lineno">189</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_prob</span><span class="p">)</span>
<span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span>
<span class="lineno">191</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
@ -525,9 +526,9 @@ with handling extra outputs of switch feedforward module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">192</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">193</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">194</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">193</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">194</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="lineno">195</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
@ -538,7 +539,7 @@ with handling extra outputs of switch feedforward module.</p>
<p>Normalize the vectors before doing self attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_self_attn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
@ -549,7 +550,7 @@ with handling extra outputs of switch feedforward module.</p>
<p>Run through self attention, i.e. keys and values are from self</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">199</span> <span class="n">self_attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
@ -560,7 +561,7 @@ with handling extra outputs of switch feedforward module.</p>
<p>Add the self attention results</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">self_attn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
@ -571,7 +572,7 @@ with handling extra outputs of switch feedforward module.</p>
<p>Normalize for feed-forward</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">204</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
@ -582,7 +583,7 @@ with handling extra outputs of switch feedforward module.</p>
<p>Pass through the switching feed-forward network</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="n">ff</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">206</span> <span class="n">ff</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feed_forward</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
@ -593,9 +594,9 @@ with handling extra outputs of switch feedforward module.</p>
<p>Add the feed-forward results back</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">207</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
<span class="lineno">208</span>
<span class="lineno">209</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span></pre></div>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">ff</span><span class="p">)</span>
<span class="lineno">209</span>
<span class="lineno">210</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
@ -606,7 +607,7 @@ with handling extra outputs of switch feedforward module.</p>
<h2>Switch Transformer</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span><span class="k">class</span> <span class="nc">SwitchTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">213</span><span class="k">class</span> <span class="nc">SwitchTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
@ -617,8 +618,8 @@ with handling extra outputs of switch feedforward module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">SwitchTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">218</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
<div class="highlight"><pre><span class="lineno">218</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer</span><span class="p">:</span> <span class="n">SwitchTransformerLayer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">219</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
@ -629,7 +630,7 @@ with handling extra outputs of switch feedforward module.</p>
<p>Make copies of the transformer layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">220</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">221</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">clone_module_list</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">n_layers</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
@ -640,7 +641,7 @@ with handling extra outputs of switch feedforward module.</p>
<p>Final normalization layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
<div class="highlight"><pre><span class="lineno">223</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">layer</span><span class="o">.</span><span class="n">size</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
@ -651,7 +652,7 @@ with handling extra outputs of switch feedforward module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">225</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
@ -662,12 +663,12 @@ with handling extra outputs of switch feedforward module.</p>
<p>Run through each transformer layer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">226</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="lineno">227</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">228</span> <span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">n_d</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="lineno">229</span> <span class="n">counts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="lineno">230</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
<span class="lineno">231</span> <span class="n">n_dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">n_d</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">counts</span><span class="p">,</span> <span class="n">route_prob</span><span class="p">,</span> <span class="n">n_dropped</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="lineno">228</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
<span class="lineno">229</span> <span class="n">x</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">n_d</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="lineno">230</span> <span class="n">counts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="lineno">231</span> <span class="n">route_prob</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
<span class="lineno">232</span> <span class="n">n_dropped</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">n_d</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
@ -678,7 +679,7 @@ with handling extra outputs of switch feedforward module.</p>
<p>Finally, normalize the vectors</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">233</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
@ -689,7 +690,7 @@ with handling extra outputs of switch feedforward module.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">235</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">counts</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">route_prob</span><span class="p">),</span> <span class="n">n_dropped</span></pre></div>
<div class="highlight"><pre><span class="lineno">236</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">counts</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">route_prob</span><span class="p">),</span> <span class="n">n_dropped</span></pre></div>
</div>
</div>
</div>

View File

@ -77,16 +77,17 @@
<a href="https://arxiv.org/abs/2101.03961">Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity</a>.
Our implementation only has a few million parameters and doesn&rsquo;t do model parallel distributed training.
It does single GPU training, but we implement the concept of switching as described in the paper.</p>
<p>The Switch Transformer uses different parameters for each token by switching among parameters,
based on the token. So only a fraction of parameters is chosen for each token, so you
can have more parameters but less computational cost.</p>
<p>The Switch Transformer uses different parameters for each token by switching among parameters
based on the token.
Therefore, only a fraction of parameters are chosen for each token.
So you can have more parameters but less computational cost.</p>
<p>The switching happens at the Position-wise Feedforward network (FFN) of each transformer block.
Position-wise feedforward network is a two sequentially fully connected layers.
In switch transformer we have multiple FFNs (multiple experts),
Position-wise feedforward network consists of two sequentially fully connected layers.
In switch transformer we have multiple FFNs (multiple experts),
and we chose which one to use based on a router.
The outputs a set of probabilities for picking a FFN,
and we pick the one with the highest probability and only evaluates that.
So essentially the computational cost is same as having a single FFN.
The output is a set of probabilities for picking a FFN,
and we pick the one with the highest probability and only evaluate that.
So essentially the computational cost is the same as having a single FFN.
In our implementation this doesn&rsquo;t parallelize well when you have many or large FFNs since it&rsquo;s all
happening on a single GPU.
In a distributed setup you would have each FFN (each very large) on a different device.</p>

View File

@ -81,7 +81,7 @@ equal to the length of the sequence trained in parallel.
All these positions have a fixed positional encoding.
Transformer XL increases this attention span by letting
each of the positions pay attention to precalculated past embeddings.
For instance if the context length is $l$ it will keep the embeddings of
For instance if the context length is $l$, it will keep the embeddings of
all layers for previous batch of length $l$ and feed them to current step.
If we use fixed-positional encodings these pre-calculated embeddings will have
the same positions as the current context.

View File

@ -11,13 +11,13 @@ network parameters during training.
For example, let's say there are two layers $l_1$ and $l_2$.
During the beginning of the training $l_1$ outputs (inputs to $l_2$)
could be in distribution $\mathcal{N}(0.5, 1)$.
Then, after some training steps, it could move to $\mathcal{N}(0.5, 1)$.
Then, after some training steps, it could move to $\mathcal{N}(0.6, 1.5)$.
This is *internal covariate shift*.
Internal covariate shift will adversely affect training speed because the later layers
($l_2$ in the above example) has to adapt to this shifted distribution.
($l_2$ in the above example) have to adapt to this shifted distribution.
By stabilizing the distribution batch normalization minimizes the internal covariate shift.
By stabilizing the distribution, batch normalization minimizes the internal covariate shift.
## Normalization
@ -30,10 +30,10 @@ and be uncorrelated.
Normalizing outside the gradient computation using pre-computed (detached)
means and variances doesn't work. For instance. (ignoring variance), let
$$\hat{x} = x - \mathbb{E}[x]$$
where $x = u + b$ and $b$ is a trained bias.
and $\mathbb{E}[x]$ is outside gradient computation (pre-computed constant).
where $x = u + b$ and $b$ is a trained bias
and $\mathbb{E}[x]$ is an outside gradient computation (pre-computed constant).
Note that $\hat{x}$ has no effect of $b$.
Note that $\hat{x}$ has no effect on $b$.
Therefore,
$b$ will increase or decrease based
$\frac{\partial{\mathcal{L}}}{\partial x}$,
@ -45,7 +45,7 @@ The paper notes that similar explosions happen with variances.
Whitening is computationally expensive because you need to de-correlate and
the gradients must flow through the full whitening calculation.
The paper introduces simplified version which they call *Batch Normalization*.
The paper introduces a simplified version which they call *Batch Normalization*.
First simplification is that it normalizes each feature independently to have
zero mean and unit variance:
$$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}$$
@ -53,7 +53,7 @@ where $x = (x^{(1)} ... x^{(d)})$ is the $d$-dimensional input.
The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$
and variance $Var[x^{(k)}]$ from the mini-batch
for normalization; instead of calculating the mean and variance across whole dataset.
for normalization; instead of calculating the mean and variance across the whole dataset.
Normalizing each feature to zero mean and unit variance could affect what the layer
can represent.
@ -69,8 +69,8 @@ like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization.
So you can and should omit bias parameter in linear transforms right before the
batch normalization.
Batch normalization also makes the back propagation invariant to the scale of the weights.
And empirically it improves generalization, so it has regularization effects too.
Batch normalization also makes the back propagation invariant to the scale of the weights
and empirically it improves generalization, so it has regularization effects too.
## Inference
@ -81,8 +81,8 @@ and find the mean and variance, or you can use an estimate calculated during tra
The usual practice is to calculate an exponential moving average of
mean and variance during the training phase and use that for inference.
Here's [the training code](https://nn.labml.ai/normalization/layer_norm/mnist.html) and a notebook for training
a CNN classifier that use batch normalization for MNIST dataset.
Here's [the training code](mnist.html) and a notebook for training
a CNN classifier that uses batch normalization for MNIST dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/normalization/batch_norm/mnist.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=011254fe647011ebbb8e0242ac1c0002)

View File

@ -13,8 +13,9 @@ Our implementation only has a few million parameters and doesn't do model parall
It does single GPU training, but we implement the concept of switching as described in the paper.
The Switch Transformer uses different parameters for each token by switching among parameters
based on the token. Thererfore, only a fraction of parameters are chosen for each token. So you
can have more parameters but less computational cost.
based on the token.
Therefore, only a fraction of parameters are chosen for each token.
So you can have more parameters but less computational cost.
The switching happens at the Position-wise Feedforward network (FFN) of each transformer block.
Position-wise feedforward network consists of two sequentially fully connected layers.

View File

@ -5,17 +5,18 @@ This is a miniature [PyTorch](https://pytorch.org) implementation of the paper
Our implementation only has a few million parameters and doesn't do model parallel distributed training.
It does single GPU training, but we implement the concept of switching as described in the paper.
The Switch Transformer uses different parameters for each token by switching among parameters,
based on the token. So only a fraction of parameters is chosen for each token, so you
can have more parameters but less computational cost.
The Switch Transformer uses different parameters for each token by switching among parameters
based on the token.
Therefore, only a fraction of parameters are chosen for each token.
So you can have more parameters but less computational cost.
The switching happens at the Position-wise Feedforward network (FFN) of each transformer block.
Position-wise feedforward network is a two sequentially fully connected layers.
In switch transformer we have multiple FFNs (multiple experts),
Position-wise feedforward network consists of two sequentially fully connected layers.
In switch transformer we have multiple FFNs (multiple experts),
and we chose which one to use based on a router.
The outputs a set of probabilities for picking a FFN,
and we pick the one with the highest probability and only evaluates that.
So essentially the computational cost is same as having a single FFN.
The output is a set of probabilities for picking a FFN,
and we pick the one with the highest probability and only evaluate that.
So essentially the computational cost is the same as having a single FFN.
In our implementation this doesn't parallelize well when you have many or large FFNs since it's all
happening on a single GPU.
In a distributed setup you would have each FFN (each very large) on a different device.

View File

@ -9,7 +9,7 @@ equal to the length of the sequence trained in parallel.
All these positions have a fixed positional encoding.
Transformer XL increases this attention span by letting
each of the positions pay attention to precalculated past embeddings.
For instance if the context length is $l$ it will keep the embeddings of
For instance if the context length is $l$, it will keep the embeddings of
all layers for previous batch of length $l$ and feed them to current step.
If we use fixed-positional encodings these pre-calculated embeddings will have
the same positions as the current context.