13from typing import Dict
14
15import numpy as np
16import torch
17from torch import nn
18
19from labml import lab, monit, tracker, experiment
20from labml.configs import BaseConfigs, option, calculate
21from labml.utils import download
22from labml_helpers.device import DeviceConfigs
23from labml_helpers.module import Module
24from labml_nn.graphs.gat import GraphAttentionLayer
25from labml_nn.optimizers.configs import OptimizerConfigsCora dataset is a dataset of research papers. For each paper we are given a binary feature vector that indicates the presence of words. Each paper is classified into one of 7 classes. The dataset also has the citation network.
The papers are the nodes of the graph and the edges are the citations.
The task is to classify the nodes to the 7 classes with feature vectors and citation network as input.
28class CoraDataset:Labels for each node
43    labels: torch.TensorSet of class names and an unique integer index
45    classes: Dict[str, int]Feature vectors for all nodes
47    features: torch.TensorAdjacency matrix with the edge information. adj_mat[i][j]
 is True
 if there is an edge from i
 to j
. 
50    adj_mat: torch.TensorDownload the dataset
52    @staticmethod
53    def _download():57        if not (lab.get_data_path() / 'cora').exists():
58            download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
59                                   lab.get_data_path() / 'cora.tgz')
60            download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())Load the dataset
62    def __init__(self, include_edges: bool = True):Whether to include edges. This is test how much accuracy is lost if we ignore the citation network.
69        self.include_edges = include_edgesDownload dataset
72        self._download()Read the paper ids, feature vectors, and labels
75        with monit.section('Read content file'):
76            content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))Load the citations, it's a list of pairs of integers.
78        with monit.section('Read citations file'):
79            citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)Get the feature vectors
82        features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))Normalize the feature vectors
84        self.features = features / features.sum(dim=1, keepdim=True)Get the class names and assign an unique integer to each of them
87        self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}Get the labels as those integers
89        self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)Get the paper ids
92        paper_ids = np.array(content[:, 0], dtype=np.int32)Map of paper id to index
94        ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}Empty adjacency matrix - an identity matrix
97        self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)Mark the citations in the adjacency matrix
100        if self.include_edges:
101            for e in citations:The pair of paper indexes
103                e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]We build a symmetrical graph, where if paper referenced paper we place an adge from to as well as an edge from to .
107                self.adj_mat[e1][e2] = True
108                self.adj_mat[e2][e1] = True111class GAT(Module):in_features
 is the number of features per node n_hidden
 is the number of features in the first graph attention layer n_classes
 is the number of classes n_heads
 is the number of heads in the graph attention layers dropout
 is the dropout probability118    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):126        super().__init__()First graph attention layer where we concatenate the heads
129        self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)Activation function after first graph attention layer
131        self.activation = nn.ELU()Final graph attention layer where we average the heads
133        self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)Dropout
135        self.dropout = nn.Dropout(dropout)x
 is the features vectors of shape [n_nodes, in_features]
 adj_mat
 is the adjacency matrix of the form  [n_nodes, n_nodes, n_heads]
 or [n_nodes, n_nodes, 1]
137    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):Apply dropout to the input
144        x = self.dropout(x)First graph attention layer
146        x = self.layer1(x, adj_mat)Activation function
148        x = self.activation(x)Dropout
150        x = self.dropout(x)Output layer (without activation) for logits
152        return self.output(x, adj_mat)A simple function to calculate the accuracy
155def accuracy(output: torch.Tensor, labels: torch.Tensor):159    return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)162class Configs(BaseConfigs):Model
168    model: GATNumber of nodes to train on
170    training_samples: int = 500Number of features per node in the input
172    in_features: intNumber of features in the first graph attention layer
174    n_hidden: int = 64Number of heads
176    n_heads: int = 8Number of classes for classification
178    n_classes: intDropout probability
180    dropout: float = 0.6Whether to include the citation network
182    include_edges: bool = TrueDataset
184    dataset: CoraDatasetNumber of training iterations
186    epochs: int = 1_000Loss function
188    loss_func = nn.CrossEntropyLoss()Device to train on
This creates configs for device, so that we can change the device by passing a config value
193    device: torch.device = DeviceConfigs()Optimizer
195    optimizer: torch.optim.AdamWe do full batch training since the dataset is small. If we were to sample and train we will have to sample a set of nodes for each training step along with the edges that span across those selected nodes.
197    def run(self):Move the feature vectors to the device
207        features = self.dataset.features.to(self.device)Move the labels to the device
209        labels = self.dataset.labels.to(self.device)Move the adjacency matrix to the device
211        edges_adj = self.dataset.adj_mat.to(self.device)Add an empty third dimension for the heads
213        edges_adj = edges_adj.unsqueeze(-1)Random indexes
216        idx_rand = torch.randperm(len(labels))Nodes for training
218        idx_train = idx_rand[:self.training_samples]Nodes for validation
220        idx_valid = idx_rand[self.training_samples:]Training loop
223        for epoch in monit.loop(self.epochs):Set the model to training mode
225            self.model.train()Make all the gradients zero
227            self.optimizer.zero_grad()Evaluate the model
229            output = self.model(features, edges_adj)Get the loss for training nodes
231            loss = self.loss_func(output[idx_train], labels[idx_train])Calculate gradients
233            loss.backward()Take optimization step
235            self.optimizer.step()Log the loss
237            tracker.add('loss.train', loss)Log the accuracy
239            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))Set mode to evaluation mode for validation
242            self.model.eval()No need to compute gradients
245            with torch.no_grad():Evaluate the model again
247                output = self.model(features, edges_adj)Calculate the loss for validation nodes
249                loss = self.loss_func(output[idx_valid], labels[idx_valid])Log the loss
251                tracker.add('loss.valid', loss)Log the accuracy
253                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))Save logs
256            tracker.save()Create Cora dataset
259@option(Configs.dataset)
260def cora_dataset(c: Configs):264    return CoraDataset(c.include_edges)Get the number of classes
268calculate(Configs.n_classes, lambda c: len(c.dataset.classes))Number of features in the input
270calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])Create GAT model
273@option(Configs.model)
274def gat_model(c: Configs):278    return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)Create configurable optimizer
281@option(Configs.optimizer)
282def _optimizer(c: Configs):286    opt_conf = OptimizerConfigs()
287    opt_conf.parameters = c.model.parameters()
288    return opt_conf291def main():Create configurations
293    conf = Configs()Create an experiment
295    experiment.create(name='gat')Calculate configurations.
297    experiment.configs(conf, {Adam optimizer
299        'optimizer.optimizer': 'Adam',
300        'optimizer.learning_rate': 5e-3,
301        'optimizer.weight_decay': 5e-4,
302    })Start and watch the experiment
305    with experiment.start():Run the training
307        conf.run()311if __name__ == '__main__':
312    main()