13import torch
14from torch import nn
15
16from labml import experiment
17from labml.configs import option
18from labml_helpers.module import Module
19from labml_nn.graphs.gat.experiment import Configs as GATConfigs
20from labml_nn.graphs.gatv2 import GraphAttentionV2Layer
23class GATv2(Module):
in_features
is the number of features per node n_hidden
is the number of features in the first graph attention layer n_classes
is the number of classes n_heads
is the number of heads in the graph attention layers dropout
is the dropout probability share_weights
if set to True, the same matrix will be applied to the source and the target node of every edge30 def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float,
31 share_weights: bool = True):
40 super().__init__()
First graph attention layer where we concatenate the heads
43 self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads,
44 is_concat=True, dropout=dropout, share_weights=share_weights)
Activation function after first graph attention layer
46 self.activation = nn.ELU()
Final graph attention layer where we average the heads
48 self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1,
49 is_concat=False, dropout=dropout, share_weights=share_weights)
Dropout
51 self.dropout = nn.Dropout(dropout)
x
is the features vectors of shape [n_nodes, in_features]
adj_mat
is the adjacency matrix of the form [n_nodes, n_nodes, n_heads]
or [n_nodes, n_nodes, 1]
53 def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
Apply dropout to the input
60 x = self.dropout(x)
First graph attention layer
62 x = self.layer1(x, adj_mat)
Activation function
64 x = self.activation(x)
Dropout
66 x = self.dropout(x)
Output layer (without activation) for logits
68 return self.output(x, adj_mat)
Since the experiment is same as GAT experiment but with GATv2 model we extend the same configs and change the model.
71class Configs(GATConfigs):
Whether to share weights for source and target nodes of edges
80 share_weights: bool = False
Set the model
82 model: GATv2 = 'gat_v2_model'
Create GATv2 model
85@option(Configs.model)
86def gat_v2_model(c: Configs):
90 return GATv2(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout, c.share_weights).to(c.device)
93def main():
Create configurations
95 conf = Configs()
Create an experiment
97 experiment.create(name='gatv2')
Calculate configurations.
99 experiment.configs(conf, {
Adam optimizer
101 'optimizer.optimizer': 'Adam',
102 'optimizer.learning_rate': 5e-3,
103 'optimizer.weight_decay': 5e-4,
104
105 'dropout': 0.7,
106 })
Start and watch the experiment
109 with experiment.start():
Run the training
111 conf.run()
115if __name__ == '__main__':
116 main()