This is a PyTorch implementation of the GATv2 operator from the paper How Attentive are Graph Attention Networks?.
GATv2s work on graph data similar to GAT. A graph consists of nodes and edges connecting nodes. For example, in Cora dataset the nodes are research papers and the edges are citations that connect the papers.
The GATv2 operator fixes the static attention problem of the standard GAT. Static attention is when the attention to the key nodes has the same rank (order) for any query node. GAT computes attention from query node to key node as,
Note that for any query node , the attention rank () of keys depends only on . Therefore the attention rank of keys remains the same (static) for all queries.
GATv2 allows dynamic attention by changing the attention mechanism,
The paper shows that GATs static attention mechanism fails on some graph problems with a synthetic dictionary lookup dataset. It's a fully connected bipartite graph where one set of nodes (query nodes) have a key associated with it and the other set of nodes have both a key and a value associated with it. The goal is to predict the values of query nodes. GAT fails on this task because of its limited static attention.
Here is the training code for training a two-layer GATv2 on Cora dataset.
59import torch
60from torch import nn
61
62from labml_helpers.module import ModuleThis is a single graph attention v2 layer. A GATv2 is made up of multiple such layers. It takes , where as input and outputs , where .
65class GraphAttentionV2Layer(Module):in_features
, , is the number of input features per node out_features
, , is the number of output features per node n_heads
, , is the number of attention heads is_concat
 whether the multi-head results should be concatenated or averaged dropout
 is the dropout probability leaky_relu_negative_slope
 is the negative slope for leaky relu activation share_weights
 if set to True
, the same matrix will be applied to the source and the target node of every edge78    def __init__(self, in_features: int, out_features: int, n_heads: int,
79                 is_concat: bool = True,
80                 dropout: float = 0.6,
81                 leaky_relu_negative_slope: float = 0.2,
82                 share_weights: bool = False):92        super().__init__()
93
94        self.is_concat = is_concat
95        self.n_heads = n_heads
96        self.share_weights = share_weightsCalculate the number of dimensions per head
99        if is_concat:
100            assert out_features % n_heads == 0If we are concatenating the multiple heads
102            self.n_hidden = out_features // n_heads
103        else:If we are averaging the multiple heads
105            self.n_hidden = out_featuresLinear layer for initial source transformation; i.e. to transform the source node embeddings before self-attention
109        self.linear_l = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)If share_weights
 is True
 the same linear layer is used for the target nodes 
111        if share_weights:
112            self.linear_r = self.linear_l
113        else:
114            self.linear_r = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)Linear layer to compute attention score
116        self.attn = nn.Linear(self.n_hidden, 1, bias=False)The activation for attention score
118        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)Softmax to compute attention
120        self.softmax = nn.Softmax(dim=1)Dropout layer to be applied for attention
122        self.dropout = nn.Dropout(dropout)h
,  is the input node embeddings of shape [n_nodes, in_features]
. adj_mat
 is the adjacency matrix of shape [n_nodes, n_nodes, n_heads]
. We use shape [n_nodes, n_nodes, 1]
 since the adjacency is the same for each head. Adjacency matrix represent the edges (or connections) among nodes. adj_mat[i][j]
 is True
 if there is an edge from node i
 to node j
.124    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):Number of nodes
134        n_nodes = h.shape[0]The initial transformations, for each head. We do two linear transformations and then split it up for each head.
140        g_l = self.linear_l(h).view(n_nodes, self.n_heads, self.n_hidden)
141        g_r = self.linear_r(h).view(n_nodes, self.n_heads, self.n_hidden)We calculate these for each head . We have omitted for simplicity.
is the attention score (importance) from node to node . We calculate this for each head.
is the attention mechanism, that calculates the attention score. The paper sums , followed by a and does a linear transformation with a weight vector
Note: The paper desrcibes as which is equivalent to the definition we use here.
First we calculate for all pairs of .
g_l_repeat
 gets  where each node embedding is repeated n_nodes
 times. 
179        g_l_repeat = g_l.repeat(n_nodes, 1, 1)g_r_repeat_interleave
 gets  where each node embedding is repeated n_nodes
 times. 
184        g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)Now we add the two tensors to get
192        g_sum = g_l_repeat + g_r_repeat_interleaveReshape so that g_sum[i, j]
 is  
194        g_sum = g_sum.view(n_nodes, n_nodes, self.n_heads, self.n_hidden)Calculate  e
 is of shape [n_nodes, n_nodes, n_heads, 1]
 
202        e = self.attn(self.activation(g_sum))Remove the last dimension of size 1
 
204        e = e.squeeze(-1)The adjacency matrix should have shape [n_nodes, n_nodes, n_heads]
 or[n_nodes, n_nodes, 1]
 
208        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
209        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
210        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_headsMask based on adjacency matrix. is set to if there is no edge from to .
213        e = e.masked_fill(adj_mat == 0, float('-inf'))We then normalize attention scores (or coefficients)
where is the set of nodes connected to .
We do this by setting unconnected to which makes for unconnected pairs.
223        a = self.softmax(e)Apply dropout regularization
226        a = self.dropout(a)Calculate final output for each head
230        attn_res = torch.einsum('ijh,jhf->ihf', a, g_r)Concatenate the heads
233        if self.is_concat:235            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)Take the mean of the heads
237        else:239            return attn_res.mean(dim=1)