13from typing import Dict
14
15import numpy as np
16import torch
17from torch import nn
18
19from labml import lab, monit, tracker, experiment
20from labml.configs import BaseConfigs, option, calculate
21from labml.utils import download
22from labml_helpers.device import DeviceConfigs
23from labml_helpers.module import Module
24from labml_nn.graphs.gat import GraphAttentionLayer
25from labml_nn.optimizers.configs import OptimizerConfigsCora 数据集是研究论文的数据集。对于每篇论文,我们都会得到一个表示单词存在的二进制特征向量。每篇论文分为 7 类中的一类。该数据集还具有引文网络。
论文是图表的节点,边缘是引文。
任务是使用特征向量和引文网络作为输入将边分为 7 个类。
28class CoraDataset:每个节点的标签
43    labels: torch.Tensor一组类名和一个唯一的整数索引
45    classes: Dict[str, int]所有节点的特征向量
47    features: torch.Tensor包含边信息的邻接矩阵。adj_mat[i][j]
True
如果存在从i
到的边缘j
。
50    adj_mat: torch.Tensor下载数据集
52    @staticmethod
53    def _download():57        if not (lab.get_data_path() / 'cora').exists():
58            download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
59                                   lab.get_data_path() / 'cora.tgz')
60            download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())加载数据集
62    def __init__(self, include_edges: bool = True):是否包括边缘。这是测试如果我们忽略引文网络会损失多少准确性。
69        self.include_edges = include_edges下载数据集
72        self._download()阅读纸张 ID、特征矢量和标签
75        with monit.section('Read content file'):
76            content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))加载引文,这是一个整数对的列表。
78        with monit.section('Read citations file'):
79            citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)获取特征向量
82        features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))归一化特征向量
84        self.features = features / features.sum(dim=1, keepdim=True)获取类名并为每个类分配一个唯一的整数
87        self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}获取这些整数的标签
89        self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)获取纸质证件
92        paper_ids = np.array(content[:, 0], dtype=np.int32)纸张 ID 到索引的映射
94        ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}空邻接矩阵-恒等矩阵
97        self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)在邻接矩阵中标记引用
100        if self.include_edges:
101            for e in citations:一对纸质索引
103                e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]我们构建一个对称的图形,如果纸张引用了纸张,我们会在其中放置一个从到的徽章以及从到。
107                self.adj_mat[e1][e2] = True
108                self.adj_mat[e2][e1] = Truein_features
是每个节点的要素数n_hidden
是第一个图形关注层中的要素数n_classes
是类的数量n_heads
是图表关注层中的头部数量dropout
是辍学概率118    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):126        super().__init__()我们连接头部的第一个图形注意层
129        self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)第一个图形关注层之后的激活功能
131        self.activation = nn.ELU()最后一张图关注层,我们平均头部
133        self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)辍学
135        self.dropout = nn.Dropout(dropout)x
是形状的特征向量[n_nodes, in_features]
adj_mat
是形式的邻接矩阵[n_nodes, n_nodes, n_heads]
或[n_nodes, n_nodes, 1]
137    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):将丢失应用于输入
144        x = self.dropout(x)第一个图形关注层
146        x = self.layer1(x, adj_mat)激活功能
148        x = self.activation(x)辍学
150        x = self.dropout(x)logits 的输出层(未激活)
152        return self.output(x, adj_mat)计算精度的简单函数
155def accuracy(output: torch.Tensor, labels: torch.Tensor):159    return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)162class Configs(BaseConfigs):型号
168    model: GAT要训练的节点数
170    training_samples: int = 500输入中每个节点的要素数
172    in_features: int第一个图形关注图层中的要素数
174    n_hidden: int = 64头数
176    n_heads: int = 8用于分类的类数
178    n_classes: int辍学概率
180    dropout: float = 0.6是否包括引文网络
182    include_edges: bool = True数据集
184    dataset: CoraDataset训练迭代次数
186    epochs: int = 1_000亏损函数
188    loss_func = nn.CrossEntropyLoss()193    device: torch.device = DeviceConfigs()优化器
195    optimizer: torch.optim.Adam197    def run(self):将特征向量移动到设备
207        features = self.dataset.features.to(self.device)将标签移到设备上
209        labels = self.dataset.labels.to(self.device)将邻接矩阵移至设备
211        edges_adj = self.dataset.adj_mat.to(self.device)为头部添加一个空的第三个维度
213        edges_adj = edges_adj.unsqueeze(-1)随机索引
216        idx_rand = torch.randperm(len(labels))训练节点
218        idx_train = idx_rand[:self.training_samples]用于验证的节点
220        idx_valid = idx_rand[self.training_samples:]训练循环
223        for epoch in monit.loop(self.epochs):将模型设置为训练模式
225            self.model.train()将所有渐变设为零
227            self.optimizer.zero_grad()评估模型
229            output = self.model(features, edges_adj)获得训练节点的损失
231            loss = self.loss_func(output[idx_train], labels[idx_train])计算梯度
233            loss.backward()采取优化步骤
235            self.optimizer.step()记录损失
237            tracker.add('loss.train', loss)记录准确性
239            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))将模式设置为评估模式以进行验证
242            self.model.eval()无需计算梯度
245            with torch.no_grad():再次评估模型
247                output = self.model(features, edges_adj)计算验证节点的损失
249                loss = self.loss_func(output[idx_valid], labels[idx_valid])记录损失
251                tracker.add('loss.valid', loss)记录准确性
253                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))保存日志
256            tracker.save()创建 Cora 数据集
259@option(Configs.dataset)
260def cora_dataset(c: Configs):264    return CoraDataset(c.include_edges)获取班级数
268calculate(Configs.n_classes, lambda c: len(c.dataset.classes))输入中的要素数量
270calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])创建 GAT 模型
273@option(Configs.model)
274def gat_model(c: Configs):278    return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)创建可配置的优化器
281@option(Configs.optimizer)
282def _optimizer(c: Configs):286    opt_conf = OptimizerConfigs()
287    opt_conf.parameters = c.model.parameters()
288    return opt_conf291def main():创建配置
293    conf = Configs()创建实验
295    experiment.create(name='gat')计算配置。
297    experiment.configs(conf, {Adam 优化器
299        'optimizer.optimizer': 'Adam',
300        'optimizer.learning_rate': 5e-3,
301        'optimizer.weight_decay': 5e-4,
302    })开始观看实验
305    with experiment.start():运行训练
307        conf.run()311if __name__ == '__main__':
312    main()