This is a PyTorch implementation of the paper Graph Attention Networks.
GATs work on graph data. A graph consists of nodes and edges connecting nodes. For example, in Cora dataset the nodes are research papers and the edges are citations that connect the papers.
GAT uses masked self-attention, kind of similar to transformers. GAT consists of graph attention layers stacked on top of each other. Each graph attention layer gets node embeddings as inputs and outputs transformed embeddings. The node embeddings pay attention to the embeddings of other nodes it's connected to. The details of graph attention layers are included alongside the implementation.
Here is the training code for training a two-layer GAT on Cora dataset.
28import torch
29from torch import nnThis is a single graph attention layer. A GAT is made up of multiple such layers.
It takes , where as input and outputs , where .
32class GraphAttentionLayer(nn.Module):in_features
, , is the number of input features per node out_features
, , is the number of output features per node n_heads
, , is the number of attention heads is_concat
 whether the multi-head results should be concatenated or averaged dropout
 is the dropout probability leaky_relu_negative_slope
 is the negative slope for leaky relu activation46    def __init__(self, in_features: int, out_features: int, n_heads: int,
47                 is_concat: bool = True,
48                 dropout: float = 0.6,
49                 leaky_relu_negative_slope: float = 0.2):58        super().__init__()
59
60        self.is_concat = is_concat
61        self.n_heads = n_headsCalculate the number of dimensions per head
64        if is_concat:
65            assert out_features % n_heads == 0If we are concatenating the multiple heads
67            self.n_hidden = out_features // n_heads
68        else:If we are averaging the multiple heads
70            self.n_hidden = out_featuresLinear layer for initial transformation; i.e. to transform the node embeddings before self-attention
74        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)Linear layer to compute attention score
76        self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)The activation for attention score
78        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)Softmax to compute attention
80        self.softmax = nn.Softmax(dim=1)Dropout layer to be applied for attention
82        self.dropout = nn.Dropout(dropout)h
,  is the input node embeddings of shape [n_nodes, in_features]
. adj_mat
 is the adjacency matrix of shape [n_nodes, n_nodes, n_heads]
. We use shape [n_nodes, n_nodes, 1]
 since the adjacency is the same for each head.Adjacency matrix represent the edges (or connections) among nodes. adj_mat[i][j]
 is True
 if there is an edge from node i
 to node j
.
84    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):Number of nodes
95        n_nodes = h.shape[0]The initial transformation, for each head. We do single linear transformation and then split it up for each head.
100        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)We calculate these for each head . We have omitted for simplicity.
is the attention score (importance) from node to node . We calculate this for each head.
is the attention mechanism, that calculates the attention score. The paper concatenates , and does a linear transformation with a weight vector followed by a .
First we calculate for all pairs of .
g_repeat
 gets  where each node embedding is repeated n_nodes
 times. 
131        g_repeat = g.repeat(n_nodes, 1, 1)g_repeat_interleave
 gets  where each node embedding is repeated n_nodes
 times. 
136        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)Now we concatenate to get
144        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)Reshape so that g_concat[i, j]
 is  
146        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)Calculate  e
 is of shape [n_nodes, n_nodes, n_heads, 1]
 
154        e = self.activation(self.attn(g_concat))Remove the last dimension of size 1
 
156        e = e.squeeze(-1)The adjacency matrix should have shape [n_nodes, n_nodes, n_heads]
 or[n_nodes, n_nodes, 1]
 
160        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
161        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
162        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_headsMask based on adjacency matrix. is set to if there is no edge from to .
165        e = e.masked_fill(adj_mat == 0, float('-inf'))We then normalize attention scores (or coefficients)
where is the set of nodes connected to .
We do this by setting unconnected to which makes for unconnected pairs.
175        a = self.softmax(e)Apply dropout regularization
178        a = self.dropout(a)Calculate final output for each head
Note: The paper includes the final activation in We have omitted this from the Graph Attention Layer implementation and use it on the GAT model to match with how other PyTorch modules are defined - activation as a separate layer.
187        attn_res = torch.einsum('ijh,jhf->ihf', a, g)Concatenate the heads
190        if self.is_concat:192            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)Take the mean of the heads
194        else:196            return attn_res.mean(dim=1)