Files
Varuna Jayasiri ed681b48b9 add docs
2021-06-07 15:03:34 +05:30

565 lines
32 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="This is an annotated implementation/tutorial of Pay Attention to MLPs (gMLP) in PyTorch."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Pay Attention to MLPs (gMLP)"/>
<meta name="twitter:description" content="This is an annotated implementation/tutorial of Pay Attention to MLPs (gMLP) in PyTorch."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/transformers/gmlp/index.html"/>
<meta property="og:title" content="Pay Attention to MLPs (gMLP)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Pay Attention to MLPs (gMLP)"/>
<meta property="og:description" content="This is an annotated implementation/tutorial of Pay Attention to MLPs (gMLP) in PyTorch."/>
<title>Pay Attention to MLPs (gMLP)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/transformers/gmlp/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">transformers</a>
<a class="parent" href="index.html">gmlp</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/transformers/gmlp/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Pay Attention to MLPs (gMLP)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://papers.labml.ai/paper/2105.08050">Pay Attention to MLPs</a>.</p>
<p>This paper introduces a Multilayer Perceptron (MLP) based architecture with gating,
which they name <strong>gMLP</strong>. It consists of a stack of $L$ <em>gMLP</em> blocks.</p>
<p>Here is <a href="experiment.html">the training code</a> for a gMLP model based autoregressive model.</p>
<p><a href="https://app.labml.ai/run/01bd941ac74c11eb890c1d9196651a4a"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">21</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span>
<span class="lineno">22</span>
<span class="lineno">23</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>gMLP Block</h2>
<p>Each block does the following transformations to input embeddings
$X \in \mathbb{R}^{n \times d}$ where $n$ is the sequence length
and $d$ is the dimensionality of the embeddings:</p>
<p>
<script type="math/tex; mode=display">\begin{align}
Z &= \sigma(XU) \\
\tilde{Z} &= s(Z) \\
Y &= \tilde{Z}V \\
\end{align}</script>
</p>
<p>where $V$ and $U$ are learnable projection weights.
$s(\cdot)$ is the Spacial Gating Unit defined below.
Output dimensionality of $s(\cdot)$ will be half of $Z$.
$\sigma$ is an activation function such as
<a href="https://pytorch.org/docs/stable/generated/torch.nn.GELU.html">GeLU</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">GMLPBlock</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p><code>d_model</code> is the dimensionality ($d$) of $X$
<code>d_ffn</code> is the dimensionality of $Z$
<code>seq_len</code> is the length of the token sequence ($n$)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">48</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ffn</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Normalization layer fro Pre-Norm</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_model</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Activation function $\sigma$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Projection layer for $Z = \sigma(XU)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ffn</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Spacial Gating Unit $s(\cdot)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="bp">self</span><span class="o">.</span><span class="n">sgu</span> <span class="o">=</span> <span class="n">SpacialGatingUnit</span><span class="p">(</span><span class="n">d_ffn</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Projection layer for $Y = \tilde{Z}V$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ffn</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>Embedding size (required by <a href="../models.html#Encoder">Encoder</a>.
We use the encoder module from transformer architecture and plug
<em>gMLP</em> block as a replacement for the <a href="../models.html#Encoder">Transformer Layer</a>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">d_model</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<ul>
<li><code>x</code> is the input embedding tensor $X$ of shape <code>[seq_len, batch_size, d_model]</code></li>
<li><code>mask</code> is a boolean mask of shape <code>[seq_len, seq_len, 1]</code> that controls the visibility of tokens
among each other.</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Keep a copy for shortcut connection</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="n">shortcut</span> <span class="o">=</span> <span class="n">x</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Normalize $X$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">79</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Projection and activation $Z = \sigma(XU)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">81</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">proj1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Spacial Gating Unit $\tilde{Z} = s(Z)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sgu</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Final projection $Y = \tilde{Z}V$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">85</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj2</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Add the shortcut connection</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="k">return</span> <span class="n">z</span> <span class="o">+</span> <span class="n">shortcut</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<h2>Spatial Gating Unit</h2>
<p>
<script type="math/tex; mode=display">s(Z) = Z_1 \odot f_{W,b}(Z_2)</script>
</p>
<p>where $f_{W,b}(Z) = W Z + b$ is a linear transformation along the sequence dimension,
and $\odot$ is element-wise multiplication.
$Z$ is split into to parts of equal size $Z_1$ and $Z_2$ along the channel dimension (embedding dimension).</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span><span class="k">class</span> <span class="nc">SpacialGatingUnit</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<ul>
<li><code>d_z</code> is the dimensionality of $Z$</li>
<li><code>seq_len</code> is the sequence length</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">101</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_z</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Normalization layer before applying $f_{W,b}(\cdot)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">([</span><span class="n">d_z</span> <span class="o">//</span> <span class="mi">2</span><span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Weight $W$ in $f_{W,b}(\cdot)$.</p>
<p>The paper notes that it&rsquo;s important to initialize weights to small values and the bias to $1$,
so that during the initial training $s(\cdot)$ is close to identity (apart from the split).</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">)</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="o">-</span><span class="mf">0.01</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Weight $b$ in $f_{W,b}(\cdot)$</p>
<p>The paper notes that it&rsquo;s important to initialize bias to $1$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">117</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">seq_len</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<ul>
<li><code>z</code> is the input $Z$ of shape <code>[seq_len, batch_size, d_z]</code></li>
<li><code>mask</code> is is a boolean mask of shape <code>[seq_len, seq_len, 1]</code> that controls the visibility of tokens
among each other. The last dimension of size <code>1</code> is the batch, which we have in other transformer
implementations and was left for compatibility.</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>Get sequence length</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">128</span> <span class="n">seq_len</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Split $Z$ into $Z_1$ and $Z_2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">130</span> <span class="n">z1</span><span class="p">,</span> <span class="n">z2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Check mask</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">133</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p><code>mask</code> has shape <code>[seq_len_q, seq_len_k, batch_size]</code>.
The batch dimension should be of size <code>1</code> because this implementation supports
only same mask for all samples in the batch.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">137</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">seq_len</span>
<span class="lineno">138</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">seq_len</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Here we only support the same mask for all samples</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Remove the batch dimension</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Normalize $Z_2$ before $f_{W,b}(\cdot)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">z2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">z2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Get the weight matrix; truncate if larger than <code>seq_len</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">weight</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">[:</span><span class="n">seq_len</span><span class="p">,</span> <span class="p">:</span><span class="n">seq_len</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Apply mask to the weights.</p>
<p>If $W_{i,j}$ is $0$ then $f_{W,b}(Z_2)_i$ will not get any information
from token $j$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">153</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span> <span class="o">*</span> <span class="n">mask</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>$f_{W,b}(Z_2) = W Z_2 + b$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">z2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ij,jbd-&gt;ibd&#39;</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">z2</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">[:</span><span class="n">seq_len</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>$Z_1 \odot f_{W,b}(Z_2)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">159</span> <span class="k">return</span> <span class="n">z1</span> <span class="o">*</span> <span class="n">z2</span></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>