在 Cora 数据集上训练图形注意力网络 v2 (GATv2)

View Run

13import torch
14from torch import nn
15
16from labml import experiment
17from labml.configs import option
18from labml_helpers.module import Module
19from labml_nn.graphs.gat.experiment import Configs as GATConfigs
20from labml_nn.graphs.gatv2 import GraphAttentionV2Layer

Graph 注意力网络 v2 (GATv2)

这个图形关注网络有两个图形关注层

23class GATv2(Module):
  • in_features 是每个节点的要素数
  • n_hidden 是第一个图形关注层中的要素数
  • n_classes 是类的数量
  • n_heads 是图表关注层中的头部数量
  • dropout 是辍学概率
  • share_weights 如果设置为 True,则同一矩阵将应用于每条边的源节点和目标节点
30    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float,
31                 share_weights: bool = True):
40        super().__init__()

我们连接头部的第一个图形注意层

43        self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads,
44                                            is_concat=True, dropout=dropout, share_weights=share_weights)

第一个图形关注层之后的激活功能

46        self.activation = nn.ELU()

最后一张图关注层,我们平均头部

48        self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1,
49                                            is_concat=False, dropout=dropout, share_weights=share_weights)

辍学

51        self.dropout = nn.Dropout(dropout)
  • x 是形状的特征向量[n_nodes, in_features]
  • adj_mat 是形式的邻接矩阵[n_nodes, n_nodes, n_heads][n_nodes, n_nodes, 1]
53    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):

将丢失应用于输入

60        x = self.dropout(x)

第一个图形关注层

62        x = self.layer1(x, adj_mat)

激活功能

64        x = self.activation(x)

辍学

66        x = self.dropout(x)

logits 的输出层(未激活)

68        return self.output(x, adj_mat)

配置

由于实验与 GAT 实验相同,但使用 G ATv2 模型,我们扩展了相同的配置并更改了模型。

71class Configs(GATConfigs):

是否共享边的源节点和目标节点的权重

80    share_weights: bool = False

设置模型

82    model: GATv2 = 'gat_v2_model'

创建 GATv2 模型

85@option(Configs.model)
86def gat_v2_model(c: Configs):
90    return GATv2(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout, c.share_weights).to(c.device)
93def main():

创建配置

95    conf = Configs()

创建实验

97    experiment.create(name='gatv2')

计算配置。

99    experiment.configs(conf, {

Adam 优化器

101        'optimizer.optimizer': 'Adam',
102        'optimizer.learning_rate': 5e-3,
103        'optimizer.weight_decay': 5e-4,
104
105        'dropout': 0.7,
106    })

开始观看实验

109    with experiment.start():

运行训练

111        conf.run()

115if __name__ == '__main__':
116    main()