これは、「グラフアテンションネットワークはどの程度注意深いのか?」という論文のGATv2演算子をPyTorchで実装したものです。
。GATv2は、GATと同様にグラフデータを処理します。グラフは、ノードとノードを接続するエッジで構成されます。たとえば、Coraデータセットでは、ノードは研究論文で、端は論文をつなぐ引用です
。GATv2 オペレータは、標準 GAT のスタティックアテンションの問題を解決します。スタティックアテンションとは、どのクエリノードでもキーノードへのアテンションのランク(順序)が同じであることです。GAT は、クエリノードからキーノードへのアテンションを次のように計算します
。どのクエリノードでも、キーのアテンションランク () は以下にのみ依存することに注意してください。したがって、キーのアテンションランクはすべてのクエリで同じ(静的)ままです。
GATv2はアテンションメカニズムを変更することで動的なアテンションを可能にします。
この論文は、GATの静的注意メカニズムが、合成辞書検索データセットのグラフ問題の一部で失敗することを示しています。これは完全に接続された二部グラフで、一方のノード(クエリノード)にはキーが関連付けられ、もう一方のノードセットにはキーと値の両方が関連付けられています。目標は、クエリノードの値を予測することです。GAT は静的処理が制限されているため、このタスクは失敗します。
57import torch
58from torch import nn
59
60from labml_helpers.module import Moduleこれはシングルグラフアテンションv2レイヤーです。GATv2は、このような複数のレイヤーで構成されています。入力として、where を、出力として、where を取ります。
63class GraphAttentionV2Layer(Module):in_features
、はノードあたりの入力フィーチャの数ですout_features
、はノードごとの出力フィーチャの数ですn_heads
、はアテンション・ヘッドの数is_concat
マルチヘッドの結果を連結すべきか平均化すべきかdropout
は脱落確率ですleaky_relu_negative_slope
リークのあるリレーアクティベーションの負の傾きですshare_weights
に設定するとTrue
、すべてのエッジのソースノードとターゲットノードに同じマトリックスが適用されます76    def __init__(self, in_features: int, out_features: int, n_heads: int,
77                 is_concat: bool = True,
78                 dropout: float = 0.6,
79                 leaky_relu_negative_slope: float = 0.2,
80                 share_weights: bool = False):90        super().__init__()
91
92        self.is_concat = is_concat
93        self.n_heads = n_heads
94        self.share_weights = share_weights頭あたりの寸法数の計算
97        if is_concat:
98            assert out_features % n_heads == 0複数のヘッドを連結する場合
100            self.n_hidden = out_features // n_heads
101        else:複数のヘッドを平均化する場合
103            self.n_hidden = out_features初期ソース変換用の線形レイヤー。つまり、自己処理の前にソースノードの埋め込みを変換する
107        self.linear_l = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)share_weights
True
ターゲットノードに同じリニアレイヤーが使用されている場合
109        if share_weights:
110            self.linear_r = self.linear_l
111        else:
112            self.linear_r = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)アテンションスコアを計算する線形レイヤー
114        self.attn = nn.Linear(self.n_hidden, 1, bias=False)アテンションスコアのアクティベーション
116        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)注意力を計算するソフトマックス
118        self.softmax = nn.Softmax(dim=1)注目すべきドロップアウト層
120        self.dropout = nn.Dropout(dropout)h
、はシェイプの入力ノード埋め込みです。[n_nodes, in_features]
adj_mat
[n_nodes, n_nodes, n_heads]
は形状の隣接行列です。[n_nodes, n_nodes, 1]
各ヘッドの隣接関係が同じなので、形状を使用します。隣接マトリックスは、ノード間のエッジ (または接続) を表します。adj_mat[i][j]
True
i
ノード間でエッジがある場合ですj
。122    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):ノード数
132        n_nodes = h.shape[0]各ヘッドの初期変換。線形変換を 2 回行い、それを各ヘッドに分割します
。138        g_l = self.linear_l(h).view(n_nodes, self.n_heads, self.n_hidden)
139        g_r = self.linear_r(h).view(n_nodes, self.n_heads, self.n_hidden)これらは頭ごとに計算します。わかりやすくするために省略しました。
ノードごとのアテンションスコア(重要度)です。これを頭ごとに計算します。
アテンションスコアを計算するアテンションメカニズムです。紙は合計し、その後にAとが続き、重みベクトルを使用して線形変換を行います
注:この論文では、どちらがここで使用している定義と同等であるかが説明されています。
177        g_l_repeat = g_l.repeat(n_nodes, 1, 1)g_r_repeat_interleave
n_nodes
各ノードの埋め込みが何度も繰り返される場所を取得します。
182        g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)次に、2 つのテンソルを追加して
190        g_sum = g_l_repeat + g_r_repeat_interleaveg_sum[i, j]
そのように形を変えてください 
192        g_sum = g_sum.view(n_nodes, n_nodes, self.n_heads, self.n_hidden)e
形状の計算 [n_nodes, n_nodes, n_heads, 1]
200        e = self.attn(self.activation(g_sum))サイズの最後のディメンションを削除 1
202        e = e.squeeze(-1)隣接マトリックスは、[n_nodes, n_nodes, n_heads]
またはの形状でなければなりません [n_nodes, n_nodes, 1]
206        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
207        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
208        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads隣接マトリックスに基づくマスク。からまでのエッジがない場合は、に設定されます。
211        e = e.masked_fill(adj_mat == 0, float('-inf'))221        a = self.softmax(e)ドロップアウト正則化を適用
224        a = self.dropout(a)各ヘッドの最終出力を計算
228        attn_res = torch.einsum('ijh,jhf->ihf', a, g_r)ヘッドを連結してください
231        if self.is_concat:233            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)頭の中を平均して
235        else:237            return attn_res.mean(dim=1)