mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
refracter gat experiment
This commit is contained in:
@ -17,7 +17,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from labml import lab, monit, tracker, experiment
|
||||
from labml.configs import BaseConfigs
|
||||
from labml.configs import BaseConfigs, option, calculate
|
||||
from labml.utils import download
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.module import Module
|
||||
@ -194,26 +194,6 @@ class Configs(BaseConfigs):
|
||||
# Optimizer
|
||||
optimizer: torch.optim.Adam
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Initialize
|
||||
"""
|
||||
# Create the dataset
|
||||
self.dataset = CoraDataset(self.include_edges)
|
||||
# Get the number of classes
|
||||
self.n_classes = len(self.dataset.classes)
|
||||
# Number of features in the input
|
||||
self.in_features = self.dataset.features.shape[1]
|
||||
# Create the model
|
||||
self.model = GAT(self.in_features, self.n_hidden, self.n_classes, self.n_heads, self.dropout)
|
||||
# Move the model to the device
|
||||
self.model.to(self.device)
|
||||
# Configurable optimizer, so that we can set the configurations
|
||||
# such as learning rate by passing the dictionary later.
|
||||
optimizer_conf = OptimizerConfigs()
|
||||
optimizer_conf.parameters = self.model.parameters()
|
||||
self.optimizer = optimizer_conf
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
### Training loop
|
||||
@ -276,6 +256,38 @@ class Configs(BaseConfigs):
|
||||
tracker.save()
|
||||
|
||||
|
||||
@option(Configs.dataset)
|
||||
def cora_dataset(c: Configs):
|
||||
"""
|
||||
Create Cora dataset
|
||||
"""
|
||||
return CoraDataset(c.include_edges)
|
||||
|
||||
|
||||
# Get the number of classes
|
||||
calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
|
||||
# Number of features in the input
|
||||
calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def gat_model(c: Configs):
|
||||
"""
|
||||
Create GAT model
|
||||
"""
|
||||
return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
|
||||
|
||||
|
||||
@option(Configs.optimizer)
|
||||
def _optimizer(c: Configs):
|
||||
"""
|
||||
Create configurable optimizer
|
||||
"""
|
||||
opt_conf = OptimizerConfigs()
|
||||
opt_conf.parameters = c.model.parameters()
|
||||
return opt_conf
|
||||
|
||||
|
||||
def main():
|
||||
# Create configurations
|
||||
conf = Configs()
|
||||
@ -288,8 +300,6 @@ def main():
|
||||
'optimizer.learning_rate': 5e-3,
|
||||
'optimizer.weight_decay': 5e-4,
|
||||
})
|
||||
# Initialize
|
||||
conf.initialize()
|
||||
|
||||
# Start and watch the experiment
|
||||
with experiment.start():
|
||||
|
@ -10,102 +10,14 @@ summary: >
|
||||
[](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3)
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml import lab, monit, tracker, experiment
|
||||
from labml.configs import BaseConfigs
|
||||
from labml.utils import download
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml import experiment
|
||||
from labml.configs import option
|
||||
from labml_helpers.module import Module
|
||||
from labml_nn.graphs.gat.experiment import Configs as GATConfigs
|
||||
from labml_nn.graphs.gatv2 import GraphAttentionV2Layer
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
||||
class CoraDataset:
|
||||
"""
|
||||
## [Cora Dataset](https://linqs.soe.ucsc.edu/data)
|
||||
|
||||
Cora dataset is a dataset of research papers.
|
||||
For each paper we are given a binary feature vector that indicates the presence of words.
|
||||
Each paper is classified into one of 7 classes.
|
||||
The dataset also has the citation network.
|
||||
|
||||
The papers are the nodes of the graph and the edges are the citations.
|
||||
|
||||
The task is to classify the edges to the 7 classes with feature vectors and
|
||||
citation network as input.
|
||||
"""
|
||||
# Labels for each node
|
||||
labels: torch.Tensor
|
||||
# Set of class names and an unique integer index
|
||||
classes: Dict[str, int]
|
||||
# Feature vectors for all nodes
|
||||
features: torch.Tensor
|
||||
# Adjacency matrix with the edge information.
|
||||
# `adj_mat[i][j]` is `True` if there is an edge from `i` to `j`.
|
||||
adj_mat: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def _download():
|
||||
"""
|
||||
Download the dataset
|
||||
"""
|
||||
if not (lab.get_data_path() / 'cora').exists():
|
||||
download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
|
||||
lab.get_data_path() / 'cora.tgz')
|
||||
download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
|
||||
|
||||
def __init__(self, include_edges: bool = True):
|
||||
"""
|
||||
Load the dataset
|
||||
"""
|
||||
|
||||
# Whether to include edges.
|
||||
# This is test how much accuracy is lost if we ignore the citation network.
|
||||
self.include_edges = include_edges
|
||||
|
||||
# Download dataset
|
||||
self._download()
|
||||
|
||||
# Read the paper ids, feature vectors, and labels
|
||||
with monit.section('Read content file'):
|
||||
content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
|
||||
# Load the citations, it's a list of pairs of integers.
|
||||
with monit.section('Read citations file'):
|
||||
citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
|
||||
|
||||
# Get the feature vectors
|
||||
features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
|
||||
# Normalize the feature vectors
|
||||
self.features = features / features.sum(dim=1, keepdim=True)
|
||||
|
||||
# Get the class names and assign an unique integer to each of them
|
||||
self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
|
||||
# Get the labels as those integers
|
||||
self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
|
||||
|
||||
# Get the paper ids
|
||||
paper_ids = np.array(content[:, 0], dtype=np.int32)
|
||||
# Map of paper id to index
|
||||
ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
|
||||
|
||||
# Empty adjacency matrix - an identity matrix
|
||||
self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
|
||||
|
||||
# Mark the citations in the adjacency matrix
|
||||
if self.include_edges:
|
||||
for e in citations:
|
||||
# The pair of paper indexes
|
||||
e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
|
||||
# We build a symmetrical graph, where if paper $i$ referenced
|
||||
# paper $j$ we place an adge from $i$ to $j$ as well as an edge
|
||||
# from $j$ to $i$.
|
||||
self.adj_mat[e1][e2] = True
|
||||
self.adj_mat[e2][e1] = True
|
||||
|
||||
|
||||
class GATv2(Module):
|
||||
@ -115,7 +27,8 @@ class GATv2(Module):
|
||||
This graph attention network has two [graph attention layers](index.html).
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float, share_weights: bool = True):
|
||||
def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float,
|
||||
share_weights: bool = True):
|
||||
"""
|
||||
* `in_features` is the number of features per node
|
||||
* `n_hidden` is the number of features in the first graph attention layer
|
||||
@ -127,11 +40,13 @@ class GATv2(Module):
|
||||
super().__init__()
|
||||
|
||||
# First graph attention layer where we concatenate the heads
|
||||
self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout, share_weights=share_weights)
|
||||
self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout,
|
||||
share_weights=share_weights)
|
||||
# Activation function after first graph attention layer
|
||||
self.activation = nn.ELU()
|
||||
# Final graph attention layer where we average the heads
|
||||
self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout, share_weights=share_weights)
|
||||
self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout,
|
||||
share_weights=share_weights)
|
||||
# Dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@ -153,128 +68,26 @@ class GATv2(Module):
|
||||
return self.output(x, adj_mat)
|
||||
|
||||
|
||||
def accuracy(output: torch.Tensor, labels: torch.Tensor):
|
||||
"""
|
||||
A simple function to calculate the accuracy
|
||||
"""
|
||||
return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)
|
||||
|
||||
|
||||
class Configs(BaseConfigs):
|
||||
class Configs(GATConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
Since the experiment is same as [GAT experiment](../gat/experiment.html) but with
|
||||
[GATv2 mode](index.html) we extend the same configs and change the model
|
||||
"""
|
||||
|
||||
# Model
|
||||
model: GATv2
|
||||
# Number of nodes to train on
|
||||
training_samples: int = 500
|
||||
# Number of features per node in the input
|
||||
in_features: int
|
||||
# Number of features in the first graph attention layer
|
||||
n_hidden: int = 64
|
||||
# Number of heads
|
||||
n_heads: int = 8
|
||||
# Number of classes for classification
|
||||
n_classes: int
|
||||
# Dropout probability
|
||||
dropout: float = 0.7
|
||||
# Whether to include the citation network
|
||||
include_edges: bool = True
|
||||
# Dataset
|
||||
dataset: CoraDataset
|
||||
# Number of training iterations
|
||||
epochs: int = 1_000
|
||||
# Loss function
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
# Device to train on
|
||||
#
|
||||
# This creates configs for device, so that
|
||||
# we can change the device by passing a config value
|
||||
device: torch.device = DeviceConfigs()
|
||||
# Optimizer
|
||||
optimizer: torch.optim.Adam
|
||||
# Whether to share weights for source and target nodes of edges
|
||||
share_weights: bool = True
|
||||
# Set the model
|
||||
model: GATv2 = 'gat_v2_model'
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Initialize
|
||||
"""
|
||||
# Create the dataset
|
||||
self.dataset = CoraDataset(self.include_edges)
|
||||
# Get the number of classes
|
||||
self.n_classes = len(self.dataset.classes)
|
||||
# Number of features in the input
|
||||
self.in_features = self.dataset.features.shape[1]
|
||||
# Create the model
|
||||
self.model = GATv2(self.in_features, self.n_hidden, self.n_classes, self.n_heads, self.dropout)
|
||||
# Move the model to the device
|
||||
self.model.to(self.device)
|
||||
# Configurable optimizer, so that we can set the configurations
|
||||
# such as learning rate by passing the dictionary later.
|
||||
optimizer_conf = OptimizerConfigs()
|
||||
optimizer_conf.parameters = self.model.parameters()
|
||||
self.optimizer = optimizer_conf
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
### Training loop
|
||||
|
||||
We do full batch training since the dataset is small.
|
||||
If we were to sample and train we will have to sample a set of
|
||||
nodes for each training step along with the edges that span
|
||||
across those selected nodes.
|
||||
"""
|
||||
# Move the feature vectors to the device
|
||||
features = self.dataset.features.to(self.device)
|
||||
# Move the labels to the device
|
||||
labels = self.dataset.labels.to(self.device)
|
||||
# Move the adjacency matrix to the device
|
||||
edges_adj = self.dataset.adj_mat.to(self.device)
|
||||
# Add an empty third dimension for the heads
|
||||
edges_adj = edges_adj.unsqueeze(-1)
|
||||
|
||||
# Random indexes
|
||||
idx_rand = torch.randperm(len(labels))
|
||||
# Nodes for training
|
||||
idx_train = idx_rand[:self.training_samples]
|
||||
# Nodes for validation
|
||||
idx_valid = idx_rand[self.training_samples:]
|
||||
|
||||
# Training loop
|
||||
for epoch in monit.loop(self.epochs):
|
||||
# Set the model to training mode
|
||||
self.model.train()
|
||||
# Make all the gradients zero
|
||||
self.optimizer.zero_grad()
|
||||
# Evaluate the model
|
||||
output = self.model(features, edges_adj)
|
||||
# Get the loss for training nodes
|
||||
loss = self.loss_func(output[idx_train], labels[idx_train])
|
||||
# Calculate gradients
|
||||
loss.backward()
|
||||
# Take optimization step
|
||||
self.optimizer.step()
|
||||
# Log the loss
|
||||
tracker.add('loss.train', loss)
|
||||
# Log the accuracy
|
||||
tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
|
||||
|
||||
# Set mode to evaluation mode for validation
|
||||
self.model.eval()
|
||||
|
||||
# No need to compute gradients
|
||||
with torch.no_grad():
|
||||
# Evaluate the model again
|
||||
output = self.model(features, edges_adj)
|
||||
# Calculate the loss for validation nodes
|
||||
loss = self.loss_func(output[idx_valid], labels[idx_valid])
|
||||
# Log the loss
|
||||
tracker.add('loss.valid', loss)
|
||||
# Log the accuracy
|
||||
tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
|
||||
|
||||
# Save logs
|
||||
tracker.save()
|
||||
@option(Configs.model)
|
||||
def gat_v2_model(c: Configs):
|
||||
"""
|
||||
Create GAT model
|
||||
"""
|
||||
return GATv2(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout, c.share_weights).to(c.device)
|
||||
|
||||
|
||||
def main():
|
||||
@ -288,9 +101,9 @@ def main():
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 5e-3,
|
||||
'optimizer.weight_decay': 5e-4,
|
||||
|
||||
'dropout': 0.7,
|
||||
})
|
||||
# Initialize
|
||||
conf.initialize()
|
||||
|
||||
# Start and watch the experiment
|
||||
with experiment.start():
|
||||
|
Reference in New Issue
Block a user