mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-01 03:43:09 +08:00 
			
		
		
		
	refracter gat experiment
This commit is contained in:
		| @ -17,7 +17,7 @@ import torch | ||||
| from torch import nn | ||||
|  | ||||
| from labml import lab, monit, tracker, experiment | ||||
| from labml.configs import BaseConfigs | ||||
| from labml.configs import BaseConfigs, option, calculate | ||||
| from labml.utils import download | ||||
| from labml_helpers.device import DeviceConfigs | ||||
| from labml_helpers.module import Module | ||||
| @ -194,26 +194,6 @@ class Configs(BaseConfigs): | ||||
|     # Optimizer | ||||
|     optimizer: torch.optim.Adam | ||||
|  | ||||
|     def initialize(self): | ||||
|         """ | ||||
|         Initialize | ||||
|         """ | ||||
|         # Create the dataset | ||||
|         self.dataset = CoraDataset(self.include_edges) | ||||
|         # Get the number of classes | ||||
|         self.n_classes = len(self.dataset.classes) | ||||
|         # Number of features in the input | ||||
|         self.in_features = self.dataset.features.shape[1] | ||||
|         # Create the model | ||||
|         self.model = GAT(self.in_features, self.n_hidden, self.n_classes, self.n_heads, self.dropout) | ||||
|         # Move the model to the device | ||||
|         self.model.to(self.device) | ||||
|         # Configurable optimizer, so that we can set the configurations | ||||
|         # such as learning rate by passing the dictionary later. | ||||
|         optimizer_conf = OptimizerConfigs() | ||||
|         optimizer_conf.parameters = self.model.parameters() | ||||
|         self.optimizer = optimizer_conf | ||||
|  | ||||
|     def run(self): | ||||
|         """ | ||||
|         ### Training loop | ||||
| @ -276,6 +256,38 @@ class Configs(BaseConfigs): | ||||
|             tracker.save() | ||||
|  | ||||
|  | ||||
| @option(Configs.dataset) | ||||
| def cora_dataset(c: Configs): | ||||
|     """ | ||||
|     Create Cora dataset | ||||
|     """ | ||||
|     return CoraDataset(c.include_edges) | ||||
|  | ||||
|  | ||||
| # Get the number of classes | ||||
| calculate(Configs.n_classes, lambda c: len(c.dataset.classes)) | ||||
| # Number of features in the input | ||||
| calculate(Configs.in_features, lambda c: c.dataset.features.shape[1]) | ||||
|  | ||||
|  | ||||
| @option(Configs.model) | ||||
| def gat_model(c: Configs): | ||||
|     """ | ||||
|     Create GAT model | ||||
|     """ | ||||
|     return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device) | ||||
|  | ||||
|  | ||||
| @option(Configs.optimizer) | ||||
| def _optimizer(c: Configs): | ||||
|     """ | ||||
|     Create configurable optimizer | ||||
|     """ | ||||
|     opt_conf = OptimizerConfigs() | ||||
|     opt_conf.parameters = c.model.parameters() | ||||
|     return opt_conf | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     # Create configurations | ||||
|     conf = Configs() | ||||
| @ -288,8 +300,6 @@ def main(): | ||||
|         'optimizer.learning_rate': 5e-3, | ||||
|         'optimizer.weight_decay': 5e-4, | ||||
|     }) | ||||
|     # Initialize | ||||
|     conf.initialize() | ||||
|  | ||||
|     # Start and watch the experiment | ||||
|     with experiment.start(): | ||||
|  | ||||
| @ -10,102 +10,14 @@ summary: > | ||||
| [](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3) | ||||
| """ | ||||
|  | ||||
| from typing import Dict | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| from torch import nn | ||||
|  | ||||
| from labml import lab, monit, tracker, experiment | ||||
| from labml.configs import BaseConfigs | ||||
| from labml.utils import download | ||||
| from labml_helpers.device import DeviceConfigs | ||||
| from labml import experiment | ||||
| from labml.configs import option | ||||
| from labml_helpers.module import Module | ||||
| from labml_nn.graphs.gat.experiment import Configs as GATConfigs | ||||
| from labml_nn.graphs.gatv2 import GraphAttentionV2Layer | ||||
| from labml_nn.optimizers.configs import OptimizerConfigs | ||||
|  | ||||
|  | ||||
| class CoraDataset: | ||||
|     """ | ||||
|     ## [Cora Dataset](https://linqs.soe.ucsc.edu/data) | ||||
|  | ||||
|     Cora dataset is a dataset of research papers. | ||||
|     For each paper we are given a binary feature vector that indicates the presence of words. | ||||
|     Each paper is classified into one of 7 classes. | ||||
|     The dataset also has the citation network. | ||||
|  | ||||
|     The papers are the nodes of the graph and the edges are the citations. | ||||
|  | ||||
|     The task is to classify the edges to the 7 classes with feature vectors and | ||||
|     citation network as input. | ||||
|     """ | ||||
|     # Labels for each node | ||||
|     labels: torch.Tensor | ||||
|     # Set of class names and an unique integer index | ||||
|     classes: Dict[str, int] | ||||
|     # Feature vectors for all nodes | ||||
|     features: torch.Tensor | ||||
|     # Adjacency matrix with the edge information. | ||||
|     # `adj_mat[i][j]` is `True` if there is an edge from `i` to `j`. | ||||
|     adj_mat: torch.Tensor | ||||
|  | ||||
|     @staticmethod | ||||
|     def _download(): | ||||
|         """ | ||||
|         Download the dataset | ||||
|         """ | ||||
|         if not (lab.get_data_path() / 'cora').exists(): | ||||
|             download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz', | ||||
|                                    lab.get_data_path() / 'cora.tgz') | ||||
|             download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path()) | ||||
|  | ||||
|     def __init__(self, include_edges: bool = True): | ||||
|         """ | ||||
|         Load the dataset | ||||
|         """ | ||||
|  | ||||
|         # Whether to include edges. | ||||
|         # This is test how much accuracy is lost if we ignore the citation network. | ||||
|         self.include_edges = include_edges | ||||
|  | ||||
|         # Download dataset | ||||
|         self._download() | ||||
|  | ||||
|         # Read the paper ids, feature vectors, and labels | ||||
|         with monit.section('Read content file'): | ||||
|             content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str)) | ||||
|         # Load the citations, it's a list of pairs of integers. | ||||
|         with monit.section('Read citations file'): | ||||
|             citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32) | ||||
|  | ||||
|         # Get the feature vectors | ||||
|         features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32)) | ||||
|         # Normalize the feature vectors | ||||
|         self.features = features / features.sum(dim=1, keepdim=True) | ||||
|  | ||||
|         # Get the class names and assign an unique integer to each of them | ||||
|         self.classes = {s: i for i, s in enumerate(set(content[:, -1]))} | ||||
|         # Get the labels as those integers | ||||
|         self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long) | ||||
|  | ||||
|         # Get the paper ids | ||||
|         paper_ids = np.array(content[:, 0], dtype=np.int32) | ||||
|         # Map of paper id to index | ||||
|         ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)} | ||||
|  | ||||
|         # Empty adjacency matrix - an identity matrix | ||||
|         self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool) | ||||
|  | ||||
|         # Mark the citations in the adjacency matrix | ||||
|         if self.include_edges: | ||||
|             for e in citations: | ||||
|                 # The pair of paper indexes | ||||
|                 e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]] | ||||
|                 # We build a symmetrical graph, where if paper $i$ referenced | ||||
|                 # paper $j$ we place an adge from $i$ to $j$ as well as an edge | ||||
|                 # from $j$ to $i$. | ||||
|                 self.adj_mat[e1][e2] = True | ||||
|                 self.adj_mat[e2][e1] = True | ||||
|  | ||||
|  | ||||
| class GATv2(Module): | ||||
| @ -115,7 +27,8 @@ class GATv2(Module): | ||||
|     This graph attention network has two [graph attention layers](index.html). | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float, share_weights: bool = True): | ||||
|     def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float, | ||||
|                  share_weights: bool = True): | ||||
|         """ | ||||
|         * `in_features` is the number of features per node | ||||
|         * `n_hidden` is the number of features in the first graph attention layer | ||||
| @ -127,11 +40,13 @@ class GATv2(Module): | ||||
|         super().__init__() | ||||
|  | ||||
|         # First graph attention layer where we concatenate the heads | ||||
|         self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout, share_weights=share_weights) | ||||
|         self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout, | ||||
|                                             share_weights=share_weights) | ||||
|         # Activation function after first graph attention layer | ||||
|         self.activation = nn.ELU() | ||||
|         # Final graph attention layer where we average the heads | ||||
|         self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout, share_weights=share_weights) | ||||
|         self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout, | ||||
|                                             share_weights=share_weights) | ||||
|         # Dropout | ||||
|         self.dropout = nn.Dropout(dropout) | ||||
|  | ||||
| @ -153,128 +68,26 @@ class GATv2(Module): | ||||
|         return self.output(x, adj_mat) | ||||
|  | ||||
|  | ||||
| def accuracy(output: torch.Tensor, labels: torch.Tensor): | ||||
|     """ | ||||
|     A simple function to calculate the accuracy | ||||
|     """ | ||||
|     return output.argmax(dim=-1).eq(labels).sum().item() / len(labels) | ||||
|  | ||||
|  | ||||
| class Configs(BaseConfigs): | ||||
| class Configs(GATConfigs): | ||||
|     """ | ||||
|     ## Configurations | ||||
|  | ||||
|     Since the experiment is same as [GAT experiment](../gat/experiment.html) but with | ||||
|     [GATv2 mode](index.html) we extend the same configs and change the model | ||||
|     """ | ||||
|  | ||||
|     # Model | ||||
|     model: GATv2 | ||||
|     # Number of nodes to train on | ||||
|     training_samples: int = 500 | ||||
|     # Number of features per node in the input | ||||
|     in_features: int | ||||
|     # Number of features in the first graph attention layer | ||||
|     n_hidden: int = 64 | ||||
|     # Number of heads | ||||
|     n_heads: int = 8 | ||||
|     # Number of classes for classification | ||||
|     n_classes: int | ||||
|     # Dropout probability | ||||
|     dropout: float = 0.7 | ||||
|     # Whether to include the citation network | ||||
|     include_edges: bool = True | ||||
|     # Dataset | ||||
|     dataset: CoraDataset | ||||
|     # Number of training iterations | ||||
|     epochs: int = 1_000 | ||||
|     # Loss function | ||||
|     loss_func = nn.CrossEntropyLoss() | ||||
|     # Device to train on | ||||
|     # | ||||
|     # This creates configs for device, so that | ||||
|     # we can change the device by passing a config value | ||||
|     device: torch.device = DeviceConfigs() | ||||
|     # Optimizer | ||||
|     optimizer: torch.optim.Adam | ||||
|     # Whether to share weights for source and target nodes of edges | ||||
|     share_weights: bool = True | ||||
|     # Set the model | ||||
|     model: GATv2 = 'gat_v2_model' | ||||
|  | ||||
|     def initialize(self): | ||||
|         """ | ||||
|         Initialize | ||||
|         """ | ||||
|         # Create the dataset | ||||
|         self.dataset = CoraDataset(self.include_edges) | ||||
|         # Get the number of classes | ||||
|         self.n_classes = len(self.dataset.classes) | ||||
|         # Number of features in the input | ||||
|         self.in_features = self.dataset.features.shape[1] | ||||
|         # Create the model | ||||
|         self.model = GATv2(self.in_features, self.n_hidden, self.n_classes, self.n_heads, self.dropout) | ||||
|         # Move the model to the device | ||||
|         self.model.to(self.device) | ||||
|         # Configurable optimizer, so that we can set the configurations | ||||
|         # such as learning rate by passing the dictionary later. | ||||
|         optimizer_conf = OptimizerConfigs() | ||||
|         optimizer_conf.parameters = self.model.parameters() | ||||
|         self.optimizer = optimizer_conf | ||||
|  | ||||
|     def run(self): | ||||
|         """ | ||||
|         ### Training loop | ||||
|  | ||||
|         We do full batch training since the dataset is small. | ||||
|         If we were to sample and train we will have to sample a set of | ||||
|         nodes for each training step along with the edges that span | ||||
|         across those selected nodes. | ||||
|         """ | ||||
|         # Move the feature vectors to the device | ||||
|         features = self.dataset.features.to(self.device) | ||||
|         # Move the labels to the device | ||||
|         labels = self.dataset.labels.to(self.device) | ||||
|         # Move the adjacency matrix to the device | ||||
|         edges_adj = self.dataset.adj_mat.to(self.device) | ||||
|         # Add an empty third dimension for the heads | ||||
|         edges_adj = edges_adj.unsqueeze(-1) | ||||
|  | ||||
|         # Random indexes | ||||
|         idx_rand = torch.randperm(len(labels)) | ||||
|         # Nodes for training | ||||
|         idx_train = idx_rand[:self.training_samples] | ||||
|         # Nodes for validation | ||||
|         idx_valid = idx_rand[self.training_samples:] | ||||
|  | ||||
|         # Training loop | ||||
|         for epoch in monit.loop(self.epochs): | ||||
|             # Set the model to training mode | ||||
|             self.model.train() | ||||
|             # Make all the gradients zero | ||||
|             self.optimizer.zero_grad() | ||||
|             # Evaluate the model | ||||
|             output = self.model(features, edges_adj) | ||||
|             # Get the loss for training nodes | ||||
|             loss = self.loss_func(output[idx_train], labels[idx_train]) | ||||
|             # Calculate gradients | ||||
|             loss.backward() | ||||
|             # Take optimization step | ||||
|             self.optimizer.step() | ||||
|             # Log the loss | ||||
|             tracker.add('loss.train', loss) | ||||
|             # Log the accuracy | ||||
|             tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train])) | ||||
|  | ||||
|             # Set mode to evaluation mode for validation | ||||
|             self.model.eval() | ||||
|  | ||||
|             # No need to compute gradients | ||||
|             with torch.no_grad(): | ||||
|                 # Evaluate the model again | ||||
|                 output = self.model(features, edges_adj) | ||||
|                 # Calculate the loss for validation nodes | ||||
|                 loss = self.loss_func(output[idx_valid], labels[idx_valid]) | ||||
|                 # Log the loss | ||||
|                 tracker.add('loss.valid', loss) | ||||
|                 # Log the accuracy | ||||
|                 tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid])) | ||||
|  | ||||
|             # Save logs | ||||
|             tracker.save() | ||||
| @option(Configs.model) | ||||
| def gat_v2_model(c: Configs): | ||||
|     """ | ||||
|     Create GAT model | ||||
|     """ | ||||
|     return GATv2(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout, c.share_weights).to(c.device) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
| @ -288,9 +101,9 @@ def main(): | ||||
|         'optimizer.optimizer': 'Adam', | ||||
|         'optimizer.learning_rate': 5e-3, | ||||
|         'optimizer.weight_decay': 5e-4, | ||||
|  | ||||
|         'dropout': 0.7, | ||||
|     }) | ||||
|     # Initialize | ||||
|     conf.initialize() | ||||
|  | ||||
|     # Start and watch the experiment | ||||
|     with experiment.start(): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri