11from typing import Dict
12
13import numpy as np
14import torch
15from torch import nn
16
17from labml import lab, monit, tracker, experiment
18from labml.configs import BaseConfigs, option, calculate
19from labml.utils import download
20from labml_helpers.device import DeviceConfigs
21from labml_helpers.module import Module
22from labml_nn.graphs.gat import GraphAttentionLayer
23from labml_nn.optimizers.configs import OptimizerConfigsCoraデータセットは研究論文のデータセットです。各論文には、単語の存在を示すバイナリ特徴ベクトルが与えられます。各論文は7つのクラスのいずれかに分類されます。データセットには引用ネットワークもあります
。論文はグラフの節点で、端は引用です。
タスクは、特徴ベクトルと引用ネットワークを入力として、ノードを7つのクラスに分類することです。
26class CoraDataset:各ノードのラベル
41    labels: torch.Tensorクラス名と一意の整数インデックスのセット
43    classes: Dict[str, int]全ノードの特徴ベクトル
45    features: torch.Tensorエッジ情報を含む隣接マトリックス。adj_mat[i][j]
True
i
j
もしもから端があったらね
48    adj_mat: torch.Tensorデータセットのダウンロード
50    @staticmethod
51    def _download():55        if not (lab.get_data_path() / 'cora').exists():
56            download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
57                                   lab.get_data_path() / 'cora.tgz')
58            download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())データセットの読み込み
60    def __init__(self, include_edges: bool = True):エッジを含めるかどうか。これは、引用ネットワークを無視すると精度がどれだけ失われるかをテストするものです
。67        self.include_edges = include_edgesデータセットのダウンロード
70        self._download()論文ID、特徴ベクトル、ラベルを読む
73        with monit.section('Read content file'):
74            content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))引用をロードします。整数のペアのリストです。
76        with monit.section('Read citations file'):
77            citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)特徴ベクトルを取得
80        features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))特徴ベクトルを正規化
82        self.features = features / features.sum(dim=1, keepdim=True)クラス名を取得し、それぞれに一意の整数を割り当てます。
85        self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}ラベルをそれらの整数として取得
87        self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)紙の ID を入手
90        paper_ids = np.array(content[:, 0], dtype=np.int32)紙IDと索引のマップ
92        ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}空の隣接行列-単位行列
95        self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)引用文献を隣接マトリックスに記入
98        if self.include_edges:
99            for e in citations:一対のペーパーインデックス
101                e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]対称的なグラフを作成します。紙が参照している紙の場合は、端を端から端に、端を端として配置します。
105                self.adj_mat[e1][e2] = True
106                self.adj_mat[e2][e1] = True109class GAT(Module):in_features
はノードあたりのフィーチャ数n_hidden
は最初のグラフアテンションレイヤーに含まれるフィーチャの数ですn_classes
はクラスの数n_heads
グラフ・アテンション・レイヤーのヘッド数ですdropout
は脱落確率です116    def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):124        super().__init__()ヘッドを連結する最初のグラフ・アテンション・レイヤー
127        self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)最初のグラフアテンションレイヤー後のアクティベーション機能
129        self.activation = nn.ELU()ヘッドを平均化する最後のグラフ・アテンション・レイヤー
131        self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)ドロップアウト
133        self.dropout = nn.Dropout(dropout)x
は形状の特徴ベクトルです [n_nodes, in_features]
adj_mat
[n_nodes, n_nodes, n_heads]
は次の形式の隣接行列です [n_nodes, n_nodes, 1]
135    def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):入力にドロップアウトを適用
142        x = self.dropout(x)最初のグラフアテンションレイヤー
144        x = self.layer1(x, adj_mat)アクティベーション機能
146        x = self.activation(x)ドロップアウト
148        x = self.dropout(x)ロジットの出力レイヤー (アクティベーションなし)
150        return self.output(x, adj_mat)精度を計算する簡単な関数
153def accuracy(output: torch.Tensor, labels: torch.Tensor):157    return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)160class Configs(BaseConfigs):モデル
166    model: GATトレーニングするノード数
168    training_samples: int = 500入力内のノードあたりのフィーチャ数
170    in_features: int最初のグラフアテンションレイヤーに含まれるフィーチャの数
172    n_hidden: int = 64ヘッド数
174    n_heads: int = 8分類するクラス数
176    n_classes: int脱落確率
178    dropout: float = 0.6引用ネットワークを含めるかどうか
180    include_edges: bool = Trueデータセット
182    dataset: CoraDatasetトレーニングの反復回数
184    epochs: int = 1_000損失関数
186    loss_func = nn.CrossEntropyLoss()191    device: torch.device = DeviceConfigs()オプティマイザー
193    optimizer: torch.optim.Adamデータセットが小さいので、フルバッチトレーニングを行います。サンプリングしてトレーニングする場合、トレーニングステップごとに一連のノードと、選択したノードにまたがるエッジをサンプリングする必要があります。
195    def run(self):特徴ベクトルをデバイスに移動します
205        features = self.dataset.features.to(self.device)ラベルをデバイスに移動
207        labels = self.dataset.labels.to(self.device)隣接マトリックスをデバイスに移動
209        edges_adj = self.dataset.adj_mat.to(self.device)頭部に空の 3 番目のディメンションを追加
211        edges_adj = edges_adj.unsqueeze(-1)ランダムインデックス
214        idx_rand = torch.randperm(len(labels))トレーニング用ノード
216        idx_train = idx_rand[:self.training_samples]検証用ノード
218        idx_valid = idx_rand[self.training_samples:]トレーニングループ
221        for epoch in monit.loop(self.epochs):モデルをトレーニングモードに設定
223            self.model.train()すべてのグラデーションをゼロにする
225            self.optimizer.zero_grad()モデルの評価
227            output = self.model(features, edges_adj)トレーニングノードで損失を被る
229            loss = self.loss_func(output[idx_train], labels[idx_train])勾配の計算
231            loss.backward()最適化の一歩を踏み出す
233            self.optimizer.step()損失を記録する
235            tracker.add('loss.train', loss)精度を記録する
237            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))検証用にモードを評価モードに設定
240            self.model.eval()勾配を計算する必要はありません
243            with torch.no_grad():モデルを再度評価してください
245                output = self.model(features, edges_adj)検証ノードの損失の計算
247                loss = self.loss_func(output[idx_valid], labels[idx_valid])損失を記録する
249                tracker.add('loss.valid', loss)精度を記録する
251                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))ログを保存
254            tracker.save()Cora データセットの作成
257@option(Configs.dataset)
258def cora_dataset(c: Configs):262    return CoraDataset(c.include_edges)クラス数を取得
266calculate(Configs.n_classes, lambda c: len(c.dataset.classes))入力内のフィーチャの数
268calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])GAT モデルの作成
271@option(Configs.model)
272def gat_model(c: Configs):276    return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)設定可能なオプティマイザーの作成
279@option(Configs.optimizer)
280def _optimizer(c: Configs):284    opt_conf = OptimizerConfigs()
285    opt_conf.parameters = c.model.parameters()
286    return opt_conf289def main():構成の作成
291    conf = Configs()テストを作成
293    experiment.create(name='gat')構成を計算します。
295    experiment.configs(conf, {アダム・オプティマイザー
297        'optimizer.optimizer': 'Adam',
298        'optimizer.learning_rate': 5e-3,
299        'optimizer.weight_decay': 5e-4,
300    })実験を開始して見る
303    with experiment.start():トレーニングを実行
305        conf.run()309if __name__ == '__main__':
310    main()