13from typing import Dict
14
15import numpy as np
16import torch
17from torch import nn
18
19from labml import lab, monit, tracker, experiment
20from labml.configs import BaseConfigs, option, calculate
21from labml.utils import download
22from labml_helpers.device import DeviceConfigs
23from labml_helpers.module import Module
24from labml_nn.graphs.gat import GraphAttentionLayer
25from labml_nn.optimizers.configs import OptimizerConfigsCora 数据集是研究论文的数据集。对于每篇论文,我们都会得到一个表示单词存在的二进制特征向量。每篇论文分为 7 类中的一类。该数据集还具有引文网络。
论文是图表的节点,边缘是引文。
任务是使用特征向量和引文网络作为输入将边分为 7 个类。
28class CoraDataset:每个节点的标签
43 labels: torch.Tensor一组类名和一个唯一的整数索引
45 classes: Dict[str, int]所有节点的特征向量
47 features: torch.Tensor包含边信息的邻接矩阵。adj_mat[i][j]
True
如果存在从i
到的边缘j
。
50 adj_mat: torch.Tensor下载数据集
52 @staticmethod
53 def _download():57 if not (lab.get_data_path() / 'cora').exists():
58 download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
59 lab.get_data_path() / 'cora.tgz')
60 download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())加载数据集
62 def __init__(self, include_edges: bool = True):是否包括边缘。这是测试如果我们忽略引文网络会损失多少准确性。
69 self.include_edges = include_edges下载数据集
72 self._download()阅读纸张 ID、特征矢量和标签
75 with monit.section('Read content file'):
76 content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))加载引文,这是一个整数对的列表。
78 with monit.section('Read citations file'):
79 citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)获取特征向量
82 features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))归一化特征向量
84 self.features = features / features.sum(dim=1, keepdim=True)获取类名并为每个类分配一个唯一的整数
87 self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}获取这些整数的标签
89 self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)获取纸质证件
92 paper_ids = np.array(content[:, 0], dtype=np.int32)纸张 ID 到索引的映射
94 ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}空邻接矩阵-恒等矩阵
97 self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)在邻接矩阵中标记引用
100 if self.include_edges:
101 for e in citations:一对纸质索引
103 e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]我们构建一个对称的图形,如果纸张引用了纸张,我们会在其中放置一个从到的徽章以及从到。
107 self.adj_mat[e1][e2] = True
108 self.adj_mat[e2][e1] = Truein_features
是每个节点的要素数n_hidden
是第一个图形关注层中的要素数n_classes
是类的数量n_heads
是图表关注层中的头部数量dropout
是辍学概率118 def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):126 super().__init__()我们连接头部的第一个图形注意层
129 self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)第一个图形关注层之后的激活功能
131 self.activation = nn.ELU()最后一张图关注层,我们平均头部
133 self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)辍学
135 self.dropout = nn.Dropout(dropout)x
是形状的特征向量[n_nodes, in_features]
adj_mat
是形式的邻接矩阵[n_nodes, n_nodes, n_heads]
或[n_nodes, n_nodes, 1]
137 def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):将丢失应用于输入
144 x = self.dropout(x)第一个图形关注层
146 x = self.layer1(x, adj_mat)激活功能
148 x = self.activation(x)辍学
150 x = self.dropout(x)logits 的输出层(未激活)
152 return self.output(x, adj_mat)计算精度的简单函数
155def accuracy(output: torch.Tensor, labels: torch.Tensor):159 return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)162class Configs(BaseConfigs):型号
168 model: GAT要训练的节点数
170 training_samples: int = 500输入中每个节点的要素数
172 in_features: int第一个图形关注图层中的要素数
174 n_hidden: int = 64头数
176 n_heads: int = 8用于分类的类数
178 n_classes: int辍学概率
180 dropout: float = 0.6是否包括引文网络
182 include_edges: bool = True数据集
184 dataset: CoraDataset训练迭代次数
186 epochs: int = 1_000亏损函数
188 loss_func = nn.CrossEntropyLoss()193 device: torch.device = DeviceConfigs()优化器
195 optimizer: torch.optim.Adam197 def run(self):将特征向量移动到设备
207 features = self.dataset.features.to(self.device)将标签移到设备上
209 labels = self.dataset.labels.to(self.device)将邻接矩阵移至设备
211 edges_adj = self.dataset.adj_mat.to(self.device)为头部添加一个空的第三个维度
213 edges_adj = edges_adj.unsqueeze(-1)随机索引
216 idx_rand = torch.randperm(len(labels))训练节点
218 idx_train = idx_rand[:self.training_samples]用于验证的节点
220 idx_valid = idx_rand[self.training_samples:]训练循环
223 for epoch in monit.loop(self.epochs):将模型设置为训练模式
225 self.model.train()将所有渐变设为零
227 self.optimizer.zero_grad()评估模型
229 output = self.model(features, edges_adj)获得训练节点的损失
231 loss = self.loss_func(output[idx_train], labels[idx_train])计算梯度
233 loss.backward()采取优化步骤
235 self.optimizer.step()记录损失
237 tracker.add('loss.train', loss)记录准确性
239 tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))将模式设置为评估模式以进行验证
242 self.model.eval()无需计算梯度
245 with torch.no_grad():再次评估模型
247 output = self.model(features, edges_adj)计算验证节点的损失
249 loss = self.loss_func(output[idx_valid], labels[idx_valid])记录损失
251 tracker.add('loss.valid', loss)记录准确性
253 tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))保存日志
256 tracker.save()创建 Cora 数据集
259@option(Configs.dataset)
260def cora_dataset(c: Configs):264 return CoraDataset(c.include_edges)获取班级数
268calculate(Configs.n_classes, lambda c: len(c.dataset.classes))输入中的要素数量
270calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])创建 GAT 模型
273@option(Configs.model)
274def gat_model(c: Configs):278 return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)创建可配置的优化器
281@option(Configs.optimizer)
282def _optimizer(c: Configs):286 opt_conf = OptimizerConfigs()
287 opt_conf.parameters = c.model.parameters()
288 return opt_conf291def main():创建配置
293 conf = Configs()创建实验
295 experiment.create(name='gat')计算配置。
297 experiment.configs(conf, {Adam 优化器
299 'optimizer.optimizer': 'Adam',
300 'optimizer.learning_rate': 5e-3,
301 'optimizer.weight_decay': 5e-4,
302 })开始观看实验
305 with experiment.start():运行训练
307 conf.run()311if __name__ == '__main__':
312 main()