Files
2021-08-19 15:56:26 +05:30

579 lines
34 KiB
HTML

<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A PyTorch implementation/tutorial of Graph Attention Networks."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Graph Attention Networks (GAT)"/>
<meta name="twitter:description" content="A PyTorch implementation/tutorial of Graph Attention Networks."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/graphs/gat/index.html"/>
<meta property="og:title" content="Graph Attention Networks (GAT)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Graph Attention Networks (GAT)"/>
<meta property="og:description" content="A PyTorch implementation/tutorial of Graph Attention Networks."/>
<title>Graph Attention Networks (GAT)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/graphs/gat/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">graphs</a>
<a class="parent" href="index.html">gat</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/graphs/gat/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Graph Attention Networks (GAT)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://papers.labml.ai/paper/1710.10903">Graph Attention Networks</a>.</p>
<p>GATs work on graph data.
A graph consists of nodes and edges connecting nodes.
For example, in Cora dataset the nodes are research papers and the edges are citations that
connect the papers.</p>
<p>GAT uses masked self-attention, kind of similar to <a href="../../transformers/mha.html">transformers</a>.
GAT consists of graph attention layers stacked on top of each other.
Each graph attention layer gets node embeddings as inputs and outputs transformed embeddings.
The node embeddings pay attention to the embeddings of other nodes it&rsquo;s connected to.
The details of graph attention layers are included alongside the implementation.</p>
<p>Here is <a href="experiment.html">the training code</a> for training
a two-layer GAT on Cora dataset.</p>
<p><a href="https://app.labml.ai/run/d6c636cadf3511eba2f1e707f612f95d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">30</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">31</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">32</span>
<span class="lineno">33</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Graph attention layer</h2>
<p>This is a single graph attention layer.
A GAT is made up of multiple such layers.</p>
<p>It takes
<script type="math/tex; mode=display">\mathbf{h} = \{ \overrightarrow{h_1}, \overrightarrow{h_2}, \dots, \overrightarrow{h_N} \}</script>,
where $\overrightarrow{h_i} \in \mathbb{R}^F$ as input
and outputs
<script type="math/tex; mode=display">\mathbf{h'} = \{ \overrightarrow{h'_1}, \overrightarrow{h'_2}, \dots, \overrightarrow{h'_N} \}</script>,
where $\overrightarrow{h&rsquo;_i} \in \mathbb{R}^{F&rsquo;}$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">36</span><span class="k">class</span> <span class="nc">GraphAttentionLayer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul>
<li><code>in_features</code>, $F$, is the number of input features per node</li>
<li><code>out_features</code>, $F&rsquo;$, is the number of output features per node</li>
<li><code>n_heads</code>, $K$, is the number of attention heads</li>
<li><code>is_concat</code> whether the multi-head results should be concatenated or averaged</li>
<li><code>dropout</code> is the dropout probability</li>
<li><code>leaky_relu_negative_slope</code> is the negative slope for leaky relu activation</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">50</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">51</span> <span class="n">is_concat</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">52</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.6</span><span class="p">,</span>
<span class="lineno">53</span> <span class="n">leaky_relu_negative_slope</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">63</span>
<span class="lineno">64</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span> <span class="o">=</span> <span class="n">is_concat</span>
<span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Calculate the number of dimensions per head</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="k">if</span> <span class="n">is_concat</span><span class="p">:</span>
<span class="lineno">69</span> <span class="k">assert</span> <span class="n">out_features</span> <span class="o">%</span> <span class="n">n_heads</span> <span class="o">==</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>If we are concatenating the multiple heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span> <span class="o">//</span> <span class="n">n_heads</span>
<span class="lineno">72</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>If we are averaging the multiple heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Linear layer for initial transformation;
i.e. to transform the node embeddings before self-attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">*</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>Linear layer to compute attention score $e_{ij}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>The activation for attention score $e_{ij}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="n">leaky_relu_negative_slope</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Softmax to compute attention $\alpha_{ij}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">84</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Dropout layer to be applied for attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">86</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<ul>
<li><code>h</code>, $\mathbf{h}$ is the input node embeddings of shape <code>[n_nodes, in_features]</code>.</li>
<li><code>adj_mat</code> is the adjacency matrix of shape <code>[n_nodes, n_nodes, n_heads]</code>.
We use shape <code>[n_nodes, n_nodes, 1]</code> since the adjacency is the same for each head.</li>
</ul>
<p>Adjacency matrix represent the edges (or connections) among nodes.
<code>adj_mat[i][j]</code> is <code>True</code> if there is an edge from node <code>i</code> to node <code>j</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">adj_mat</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Number of nodes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="n">n_nodes</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>The initial transformation,
<script type="math/tex; mode=display">\overrightarrow{g^k_i} = \mathbf{W}^k \overrightarrow{h_i}</script>
for each head.
We do single linear transformation and then split it up for each head.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="n">g</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<h4>Calculate attention score</h4>
<p>We calculate these for each head $k$. <em>We have omitted $\cdot^k$ for simplicity</em>.</p>
<p>
<script type="math/tex; mode=display">e_{ij} = a(\mathbf{W} \overrightarrow{h_i}, \mathbf{W} \overrightarrow{h_j}) =
a(\overrightarrow{g_i}, \overrightarrow{g_j})</script>
</p>
<p>$e_{ij}$ is the attention score (importance) from node $j$ to node $i$.
We calculate this for each head.</p>
<p>$a$ is the attention mechanism, that calculates the attention score.
The paper concatenates
$\overrightarrow{g_i}$, $\overrightarrow{g_j}$
and does a linear transformation with a weight vector $\mathbf{a} \in \mathbb{R}^{2 F&rsquo;}$
followed by a $\text{LeakyReLU}$.</p>
<p>
<script type="math/tex; mode=display">e_{ij} = \text{LeakyReLU} \Big(
\mathbf{a}^\top \Big[
\overrightarrow{g_i} \Vert \overrightarrow{g_j}
\Big] \Big)</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>First we calculate
$\Big[\overrightarrow{g_i} \Vert \overrightarrow{g_j} \Big]$
for all pairs of $i, j$.</p>
<p><code>g_repeat</code> gets
<script type="math/tex; mode=display">\{\overrightarrow{g_1}, \overrightarrow{g_2}, \dots, \overrightarrow{g_N},
\overrightarrow{g_1}, \overrightarrow{g_2}, \dots, \overrightarrow{g_N}, ...\}</script>
where each node embedding is repeated <code>n_nodes</code> times.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">135</span> <span class="n">g_repeat</span> <span class="o">=</span> <span class="n">g</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p><code>g_repeat_interleave</code> gets
<script type="math/tex; mode=display">\{\overrightarrow{g_1}, \overrightarrow{g_1}, \dots, \overrightarrow{g_1},
\overrightarrow{g_2}, \overrightarrow{g_2}, \dots, \overrightarrow{g_2}, ...\}</script>
where each node embedding is repeated <code>n_nodes</code> times.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="n">g_repeat_interleave</span> <span class="o">=</span> <span class="n">g</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>Now we concatenate to get
<script type="math/tex; mode=display">\{\overrightarrow{g_1} \Vert \overrightarrow{g_1},
\overrightarrow{g_1}, \Vert \overrightarrow{g_2},
\dots, \overrightarrow{g_1} \Vert \overrightarrow{g_N},
\overrightarrow{g_2} \Vert \overrightarrow{g_1},
\overrightarrow{g_2}, \Vert \overrightarrow{g_2},
\dots, \overrightarrow{g_2} \Vert \overrightarrow{g_N}, ...\}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="n">g_concat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">g_repeat_interleave</span><span class="p">,</span> <span class="n">g_repeat</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Reshape so that <code>g_concat[i, j]</code> is $\overrightarrow{g_i} \Vert \overrightarrow{g_j}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">g_concat</span> <span class="o">=</span> <span class="n">g_concat</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Calculate
<script type="math/tex; mode=display">e_{ij} = \text{LeakyReLU} \Big(
\mathbf{a}^\top \Big[
\overrightarrow{g_i} \Vert \overrightarrow{g_j}
\Big] \Big)</script>
<code>e</code> is of shape <code>[n_nodes, n_nodes, n_heads, 1]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="n">g_concat</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Remove the last dimension of size <code>1</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>The adjacency matrix should have shape
<code>[n_nodes, n_nodes, n_heads]</code> or<code>[n_nodes, n_nodes, 1]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
<span class="lineno">165</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
<span class="lineno">166</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Mask $e_{ij}$ based on adjacency matrix.
$e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">169</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">adj_mat</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>We then normalize attention scores (or coefficients)
<script type="math/tex; mode=display">\alpha_{ij} = \text{softmax}_j(e_{ij}) =
\frac{\exp(e_{ij})}{\sum_{j \in \mathcal{N}_i} \exp(e_{ij})}</script>
</p>
<p>where $\mathcal{N}_i$ is the set of nodes connected to $i$.</p>
<p>We do this by setting unconnected $e_{ij}$ to $- \infty$ which
makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">e</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Apply dropout regularization</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">a</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>Calculate final output for each head
<script type="math/tex; mode=display">\overrightarrow{h'^k_i} = \sum_{j \in \mathcal{N}_i} \alpha^k_{ij} \overrightarrow{g^k_j}</script>
</p>
<p><em>Note:</em> The paper includes the final activation $\sigma$ in $\overrightarrow{h_i}$
We have omitted this from the Graph Attention Layer implementation
and use it on the GAT model to match with how other PyTorch modules are defined -
activation as a separate layer.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">191</span> <span class="n">attn_res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ijh,jhf-&gt;ihf&#39;</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">g</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Concatenate the heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\overrightarrow{h'_i} = \Bigg\Vert_{k=1}^{K} \overrightarrow{h'^k_i}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Take the mean of the heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\overrightarrow{h'_i} = \frac{1}{K} \sum_{k=1}^{K} \overrightarrow{h'^k_i}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>