这是 Gatv2 运营商的 PyTorch 实现,来自论文 G raph 注意力网络有多专心?。
GATV2 处理的图形数据与 GAT 类似。图由连接节点的节点和边组成。例如,在 Cora 数据集中,节点是研究论文,边缘是连接论文的引文。
GATv2 操作员修复了标准 G AT 的静态注意力问题。静态关注是指对关键节点的关注对于任何查询节点具有相同的排名(顺序)。GAT 计算从查询节点到关键节点的注意力,
请注意,对于任何查询节点,键的关注等级 () 仅取决于。因此,对于所有查询,键的关注等级保持不变(静态)。
GATv2 通过改变注意力机制来实现动态关注,
本文表明,GATS静态注意机制在合成字典查找数据集的某些图形问题上失败了。这是一个完全连接的二部图,其中一组节点(查询节点)有一个与之关联的键,而另一组节点同时具有与之关联的键和值。目标是预测查询节点的值。由于静态注意力有限,GAT 无法完成此任务。
59import torch
60from torch import nn
61
62from labml_helpers.module import Module65class GraphAttentionV2Layer(Module):in_features
,是每个节点的输入要素数out_features
,是每个节点的输出要素数n_heads
,是注意头的数量is_concat
多头结果应该是串联还是求平均值dropout
是辍学概率leaky_relu_negative_slope
是泄漏的 relu 激活的负斜率share_weights
如果设置为True
,则同一矩阵将应用于每条边的源节点和目标节点78 def __init__(self, in_features: int, out_features: int, n_heads: int,
79 is_concat: bool = True,
80 dropout: float = 0.6,
81 leaky_relu_negative_slope: float = 0.2,
82 share_weights: bool = False):92 super().__init__()
93
94 self.is_concat = is_concat
95 self.n_heads = n_heads
96 self.share_weights = share_weights计算每头的尺寸数
99 if is_concat:
100 assert out_features % n_heads == 0如果我们要连接多个头
102 self.n_hidden = out_features // n_heads
103 else:如果我们平均多头
105 self.n_hidden = out_features用于初始源变换的线性层;即在自我关注之前转换源节点嵌入
109 self.linear_l = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)如果share_weights
是True
,则为目标节点使用相同的线性层
111 if share_weights:
112 self.linear_r = self.linear_l
113 else:
114 self.linear_r = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)用于计算注意力分数的线性图层
116 self.attn = nn.Linear(self.n_hidden, 1, bias=False)激活注意力分数
118 self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)Softmax 需要计算注意力
120 self.softmax = nn.Softmax(dim=1)要应用的掉落层以引起注意
122 self.dropout = nn.Dropout(dropout)h
,是 shape 的输入节点嵌入[n_nodes, in_features]
。adj_mat
是形状的邻接矩阵[n_nodes, n_nodes, n_heads]
。我们使用形状,[n_nodes, n_nodes, 1]
因为每个头部的邻接是相同的。邻接矩阵表示节点之间的边(或连接)。adj_mat[i][j]
True
如果节点与节i
点之间存在边缘j
。124 def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):节点数量
134 n_nodes = h.shape[0]每个头部的初始变换。我们做了两个线性变换,然后将其拆分为每个头部。
140 g_l = self.linear_l(h).view(n_nodes, self.n_heads, self.n_hidden)
141 g_r = self.linear_r(h).view(n_nodes, self.n_heads, self.n_hidden)我们为每个头部计算这些。为简单起见,我们省略了。
是从一个节点到另一个节点的注意力分数(重要性)。我们为每个头部计算这个值。
是计算注意力分数的注意力机制。本文求和,然后是 a,然后使用权重向量进行线性变换
注意:本文描述的内容等同于我们在此处使用的定义。
179 g_l_repeat = g_l.repeat(n_nodes, 1, 1)g_r_repeat_interleave
获取每个节点嵌入重复n_nodes
次数的位置。
184 g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)现在我们添加两个张量来获得
192 g_sum = g_l_repeat + g_r_repeat_interleave重塑g_sum[i, j]
就是这样
194 g_sum = g_sum.view(n_nodes, n_nodes, self.n_heads, self.n_hidden)计算e
是形状的[n_nodes, n_nodes, n_heads, 1]
202 e = self.attn(self.activation(g_sum))移除大小的最后一个维度1
204 e = e.squeeze(-1)邻接矩阵的形状应[n_nodes, n_nodes, n_heads]
为[n_nodes, n_nodes, 1]
208 assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
209 assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
210 assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads基于邻接矩阵的掩码。如果没有从到的边缘,则设置为。
213 e = e.masked_fill(adj_mat == 0, float('-inf'))223 a = self.softmax(e)应用辍学正则化
226 a = self.dropout(a)计算每个头的最终输出
230 attn_res = torch.einsum('ijh,jhf->ihf', a, g_r)连接头部
233 if self.is_concat:235 return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)以头脑的意思为例
237 else:239 return attn_res.mean(dim=1)