Files
2024-06-21 19:35:22 +05:30

471 lines
25 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="zh">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="这列火车是 Cora 数据集上的 Graph 注意力网络 v2 (GATv2)"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="在 Cora 数据集上训练图形注意力网络 v2 (GATv2)"/>
<meta name="twitter:description" content="这列火车是 Cora 数据集上的 Graph 注意力网络 v2 (GATv2)"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/graphs/gatv2/experiment.html"/>
<meta property="og:title" content="在 Cora 数据集上训练图形注意力网络 v2 (GATv2)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="在 Cora 数据集上训练图形注意力网络 v2 (GATv2)"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="在 Cora 数据集上训练图形注意力网络 v2 (GATv2)"/>
<meta property="og:description" content="这列火车是 Cora 数据集上的 Graph 注意力网络 v2 (GATv2)"/>
<title>在 Cora 数据集上训练图形注意力网络 v2 (GATv2)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/graphs/gatv2/experiment.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">graphs</a>
<a class="parent" href="index.html">gatv2</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/graphs/gatv2/experiment.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>在 Cora 数据集上训练图注意力网络 v2 (Gatv2)</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">11</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">12</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">13</span>
<span class="lineno">14</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
<span class="lineno">15</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml_nn.graphs.gat.experiment</span> <span class="kn">import</span> <span class="n">Configs</span> <span class="k">as</span> <span class="n">GATConfigs</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml_nn.graphs.gatv2</span> <span class="kn">import</span> <span class="n">GraphAttentionV2Layer</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Graph 注意力网络 v2 (GATv2)</h2>
<p>这个图形关注网络有两个<a href="index.html">图形关注层</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">21</span><span class="k">class</span> <span class="nc">GATv2</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">in_features</span></code>
是每个节点的要素数</li>
<li><code class="highlight"><span></span><span class="n">n_hidden</span></code>
是第一个图形关注层中的要素数</li>
<li><code class="highlight"><span></span><span class="n">n_classes</span></code>
是类的数量</li>
<li><code class="highlight"><span></span><span class="n">n_heads</span></code>
是图表关注层中的头部数量</li>
<li><code class="highlight"><span></span><span class="n">dropout</span></code>
是辍学概率</li>
<li><code class="highlight"><span></span><span class="n">share_weights</span></code>
如果设置为 True则同一矩阵将应用于每条边的源节点和目标节点</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">28</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="lineno">29</span> <span class="n">share_weights</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">38</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>我们连接头部的第一个图形注意层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">41</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">GraphAttentionV2Layer</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span>
<span class="lineno">42</span> <span class="n">is_concat</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="n">dropout</span><span class="p">,</span> <span class="n">share_weights</span><span class="o">=</span><span class="n">share_weights</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>第一个图形关注层之后的激活功能</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ELU</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>最后一张图关注层,我们平均头部</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">46</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">GraphAttentionV2Layer</span><span class="p">(</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">n_classes</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span>
<span class="lineno">47</span> <span class="n">is_concat</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="n">dropout</span><span class="p">,</span> <span class="n">share_weights</span><span class="o">=</span><span class="n">share_weights</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>辍学</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<ul><li><code class="highlight"><span></span><span class="n">x</span></code>
是形状的特征向量<code class="highlight"><span></span><span class="p">[</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">in_features</span><span class="p">]</span></code>
</li>
<li><code class="highlight"><span></span><span class="n">adj_mat</span></code>
是形式的邻接矩阵<code class="highlight"><span></span><span class="p">[</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">n_nodes</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">]</span></code>
<code class="highlight"><span></span><span class="p">[</span><span class="n">n_nodes</span><span class="p">,</span> <span class="n">n_nodes</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span></code>
</li></ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">adj_mat</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>将丢失应用于输入</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">58</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>第一个图形关注层</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">60</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">adj_mat</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>激活功能</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">62</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>辍学</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">64</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>logits 的输出层(未激活)</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">66</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">adj_mat</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<h2>配置</h2>
<p>由于实验与 <a href="../gat/experiment.html">GAT 实验</a>相同,但使用 G <a href="index.html">ATv2 模型</a>,我们扩展了相同的配置并更改了模型。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">69</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">GATConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>是否共享边的源节点和目标节点的权重</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="n">share_weights</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>设置模型</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="n">model</span><span class="p">:</span> <span class="n">GATv2</span> <span class="o">=</span> <span class="s1">&#39;gat_v2_model&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>创建 GATv2 模型</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">83</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
<span class="lineno">84</span><span class="k">def</span> <span class="nf">gat_v2_model</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="k">return</span> <span class="n">GATv2</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">in_features</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_hidden</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_classes</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">c</span><span class="o">.</span><span class="n">share_weights</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">91</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>创建配置</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>创建实验</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;gatv2&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>计算配置。</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span> <span class="p">{</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Adam 优化器</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">99</span> <span class="s1">&#39;optimizer.optimizer&#39;</span><span class="p">:</span> <span class="s1">&#39;Adam&#39;</span><span class="p">,</span>
<span class="lineno">100</span> <span class="s1">&#39;optimizer.learning_rate&#39;</span><span class="p">:</span> <span class="mf">5e-3</span><span class="p">,</span>
<span class="lineno">101</span> <span class="s1">&#39;optimizer.weight_decay&#39;</span><span class="p">:</span> <span class="mf">5e-4</span><span class="p">,</span>
<span class="lineno">102</span>
<span class="lineno">103</span> <span class="s1">&#39;dropout&#39;</span><span class="p">:</span> <span class="mf">0.7</span><span class="p">,</span>
<span class="lineno">104</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>开始观看实验</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">107</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>运行训练</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">109</span> <span class="n">conf</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">114</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>