|
|
|
@ -69,24 +69,24 @@
|
|
|
|
|
</div>
|
|
|
|
|
<h1>Graph Attention Networks v2 (GATv2)</h1>
|
|
|
|
|
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the GATv2 operator from the paper
|
|
|
|
|
<a href="https://arxiv.org/abs/2105.14491">How Attentive are Graph Attention Networks?</a>.</p>
|
|
|
|
|
<p>GATv2s work on graph data.
|
|
|
|
|
<a href="https://arxiv.org/abs/2105.14491">How Attentive are Graph Attention Networks?</a>.
|
|
|
|
|
GATv2s work on graph data.
|
|
|
|
|
A graph consists of nodes and edges connecting nodes.
|
|
|
|
|
For example, in Cora dataset the nodes are research papers and the edges are citations that
|
|
|
|
|
connect the papers.</p>
|
|
|
|
|
<p>The GATv2 operator which fixes the static attention problem of the standard GAT:
|
|
|
|
|
connect the papers.
|
|
|
|
|
The GATv2 operator fixes the static attention problem of the standard GAT:
|
|
|
|
|
since the linear layers in the standard GAT are applied right after each other, the ranking
|
|
|
|
|
of attended nodes is unconditioned on the query node.
|
|
|
|
|
In contrast, in GATv2, every node can attend to any other node.</p>
|
|
|
|
|
<p>Here is <a href="experiment.html">the training code</a> for training
|
|
|
|
|
a two-layer GATv2 on Cora dataset.</p>
|
|
|
|
|
<p><a href="https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
|
|
|
|
In contrast, in GATv2, every node can attend to any other node.
|
|
|
|
|
Here is <a href="experiment.html">the training code</a> for training
|
|
|
|
|
a two-layer GATv2 on Cora dataset.
|
|
|
|
|
<a href="https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">29</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
|
|
|
|
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
|
|
|
|
<span class="lineno">31</span>
|
|
|
|
|
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">23</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
|
|
|
|
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
|
|
|
|
<span class="lineno">25</span>
|
|
|
|
|
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-1'>
|
|
|
|
@ -96,8 +96,8 @@ a two-layer GATv2 on Cora dataset.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<h2>Graph attention v2 layer</h2>
|
|
|
|
|
<p>This is a single graph attention v2 layer.
|
|
|
|
|
A GATv2 is made up of multiple such layers.</p>
|
|
|
|
|
<p>It takes
|
|
|
|
|
A GATv2 is made up of multiple such layers.
|
|
|
|
|
It takes
|
|
|
|
|
<script type="math/tex; mode=display">\mathbf{h} = \{ \overrightarrow{h_1}, \overrightarrow{h_2}, \dots, \overrightarrow{h_N} \}</script>,
|
|
|
|
|
where $\overrightarrow{h_i} \in \mathbb{R}^F$ as input
|
|
|
|
|
and outputs
|
|
|
|
@ -105,7 +105,7 @@ and outputs
|
|
|
|
|
where $\overrightarrow{h’_i} \in \mathbb{R}^{F’}$.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">GraphAttentionV2Layer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">29</span><span class="k">class</span> <span class="nc">GraphAttentionV2Layer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-2'>
|
|
|
|
@ -124,11 +124,11 @@ where $\overrightarrow{h’_i} \in \mathbb{R}^{F’}$.</p>
|
|
|
|
|
</ul>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">50</span> <span class="n">is_concat</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">51</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.6</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">52</span> <span class="n">leaky_relu_negative_slope</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">53</span> <span class="n">share_weights</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">41</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">42</span> <span class="n">is_concat</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">43</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.6</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">44</span> <span class="n">leaky_relu_negative_slope</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">,</span>
|
|
|
|
|
<span class="lineno">45</span> <span class="n">share_weights</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-3'>
|
|
|
|
@ -139,11 +139,11 @@ where $\overrightarrow{h’_i} \in \mathbb{R}^{F’}$.</p>
|
|
|
|
|
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">63</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
|
|
|
<span class="lineno">64</span>
|
|
|
|
|
<span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span> <span class="o">=</span> <span class="n">is_concat</span>
|
|
|
|
|
<span class="lineno">66</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
|
|
|
|
|
<span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">share_weights</span> <span class="o">=</span> <span class="n">share_weights</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">55</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
|
|
|
<span class="lineno">56</span>
|
|
|
|
|
<span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span> <span class="o">=</span> <span class="n">is_concat</span>
|
|
|
|
|
<span class="lineno">58</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
|
|
|
|
|
<span class="lineno">59</span> <span class="bp">self</span><span class="o">.</span><span class="n">share_weights</span> <span class="o">=</span> <span class="n">share_weights</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-4'>
|
|
|
|
@ -154,8 +154,8 @@ where $\overrightarrow{h’_i} \in \mathbb{R}^{F’}$.</p>
|
|
|
|
|
<p>Calculate the number of dimensions per head</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">70</span> <span class="k">if</span> <span class="n">is_concat</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">71</span> <span class="k">assert</span> <span class="n">out_features</span> <span class="o">%</span> <span class="n">n_heads</span> <span class="o">==</span> <span class="mi">0</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">if</span> <span class="n">is_concat</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">63</span> <span class="k">assert</span> <span class="n">out_features</span> <span class="o">%</span> <span class="n">n_heads</span> <span class="o">==</span> <span class="mi">0</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-5'>
|
|
|
|
@ -166,8 +166,8 @@ where $\overrightarrow{h’_i} \in \mathbb{R}^{F’}$.</p>
|
|
|
|
|
<p>If we are concatenating the multiple heads</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span> <span class="o">//</span> <span class="n">n_heads</span>
|
|
|
|
|
<span class="lineno">74</span> <span class="k">else</span><span class="p">:</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span> <span class="o">//</span> <span class="n">n_heads</span>
|
|
|
|
|
<span class="lineno">66</span> <span class="k">else</span><span class="p">:</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-6'>
|
|
|
|
@ -178,7 +178,7 @@ where $\overrightarrow{h’_i} \in \mathbb{R}^{F’}$.</p>
|
|
|
|
|
<p>If we are averaging the multiple heads</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-7'>
|
|
|
|
@ -190,7 +190,7 @@ where $\overrightarrow{h’_i} \in \mathbb{R}^{F’}$.</p>
|
|
|
|
|
i.e. to transform the source node embeddings before self-attention</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">*</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">72</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">*</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-8'>
|
|
|
|
@ -201,10 +201,10 @@ i.e. to transform the source node embeddings before self-attention</p>
|
|
|
|
|
<p>If <code>share_weights is True</code> the same linear layer is used for the target nodes</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">82</span> <span class="k">if</span> <span class="n">share_weights</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span>
|
|
|
|
|
<span class="lineno">84</span> <span class="k">else</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">heads</span> <span class="o">*</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">74</span> <span class="k">if</span> <span class="n">share_weights</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span>
|
|
|
|
|
<span class="lineno">76</span> <span class="k">else</span><span class="p">:</span>
|
|
|
|
|
<span class="lineno">77</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">heads</span> <span class="o">*</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-9'>
|
|
|
|
@ -215,7 +215,7 @@ i.e. to transform the source node embeddings before self-attention</p>
|
|
|
|
|
<p>Linear layer to compute attention score $e_{ij}$</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">87</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">79</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-10'>
|
|
|
|
@ -226,7 +226,7 @@ i.e. to transform the source node embeddings before self-attention</p>
|
|
|
|
|
<p>The activation for attention score $e_{ij}$</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">89</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="n">leaky_relu_negative_slope</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="n">leaky_relu_negative_slope</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-11'>
|
|
|
|
@ -237,7 +237,7 @@ i.e. to transform the source node embeddings before self-attention</p>
|
|
|
|
|
<p>Softmax to compute attention $\alpha_{ij}$</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">91</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-12'>
|
|
|
|
@ -248,7 +248,7 @@ i.e. to transform the source node embeddings before self-attention</p>
|
|
|
|
|
<p>Dropout layer to be applied for attention</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-13'>
|
|
|
|
@ -259,13 +259,13 @@ i.e. to transform the source node embeddings before self-attention</p>
|
|
|
|
|
<ul>
|
|
|
|
|
<li><code>h</code>, $\mathbf{h}$ is the input node embeddings of shape <code>[n_nodes, in_features]</code>.</li>
|
|
|
|
|
<li><code>adj_mat</code> is the adjacency matrix of shape <code>[n_nodes, n_nodes, n_heads]</code>.
|
|
|
|
|
We use shape <code>[n_nodes, n_nodes, 1]</code> since the adjacency is the same for each head.</li>
|
|
|
|
|
We use shape <code>[n_nodes, n_nodes, 1]</code> since the adjacency is the same for each head.
|
|
|
|
|
Adjacency matrix represent the edges (or connections) among nodes.
|
|
|
|
|
<code>adj_mat[i][j]</code> is <code>True</code> if there is an edge from node <code>i</code> to node <code>j</code>.</li>
|
|
|
|
|
</ul>
|
|
|
|
|
<p>Adjacency matrix represent the edges (or connections) among nodes.
|
|
|
|
|
<code>adj_mat[i][j]</code> is <code>True</code> if there is an edge from node <code>i</code> to node <code>j</code>.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">adj_mat</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">87</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">adj_mat</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-14'>
|
|
|
|
@ -276,7 +276,7 @@ We use shape <code>[n_nodes, n_nodes, 1]</code> since the adjacency is the same
|
|
|
|
|
<p>Number of nodes</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">106</span> <span class="n">n_nodes</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">n_nodes</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-15'>
|
|
|
|
@ -291,8 +291,8 @@ for each head.
|
|
|
|
|
We do two linear transformations and then split it up for each head.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">112</span> <span class="n">g_l</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">113</span> <span class="n">g_r</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">g_l</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span>
|
|
|
|
|
<span class="lineno">104</span> <span class="n">g_r</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-16'>
|
|
|
|
@ -311,14 +311,19 @@ We calculate this for each head.</p>
|
|
|
|
|
<p>$a$ is the attention mechanism, that calculates the attention score.
|
|
|
|
|
The paper sums
|
|
|
|
|
$\overrightarrow{{g_l}_i}$, $\overrightarrow{{g_r}_j}$
|
|
|
|
|
followed by a $\text{LeakyReLU}$
|
|
|
|
|
followed by a $\text{LeakyReLU}$
|
|
|
|
|
and does a linear transformation with a weight vector $\mathbf{a} \in \mathbb{R}^{F’}$</p>
|
|
|
|
|
<p>
|
|
|
|
|
<script type="math/tex; mode=display">e_{ij} = \mathbf{a}^\top \text{LeakyReLU} \Big(
|
|
|
|
|
\Big[
|
|
|
|
|
\overrightarrow{{g_l}_i} + \overrightarrow{{g_r}_j}
|
|
|
|
|
\Big] \Big)</script>
|
|
|
|
|
</p>
|
|
|
|
|
Note: The paper desrcibes $e_{ij}$ as <br />
|
|
|
|
|
<script type="math/tex; mode=display">e_{ij} = \mathbf{a}^\top \text{LeakyReLU} \Big( \mathbf{W}
|
|
|
|
|
\Big[
|
|
|
|
|
\overrightarrow{h_i} \Vert \overrightarrow{h_j}
|
|
|
|
|
\Big] \Big)</script>
|
|
|
|
|
which is equivalent to the definition we use here.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre></pre></div>
|
|
|
|
@ -338,7 +343,7 @@ for all pairs of $i, j$.</p>
|
|
|
|
|
where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">g_l_repeat</span> <span class="o">=</span> <span class="n">g_l</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">g_l_repeat</span> <span class="o">=</span> <span class="n">g_l</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-18'>
|
|
|
|
@ -352,7 +357,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">g_r_repeat_interleave</span> <span class="o">=</span> <span class="n">g_r</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">g_r_repeat_interleave</span> <span class="o">=</span> <span class="n">g_r</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-19'>
|
|
|
|
@ -360,7 +365,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
<div class='section-link'>
|
|
|
|
|
<a href='#section-19'>#</a>
|
|
|
|
|
</div>
|
|
|
|
|
<p>Now we sum to get
|
|
|
|
|
<p>Now we add the two tensors to get
|
|
|
|
|
<script type="math/tex; mode=display">\{\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1},
|
|
|
|
|
\overrightarrow{{g_l}_1}, + \overrightarrow{{g_r}_2},
|
|
|
|
|
\dots, \overrightarrow{{g_l}_1} +\overrightarrow{{g_r}_N},
|
|
|
|
@ -370,7 +375,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">g_sum</span> <span class="o">=</span> <span class="n">g_l_repeat</span> <span class="o">+</span> <span class="n">g_r_repeat_interleave</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">155</span> <span class="n">g_sum</span> <span class="o">=</span> <span class="n">g_l_repeat</span> <span class="o">+</span> <span class="n">g_r_repeat_interleave</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-20'>
|
|
|
|
@ -381,7 +386,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
<p>Reshape so that <code>g_sum[i, j]</code> is $\overrightarrow{{g_l}_i} + \overrightarrow{{g_r}_j}$</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">g_sum</span> <span class="o">=</span> <span class="n">g_sum</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">g_sum</span> <span class="o">=</span> <span class="n">g_sum</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-21'>
|
|
|
|
@ -397,7 +402,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
<code>e</code> is of shape <code>[n_nodes, n_nodes, n_heads, 1]</code></p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">g_sum</span><span class="p">))</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">g_sum</span><span class="p">))</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-22'>
|
|
|
|
@ -408,7 +413,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
<p>Remove the last dimension of size <code>1</code></p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">170</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-23'>
|
|
|
|
@ -420,9 +425,9 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
<code>[n_nodes, n_nodes, n_heads]</code> or<code>[n_nodes, n_nodes, 1]</code></p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
|
|
|
|
|
<span class="lineno">175</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
|
|
|
|
|
<span class="lineno">176</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">171</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
|
|
|
|
|
<span class="lineno">172</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
|
|
|
|
|
<span class="lineno">173</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-24'>
|
|
|
|
@ -434,7 +439,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
|
|
|
|
|
$e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">adj_mat</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">))</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">adj_mat</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">))</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-25'>
|
|
|
|
@ -451,7 +456,7 @@ $e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.</p>
|
|
|
|
|
makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">e</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">186</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">e</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-26'>
|
|
|
|
@ -462,7 +467,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
|
|
|
|
|
<p>Apply dropout regularization</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">192</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">a</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">a</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-27'>
|
|
|
|
@ -475,7 +480,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
|
|
|
|
|
</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">attn_res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'ijh,jhf->ihf'</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">g_r</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">attn_res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">'ijh,jhf->ihf'</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">g_r</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-28'>
|
|
|
|
@ -486,7 +491,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
|
|
|
|
|
<p>Concatenate the heads</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">199</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span><span class="p">:</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span><span class="p">:</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-29'>
|
|
|
|
@ -499,7 +504,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
|
|
|
|
|
</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">201</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-30'>
|
|
|
|
@ -510,7 +515,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
|
|
|
|
|
<p>Take the mean of the heads</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">else</span><span class="p">:</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">else</span><span class="p">:</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='section' id='section-31'>
|
|
|
|
@ -523,7 +528,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
|
|
|
|
|
</p>
|
|
|
|
|
</div>
|
|
|
|
|
<div class='code'>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">205</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
<div class="highlight"><pre><span class="lineno">202</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|