gatv2 docs

This commit is contained in:
Varuna Jayasiri
2021-07-26 13:56:03 +05:30
parent 671a93c299
commit f22853f610
7 changed files with 88 additions and 83 deletions

View File

@ -7,20 +7,20 @@
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Distilling the Knowledge in a Neural Network)"/>
<meta name="twitter:title" content="Distilling the Knowledge in a Neural Network"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/distillation/readme.html"/>
<meta property="og:title" content="Distilling the Knowledge in a Neural Network)"/>
<meta property="og:title" content="Distilling the Knowledge in a Neural Network"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Distilling the Knowledge in a Neural Network)"/>
<meta property="og:title" content="Distilling the Knowledge in a Neural Network"/>
<meta property="og:description" content=""/>
<title>Distilling the Knowledge in a Neural Network)</title>
<title>Distilling the Knowledge in a Neural Network</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/distillation/readme.html"/>
@ -66,7 +66,7 @@
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="(https://nn.labml.ai/distillation/index.html)">Distilling the Knowledge in a Neural Network</a></h1>
<h1><a href="https://nn.labml.ai/distillation/index.html">Distilling the Knowledge in a Neural Network</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of the paper
<a href="https://papers.labml.ai/paper/1503.02531">Distilling the Knowledge in a Neural Network</a>.</p>
<p>It&rsquo;s a way of training a small network using the knowledge in a trained larger network;

View File

@ -68,7 +68,7 @@
<a href='#section-0'>#</a>
</div>
<h1>Train a Graph Attention Network v2 (GATv2) on Cora dataset</h1>
<p><a href="https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
<p><a href="https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span>
@ -609,7 +609,7 @@ from $j$ to $i$.</p>
<p>Dropout probability</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.6</span></pre></div>
<div class="highlight"><pre><span class="lineno">181</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.7</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>

View File

@ -69,24 +69,24 @@
</div>
<h1>Graph Attention Networks v2 (GATv2)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the GATv2 operator from the paper
<a href="https://arxiv.org/abs/2105.14491">How Attentive are Graph Attention Networks?</a>.</p>
<p>GATv2s work on graph data.
<a href="https://arxiv.org/abs/2105.14491">How Attentive are Graph Attention Networks?</a>.
GATv2s work on graph data.
A graph consists of nodes and edges connecting nodes.
For example, in Cora dataset the nodes are research papers and the edges are citations that
connect the papers.</p>
<p>The GATv2 operator which fixes the static attention problem of the standard GAT:
connect the papers.
The GATv2 operator fixes the static attention problem of the standard GAT:
since the linear layers in the standard GAT are applied right after each other, the ranking
of attended nodes is unconditioned on the query node.
In contrast, in GATv2, every node can attend to any other node.</p>
<p>Here is <a href="experiment.html">the training code</a> for training
a two-layer GATv2 on Cora dataset.</p>
<p><a href="https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
In contrast, in GATv2, every node can attend to any other node.
Here is <a href="experiment.html">the training code</a> for training
a two-layer GATv2 on Cora dataset.
<a href="https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">29</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">30</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">31</span>
<span class="lineno">32</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
<div class="highlight"><pre><span class="lineno">23</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">25</span>
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
@ -96,8 +96,8 @@ a two-layer GATv2 on Cora dataset.</p>
</div>
<h2>Graph attention v2 layer</h2>
<p>This is a single graph attention v2 layer.
A GATv2 is made up of multiple such layers.</p>
<p>It takes
A GATv2 is made up of multiple such layers.
It takes
<script type="math/tex; mode=display">\mathbf{h} = \{ \overrightarrow{h_1}, \overrightarrow{h_2}, \dots, \overrightarrow{h_N} \}</script>,
where $\overrightarrow{h_i} \in \mathbb{R}^F$ as input
and outputs
@ -105,7 +105,7 @@ and outputs
where $\overrightarrow{h&rsquo;_i} \in \mathbb{R}^{F&rsquo;}$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">35</span><span class="k">class</span> <span class="nc">GraphAttentionV2Layer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">29</span><span class="k">class</span> <span class="nc">GraphAttentionV2Layer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
@ -124,11 +124,11 @@ where $\overrightarrow{h&rsquo;_i} \in \mathbb{R}^{F&rsquo;}$.</p>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">50</span> <span class="n">is_concat</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">51</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.6</span><span class="p">,</span>
<span class="lineno">52</span> <span class="n">leaky_relu_negative_slope</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">,</span>
<span class="lineno">53</span> <span class="n">share_weights</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">41</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">42</span> <span class="n">is_concat</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="lineno">43</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.6</span><span class="p">,</span>
<span class="lineno">44</span> <span class="n">leaky_relu_negative_slope</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">,</span>
<span class="lineno">45</span> <span class="n">share_weights</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
@ -139,11 +139,11 @@ where $\overrightarrow{h&rsquo;_i} \in \mathbb{R}^{F&rsquo;}$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">64</span>
<span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span> <span class="o">=</span> <span class="n">is_concat</span>
<span class="lineno">66</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
<span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">share_weights</span> <span class="o">=</span> <span class="n">share_weights</span></pre></div>
<div class="highlight"><pre><span class="lineno">55</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">56</span>
<span class="lineno">57</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span> <span class="o">=</span> <span class="n">is_concat</span>
<span class="lineno">58</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
<span class="lineno">59</span> <span class="bp">self</span><span class="o">.</span><span class="n">share_weights</span> <span class="o">=</span> <span class="n">share_weights</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
@ -154,8 +154,8 @@ where $\overrightarrow{h&rsquo;_i} \in \mathbb{R}^{F&rsquo;}$.</p>
<p>Calculate the number of dimensions per head</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">70</span> <span class="k">if</span> <span class="n">is_concat</span><span class="p">:</span>
<span class="lineno">71</span> <span class="k">assert</span> <span class="n">out_features</span> <span class="o">%</span> <span class="n">n_heads</span> <span class="o">==</span> <span class="mi">0</span></pre></div>
<div class="highlight"><pre><span class="lineno">62</span> <span class="k">if</span> <span class="n">is_concat</span><span class="p">:</span>
<span class="lineno">63</span> <span class="k">assert</span> <span class="n">out_features</span> <span class="o">%</span> <span class="n">n_heads</span> <span class="o">==</span> <span class="mi">0</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
@ -166,8 +166,8 @@ where $\overrightarrow{h&rsquo;_i} \in \mathbb{R}^{F&rsquo;}$.</p>
<p>If we are concatenating the multiple heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span> <span class="o">//</span> <span class="n">n_heads</span>
<span class="lineno">74</span> <span class="k">else</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">65</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span> <span class="o">//</span> <span class="n">n_heads</span>
<span class="lineno">66</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
@ -178,7 +178,7 @@ where $\overrightarrow{h&rsquo;_i} \in \mathbb{R}^{F&rsquo;}$.</p>
<p>If we are averaging the multiple heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
<div class="highlight"><pre><span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">=</span> <span class="n">out_features</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
@ -190,7 +190,7 @@ where $\overrightarrow{h&rsquo;_i} \in \mathbb{R}^{F&rsquo;}$.</p>
i.e. to transform the source node embeddings before self-attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">*</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">72</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span> <span class="o">*</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
@ -201,10 +201,10 @@ i.e. to transform the source node embeddings before self-attention</p>
<p>If <code>share_weights is True</code> the same linear layer is used for the target nodes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="k">if</span> <span class="n">share_weights</span><span class="p">:</span>
<span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span>
<span class="lineno">84</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">heads</span> <span class="o">*</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">74</span> <span class="k">if</span> <span class="n">share_weights</span><span class="p">:</span>
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span>
<span class="lineno">76</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">77</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">heads</span> <span class="o">*</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
@ -215,7 +215,7 @@ i.e. to transform the source node embeddings before self-attention</p>
<p>Linear layer to compute attention score $e_{ij}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">87</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">79</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
@ -226,7 +226,7 @@ i.e. to transform the source node embeddings before self-attention</p>
<p>The activation for attention score $e_{ij}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">89</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="n">leaky_relu_negative_slope</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">81</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="n">leaky_relu_negative_slope</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
@ -237,7 +237,7 @@ i.e. to transform the source node embeddings before self-attention</p>
<p>Softmax to compute attention $\alpha_{ij}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
@ -248,7 +248,7 @@ i.e. to transform the source node embeddings before self-attention</p>
<p>Dropout layer to be applied for attention</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">85</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
@ -259,13 +259,13 @@ i.e. to transform the source node embeddings before self-attention</p>
<ul>
<li><code>h</code>, $\mathbf{h}$ is the input node embeddings of shape <code>[n_nodes, in_features]</code>.</li>
<li><code>adj_mat</code> is the adjacency matrix of shape <code>[n_nodes, n_nodes, n_heads]</code>.
We use shape <code>[n_nodes, n_nodes, 1]</code> since the adjacency is the same for each head.</li>
We use shape <code>[n_nodes, n_nodes, 1]</code> since the adjacency is the same for each head.
Adjacency matrix represent the edges (or connections) among nodes.
<code>adj_mat[i][j]</code> is <code>True</code> if there is an edge from node <code>i</code> to node <code>j</code>.</li>
</ul>
<p>Adjacency matrix represent the edges (or connections) among nodes.
<code>adj_mat[i][j]</code> is <code>True</code> if there is an edge from node <code>i</code> to node <code>j</code>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">adj_mat</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
<div class="highlight"><pre><span class="lineno">87</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">adj_mat</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
@ -276,7 +276,7 @@ We use shape <code>[n_nodes, n_nodes, 1]</code> since the adjacency is the same
<p>Number of nodes</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="n">n_nodes</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">n_nodes</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
@ -291,8 +291,8 @@ for each head.
We do two linear transformations and then split it up for each head.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="n">g_l</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span>
<span class="lineno">113</span> <span class="n">g_r</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">103</span> <span class="n">g_l</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_l</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span>
<span class="lineno">104</span> <span class="n">g_r</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_r</span><span class="p">(</span><span class="n">h</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
@ -311,14 +311,19 @@ We calculate this for each head.</p>
<p>$a$ is the attention mechanism, that calculates the attention score.
The paper sums
$\overrightarrow{{g_l}_i}$, $\overrightarrow{{g_r}_j}$
followed by a $\text{LeakyReLU}$
followed by a $\text{LeakyReLU}$
and does a linear transformation with a weight vector $\mathbf{a} \in \mathbb{R}^{F&rsquo;}$</p>
<p>
<script type="math/tex; mode=display">e_{ij} = \mathbf{a}^\top \text{LeakyReLU} \Big(
\Big[
\overrightarrow{{g_l}_i} + \overrightarrow{{g_r}_j}
\Big] \Big)</script>
</p>
Note: The paper desrcibes $e_{ij}$ as <br />
<script type="math/tex; mode=display">e_{ij} = \mathbf{a}^\top \text{LeakyReLU} \Big( \mathbf{W}
\Big[
\overrightarrow{h_i} \Vert \overrightarrow{h_j}
\Big] \Big)</script>
which is equivalent to the definition we use here.</p>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
@ -338,7 +343,7 @@ for all pairs of $i, j$.</p>
where each node embedding is repeated <code>n_nodes</code> times.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">145</span> <span class="n">g_l_repeat</span> <span class="o">=</span> <span class="n">g_l</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">g_l_repeat</span> <span class="o">=</span> <span class="n">g_l</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
@ -352,7 +357,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
where each node embedding is repeated <code>n_nodes</code> times.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">g_r_repeat_interleave</span> <span class="o">=</span> <span class="n">g_r</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">147</span> <span class="n">g_r_repeat_interleave</span> <span class="o">=</span> <span class="n">g_r</span><span class="o">.</span><span class="n">repeat_interleave</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
@ -360,7 +365,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Now we sum to get
<p>Now we add the two tensors to get
<script type="math/tex; mode=display">\{\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1},
\overrightarrow{{g_l}_1}, + \overrightarrow{{g_r}_2},
\dots, \overrightarrow{{g_l}_1} +\overrightarrow{{g_r}_N},
@ -370,7 +375,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">158</span> <span class="n">g_sum</span> <span class="o">=</span> <span class="n">g_l_repeat</span> <span class="o">+</span> <span class="n">g_r_repeat_interleave</span></pre></div>
<div class="highlight"><pre><span class="lineno">155</span> <span class="n">g_sum</span> <span class="o">=</span> <span class="n">g_l_repeat</span> <span class="o">+</span> <span class="n">g_r_repeat_interleave</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
@ -381,7 +386,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
<p>Reshape so that <code>g_sum[i, j]</code> is $\overrightarrow{{g_l}_i} + \overrightarrow{{g_r}_j}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">160</span> <span class="n">g_sum</span> <span class="o">=</span> <span class="n">g_sum</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">157</span> <span class="n">g_sum</span> <span class="o">=</span> <span class="n">g_sum</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
@ -397,7 +402,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
<code>e</code> is of shape <code>[n_nodes, n_nodes, n_heads, 1]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">g_sum</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">165</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">g_sum</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
@ -408,7 +413,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
<p>Remove the last dimension of size <code>1</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">167</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
@ -420,9 +425,9 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
<code>[n_nodes, n_nodes, n_heads]</code> or<code>[n_nodes, n_nodes, 1]</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">174</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
<span class="lineno">175</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
<span class="lineno">176</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span></pre></div>
<div class="highlight"><pre><span class="lineno">171</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
<span class="lineno">172</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">n_nodes</span>
<span class="lineno">173</span> <span class="k">assert</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">adj_mat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
@ -434,7 +439,7 @@ where each node embedding is repeated <code>n_nodes</code> times.</p>
$e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">179</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">adj_mat</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">))</span></pre></div>
<div class="highlight"><pre><span class="lineno">176</span> <span class="n">e</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">adj_mat</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
@ -451,7 +456,7 @@ $e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.</p>
makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">e</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">186</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">e</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
@ -462,7 +467,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
<p>Apply dropout regularization</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">192</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">a</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">189</span> <span class="n">a</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">a</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
@ -475,7 +480,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="n">attn_res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ijh,jhf-&gt;ihf&#39;</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">g_r</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">193</span> <span class="n">attn_res</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;ijh,jhf-&gt;ihf&#39;</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">g_r</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
@ -486,7 +491,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
<p>Concatenate the heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">199</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">196</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_concat</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
@ -499,7 +504,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">201</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
@ -510,7 +515,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
<p>Take the mean of the heads</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">203</span> <span class="k">else</span><span class="p">:</span></pre></div>
<div class="highlight"><pre><span class="lineno">200</span> <span class="k">else</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
@ -523,7 +528,7 @@ makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.</p>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
<div class="highlight"><pre><span class="lineno">202</span> <span class="k">return</span> <span class="n">attn_res</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
</div>

View File

@ -67,20 +67,20 @@
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/graphs/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></h1>
<h1><a href="https://nn.labml.ai/graph/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the GATv2 opeartor from the paper
<a href="https://arxiv.org/abs/2105.14491">How Attentive are Graph Attention Networks?</a>.</p>
<p>GATv2s work on graph data.
A graph consists of nodes and edges connecting nodes.
For example, in Cora dataset the nodes are research papers and the edges are citations that
connect the papers.</p>
<p>The GATv2 operator which fixes the static attention problem of the standard GAT:
<p>The GATv2 operator fixes the static attention problem of the standard GAT:
since the linear layers in the standard GAT are applied right after each other, the ranking
of attended nodes is unconditioned on the query node.
In contrast, in GATv2, every node can attend to any other node.</p>
<p>Here is <a href="https://nn.labml.ai/graphs/gatv2/experiment.html">the training code</a> for training
a two-layer GAT on Cora dataset.</p>
<p><a href="https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
<p>Here is <a href="https://nn.labml.ai/graph/gatv2/experiment.html">the training code</a> for training
a two-layer GATv2 on Cora dataset.</p>
<p><a href="https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>

View File

@ -115,7 +115,7 @@ implementations.</p>
<h4>✨ Graph Neural Networks</h4>
<ul>
<li><a href="graphs/gat/index.html">Graph Attention Networks (GAT)</a></li>
<li><a href="gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></li>
<li><a href="graphs/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></li>
</ul>
<h4><a href="cfr/index.html">Counterfactual Regret Minimization (CFR)</a></h4>
<p>Solving games with incomplete information such as poker with CFR.</p>

View File

@ -281,7 +281,7 @@
<url>
<loc>https://nn.labml.ai/index.html</loc>
<lastmod>2021-07-25T16:30:00+00:00</lastmod>
<lastmod>2021-07-26T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -743,14 +743,14 @@
<url>
<loc>https://nn.labml.ai/graphs/gatv2/index.html</loc>
<lastmod>2021-07-25T16:30:00+00:00</lastmod>
<lastmod>2021-07-26T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/graphs/gatv2/experiment.html</loc>
<lastmod>2021-07-25T16:30:00+00:00</lastmod>
<lastmod>2021-07-26T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>

View File

@ -1,4 +1,4 @@
# [Distilling the Knowledge in a Neural Network]((https://nn.labml.ai/distillation/index.html))
# [Distilling the Knowledge in a Neural Network](https://nn.labml.ai/distillation/index.html)
This is a [PyTorch](https://pytorch.org) implementation/tutorial of the paper
[Distilling the Knowledge in a Neural Network](https://papers.labml.ai/paper/1503.02531).