mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
GATv2 refactoring (#70)
* fixed link, add clarification * updated dropout + experiment link
This commit is contained in:
@ -56,7 +56,7 @@ implementations.
|
||||
#### ✨ Graph Neural Networks
|
||||
|
||||
* [Graph Attention Networks (GAT)](graphs/gat/index.html)
|
||||
* [Graph Attention Networks v2 (GATv2)](gatv2/index.html)
|
||||
* [Graph Attention Networks v2 (GATv2)](graphs/gatv2/index.html)
|
||||
|
||||
#### ✨ [Counterfactual Regret Minimization (CFR)](cfr/index.html)
|
||||
|
||||
|
@ -4,26 +4,20 @@ title: Graph Attention Networks v2 (GATv2)
|
||||
summary: >
|
||||
A PyTorch implementation/tutorial of Graph Attention Networks v2.
|
||||
---
|
||||
|
||||
# Graph Attention Networks v2 (GATv2)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of the GATv2 operator from the paper
|
||||
[How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491).
|
||||
|
||||
GATv2s work on graph data.
|
||||
A graph consists of nodes and edges connecting nodes.
|
||||
For example, in Cora dataset the nodes are research papers and the edges are citations that
|
||||
connect the papers.
|
||||
|
||||
The GATv2 operator which fixes the static attention problem of the standard GAT:
|
||||
The GATv2 operator fixes the static attention problem of the standard GAT:
|
||||
since the linear layers in the standard GAT are applied right after each other, the ranking
|
||||
of attended nodes is unconditioned on the query node.
|
||||
In contrast, in GATv2, every node can attend to any other node.
|
||||
|
||||
Here is [the training code](experiment.html) for training
|
||||
a two-layer GATv2 on Cora dataset.
|
||||
|
||||
[](https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868)
|
||||
[](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3)
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -35,10 +29,8 @@ from labml_helpers.module import Module
|
||||
class GraphAttentionV2Layer(Module):
|
||||
"""
|
||||
## Graph attention v2 layer
|
||||
|
||||
This is a single graph attention v2 layer.
|
||||
A GATv2 is made up of multiple such layers.
|
||||
|
||||
It takes
|
||||
$$\mathbf{h} = \{ \overrightarrow{h_1}, \overrightarrow{h_2}, \dots, \overrightarrow{h_N} \}$$,
|
||||
where $\overrightarrow{h_i} \in \mathbb{R}^F$ as input
|
||||
@ -97,7 +89,6 @@ class GraphAttentionV2Layer(Module):
|
||||
* `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`.
|
||||
* `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`.
|
||||
We use shape `[n_nodes, n_nodes, 1]` since the adjacency is the same for each head.
|
||||
|
||||
Adjacency matrix represent the edges (or connections) among nodes.
|
||||
`adj_mat[i][j]` is `True` if there is an edge from node `i` to node `j`.
|
||||
"""
|
||||
@ -125,7 +116,7 @@ class GraphAttentionV2Layer(Module):
|
||||
# $a$ is the attention mechanism, that calculates the attention score.
|
||||
# The paper sums
|
||||
# $\overrightarrow{{g_l}_i}$, $\overrightarrow{{g_r}_j}$
|
||||
# followed by a $\text{LeakyReLU}$
|
||||
# followed by a $\text{LeakyReLU}$
|
||||
# and does a linear transformation with a weight vector $\mathbf{a} \in \mathbb{R}^{F'}$
|
||||
#
|
||||
#
|
||||
@ -133,6 +124,12 @@ class GraphAttentionV2Layer(Module):
|
||||
# \Big[
|
||||
# \overrightarrow{{g_l}_i} + \overrightarrow{{g_r}_j}
|
||||
# \Big] \Big)$$
|
||||
# Note: The paper desrcibes $e_{ij}$ as
|
||||
# $$e_{ij} = \mathbf{a}^\top \text{LeakyReLU} \Big( \mathbf{W}
|
||||
# \Big[
|
||||
# \overrightarrow{h_i} \Vert \overrightarrow{h_j}
|
||||
# \Big] \Big)$$
|
||||
# which is equivalent to the definition we use here.
|
||||
|
||||
# First we calculate
|
||||
# $\Big[\overrightarrow{{g_l}_i} + \overrightarrow{{g_r}_j} \Big]$
|
||||
@ -148,7 +145,7 @@ class GraphAttentionV2Layer(Module):
|
||||
# \overrightarrow{{g_r}_2}, \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_r}_2}, ...\}$$
|
||||
# where each node embedding is repeated `n_nodes` times.
|
||||
g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)
|
||||
# Now we sum to get
|
||||
# Now we add the two tensors to get
|
||||
# $$\{\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1},
|
||||
# \overrightarrow{{g_l}_1}, + \overrightarrow{{g_r}_2},
|
||||
# \dots, \overrightarrow{{g_l}_1} +\overrightarrow{{g_r}_N},
|
||||
@ -202,4 +199,4 @@ class GraphAttentionV2Layer(Module):
|
||||
# Take the mean of the heads
|
||||
else:
|
||||
# $$\overrightarrow{h'_i} = \frac{1}{K} \sum_{k=1}^{K} \overrightarrow{h'^k_i}$$
|
||||
return attn_res.mean(dim=1)
|
||||
return attn_res.mean(dim=1)
|
@ -7,7 +7,7 @@ summary: >
|
||||
|
||||
# Train a Graph Attention Network v2 (GATv2) on Cora dataset
|
||||
|
||||
[](https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868)
|
||||
[](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3)
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
@ -178,7 +178,7 @@ class Configs(BaseConfigs):
|
||||
# Number of classes for classification
|
||||
n_classes: int
|
||||
# Dropout probability
|
||||
dropout: float = 0.6
|
||||
dropout: float = 0.7
|
||||
# Whether to include the citation network
|
||||
include_edges: bool = True
|
||||
# Dataset
|
||||
|
@ -1,4 +1,4 @@
|
||||
# [Graph Attention Networks v2 (GATv2)](https://nn.labml.ai/graphs/gatv2/index.html)
|
||||
# [Graph Attention Networks v2 (GATv2)](https://nn.labml.ai/graph/gatv2/index.html)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of the GATv2 opeartor from the paper
|
||||
[How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491).
|
||||
@ -8,12 +8,12 @@ A graph consists of nodes and edges connecting nodes.
|
||||
For example, in Cora dataset the nodes are research papers and the edges are citations that
|
||||
connect the papers.
|
||||
|
||||
The GATv2 operator which fixes the static attention problem of the standard GAT:
|
||||
The GATv2 operator fixes the static attention problem of the standard GAT:
|
||||
since the linear layers in the standard GAT are applied right after each other, the ranking
|
||||
of attended nodes is unconditioned on the query node.
|
||||
In contrast, in GATv2, every node can attend to any other node.
|
||||
|
||||
Here is [the training code](https://nn.labml.ai/graphs/gatv2/experiment.html) for training
|
||||
a two-layer GAT on Cora dataset.
|
||||
Here is [the training code](https://nn.labml.ai/graph/gatv2/experiment.html) for training
|
||||
a two-layer GATv2 on Cora dataset.
|
||||
|
||||
[](https://app.labml.ai/run/8e27ad82ed2611ebabb691fb2028a868)
|
||||
[](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3)
|
||||
|
Reference in New Issue
Block a user