GATv2 refactoring (#70)

* fixed link, add clarification

* updated dropout + experiment link
This commit is contained in:
Shaked Brody
2021-07-26 11:22:54 +03:00
committed by GitHub
parent 895cad46ad
commit 671a93c299
4 changed files with 19 additions and 22 deletions

View File

@ -56,7 +56,7 @@ implementations.
#### ✨ Graph Neural Networks
* [Graph Attention Networks (GAT)](graphs/gat/index.html)
* [Graph Attention Networks v2 (GATv2)](gatv2/index.html)
* [Graph Attention Networks v2 (GATv2)](graphs/gatv2/index.html)
#### ✨ [Counterfactual Regret Minimization (CFR)](cfr/index.html)

View File

@ -4,26 +4,20 @@ title: Graph Attention Networks v2 (GATv2)
summary: >
A PyTorch implementation/tutorial of Graph Attention Networks v2.
---
# Graph Attention Networks v2 (GATv2)
This is a [PyTorch](https://pytorch.org) implementation of the GATv2 operator from the paper
[How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491).
GATv2s work on graph data.
A graph consists of nodes and edges connecting nodes.
For example, in Cora dataset the nodes are research papers and the edges are citations that
connect the papers.
The GATv2 operator which fixes the static attention problem of the standard GAT:
The GATv2 operator fixes the static attention problem of the standard GAT:
since the linear layers in the standard GAT are applied right after each other, the ranking
of attended nodes is unconditioned on the query node.
In contrast, in GATv2, every node can attend to any other node.
Here is [the training code](experiment.html) for training
a two-layer GATv2 on Cora dataset.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3)
"""
import torch
@ -35,10 +29,8 @@ from labml_helpers.module import Module
class GraphAttentionV2Layer(Module):
"""
## Graph attention v2 layer
This is a single graph attention v2 layer.
A GATv2 is made up of multiple such layers.
It takes
$$\mathbf{h} = \{ \overrightarrow{h_1}, \overrightarrow{h_2}, \dots, \overrightarrow{h_N} \}$$,
where $\overrightarrow{h_i} \in \mathbb{R}^F$ as input
@ -97,7 +89,6 @@ class GraphAttentionV2Layer(Module):
* `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`.
* `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`.
We use shape `[n_nodes, n_nodes, 1]` since the adjacency is the same for each head.
Adjacency matrix represent the edges (or connections) among nodes.
`adj_mat[i][j]` is `True` if there is an edge from node `i` to node `j`.
"""
@ -125,7 +116,7 @@ class GraphAttentionV2Layer(Module):
# $a$ is the attention mechanism, that calculates the attention score.
# The paper sums
# $\overrightarrow{{g_l}_i}$, $\overrightarrow{{g_r}_j}$
# followed by a $\text{LeakyReLU}$
# followed by a $\text{LeakyReLU}$
# and does a linear transformation with a weight vector $\mathbf{a} \in \mathbb{R}^{F'}$
#
#
@ -133,6 +124,12 @@ class GraphAttentionV2Layer(Module):
# \Big[
# \overrightarrow{{g_l}_i} + \overrightarrow{{g_r}_j}
# \Big] \Big)$$
# Note: The paper desrcibes $e_{ij}$ as
# $$e_{ij} = \mathbf{a}^\top \text{LeakyReLU} \Big( \mathbf{W}
# \Big[
# \overrightarrow{h_i} \Vert \overrightarrow{h_j}
# \Big] \Big)$$
# which is equivalent to the definition we use here.
# First we calculate
# $\Big[\overrightarrow{{g_l}_i} + \overrightarrow{{g_r}_j} \Big]$
@ -148,7 +145,7 @@ class GraphAttentionV2Layer(Module):
# \overrightarrow{{g_r}_2}, \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_r}_2}, ...\}$$
# where each node embedding is repeated `n_nodes` times.
g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)
# Now we sum to get
# Now we add the two tensors to get
# $$\{\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1},
# \overrightarrow{{g_l}_1}, + \overrightarrow{{g_r}_2},
# \dots, \overrightarrow{{g_l}_1} +\overrightarrow{{g_r}_N},
@ -202,4 +199,4 @@ class GraphAttentionV2Layer(Module):
# Take the mean of the heads
else:
# $$\overrightarrow{h'_i} = \frac{1}{K} \sum_{k=1}^{K} \overrightarrow{h'^k_i}$$
return attn_res.mean(dim=1)
return attn_res.mean(dim=1)

View File

@ -7,7 +7,7 @@ summary: >
# Train a Graph Attention Network v2 (GATv2) on Cora dataset
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3)
"""
from typing import Dict
@ -178,7 +178,7 @@ class Configs(BaseConfigs):
# Number of classes for classification
n_classes: int
# Dropout probability
dropout: float = 0.6
dropout: float = 0.7
# Whether to include the citation network
include_edges: bool = True
# Dataset

View File

@ -1,4 +1,4 @@
# [Graph Attention Networks v2 (GATv2)](https://nn.labml.ai/graphs/gatv2/index.html)
# [Graph Attention Networks v2 (GATv2)](https://nn.labml.ai/graph/gatv2/index.html)
This is a [PyTorch](https://pytorch.org) implementation of the GATv2 opeartor from the paper
[How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491).
@ -8,12 +8,12 @@ A graph consists of nodes and edges connecting nodes.
For example, in Cora dataset the nodes are research papers and the edges are citations that
connect the papers.
The GATv2 operator which fixes the static attention problem of the standard GAT:
The GATv2 operator fixes the static attention problem of the standard GAT:
since the linear layers in the standard GAT are applied right after each other, the ranking
of attended nodes is unconditioned on the query node.
In contrast, in GATv2, every node can attend to any other node.
Here is [the training code](https://nn.labml.ai/graphs/gatv2/experiment.html) for training
a two-layer GAT on Cora dataset.
Here is [the training code](https://nn.labml.ai/graph/gatv2/experiment.html) for training
a two-layer GATv2 on Cora dataset.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3)