这是论文 Gra ph 注意力网络的 Py Torch 实现。
GAT 处理图形数据。图由连接节点的节点和边组成。例如,在 Cora 数据集中,节点是研究论文,边缘是连接论文的引文。
GAT 使用蒙面自我注意力,有点类似于变形金刚。GAT 由堆叠在一起的图形注意层组成。每个图形注意层都将节点嵌入作为输入和输出转换的嵌入。节点嵌入会注意它所连接的其他节点的嵌入。图形关注层的详细信息与实现一起包括在内。
30import torch
31from torch import nn
32
33from labml_helpers.module import Module36class GraphAttentionLayer(Module):in_features
,是每个节点的输入要素数out_features
,是每个节点的输出要素数n_heads
,是注意头的数量is_concat
多头结果应该是串联还是求平均值dropout
是辍学概率leaky_relu_negative_slope
是泄漏的 relu 激活的负斜率50    def __init__(self, in_features: int, out_features: int, n_heads: int,
51                 is_concat: bool = True,
52                 dropout: float = 0.6,
53                 leaky_relu_negative_slope: float = 0.2):62        super().__init__()
63
64        self.is_concat = is_concat
65        self.n_heads = n_heads计算每头的尺寸数
68        if is_concat:
69            assert out_features % n_heads == 0如果我们要连接多个头
71            self.n_hidden = out_features // n_heads
72        else:如果我们平均多头
74            self.n_hidden = out_features用于初始变换的线性层;即在自我关注之前转换节点嵌入
78        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)用于计算注意力分数的线性图层
80        self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)激活注意力分数
82        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)Softmax 需要计算注意力
84        self.softmax = nn.Softmax(dim=1)要应用的掉落层以引起注意
86        self.dropout = nn.Dropout(dropout)h
,是 shape 的输入节点嵌入[n_nodes, in_features]
。adj_mat
是形状的邻接矩阵[n_nodes, n_nodes, n_heads]
。我们使用形状,[n_nodes, n_nodes, 1]
因为每个头部的邻接是相同的。邻接矩阵表示节点之间的边(或连接)。adj_mat[i][j]
True
如果节点与节i
点之间存在边缘j
。
88    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):节点数量
99        n_nodes = h.shape[0]每个头部的初始变换。我们做单个线性变换,然后将其拆分为每个头部。
104        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)我们为每个头部计算这些。为简单起见,我们省略了。
是从一个节点到另一个节点的注意力分数(重要性)。我们为每个头部计算这个值。
是计算注意力分数的注意力机制。本文连接起来,然后使用权重向量后跟 a 进行线性变换。
135        g_repeat = g.repeat(n_nodes, 1, 1)g_repeat_interleave
获取每个节点嵌入重复n_nodes
次数的位置。
140        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)现在我们连接来获得
148        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)重塑g_concat[i, j]
就是这样
150        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)计算e
是形状的[n_nodes, n_nodes, n_heads, 1]
158        e = self.activation(self.attn(g_concat))移除大小的最后一个维度1
160        e = e.squeeze(-1)邻接矩阵的形状应[n_nodes, n_nodes, n_heads]
为[n_nodes, n_nodes, 1]
164        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
165        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
166        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads基于邻接矩阵的掩码。如果没有从到的边缘,则设置为。
169        e = e.masked_fill(adj_mat == 0, float('-inf'))179        a = self.softmax(e)应用辍学正则化
182        a = self.dropout(a)计算每个头的最终输出
注意:本文包含了最后的激活。我们在Graph Attention Layer实现中省略了这一点,并将其用于GAT模型以匹配其他 PyTorch 模块的定义方式——激活作为单独的图层。
191        attn_res = torch.einsum('ijh,jhf->ihf', a, g)连接头部
194        if self.is_concat:196            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)以头脑的意思为例
198        else:200            return attn_res.mean(dim=1)