ප්රස්තාරයඅවධානය ජාල (GAT)

මෙය PyTorch කඩදාසි ප්රස්ථාර අවධානය යොමු කිරීමේ ජාල ක්රියාත්මක කිරීමයි.

GATSප්රස්තාර දත්ත මත ක්රියා කරයි. ප්රස්ථාරයක් නෝඩ් සහ දාර සම්බන්ධ කරන නෝඩ් වලින් සමන්විත වේ. උදාහරණයක් ලෙස, කෝරා දත්ත කට්ටලයේ නෝඩ් පර්යේෂණ පත්රිකා වන අතර දාර යනු පත්රිකා සම්බන්ධ කරන උපුටා දැක්වීම් වේ.

GATවිසින් වෙස්මූඩ් ස්වයං අවධානය භාවිතා කරයි, ට්රාන්ස්ෆෝමර්වලට සමාන ආකාරයේ. GAT එකිනෙකට ඉහළින් ගොඩගැසී ඇති ප්රස්ථාර අවධානය ස්ථර වලින් සමන්විත වේ. එක් එක් ප්රස්තාරය අවධානය ස්ථරය යෙදවුම් සහ ප්රතිදානයන් පරිවර්තනය කාවැද්දීම් ලෙස node එකක් මතම ඊට අදාල කාවැද්දීම් ලැබෙන. නෝඩ් කාවැද්දීම් එය සම්බන්ධ කර ඇති වෙනත් නෝඩ් වල කාවැද්දීම් කෙරෙහි අවධානය යොමු කරයි. ක්රියාත්මක කිරීම සමඟ ප්රස්ථාර අවධානය ස්ථර පිළිබඳ විස්තර ඇතුළත් වේ.

කෝරා දත්ත කට්ටලය මත ස්ථර දෙකක GAT පුහුණු කිරීම සඳහා පුහුණු කේතය මෙන්න.

View Run

30import torch
31from torch import nn
32
33from labml_helpers.module import Module

ප්රස්තාරයඅවධානය ස්ථරය

මෙයතනි ප්රස්ථාර අවධානය යොමු කරන ස්ථරයකි. GAT එවැනි ස්ථර කිහිපයකින් සෑදී ඇත.

ආදානසහ ප්රතිදානයන් ලෙස කොතැනද , එය අවශ්ය වේ .

36class GraphAttentionLayer(Module):
  • in_features , , node එකක් මතම ඊට අදාල ආදාන ලක්ෂණ සංඛ්යාව
  • out_features , , node එකක් මතම ඊට අදාල ප්රතිදානය විශේෂාංග සංඛ්යාව වේ
  • n_heads , , අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • is_concat බහු-හිස ප්රති results ල සංයුක්ත කළ යුතුද නැතහොත් සාමාන්යය විය යුතුද යන්න
  • dropout අතහැර දැමීමේ සම්භාවිතාව
  • leaky_relu_negative_slope යනු කාන්දු වන රිලූ සක්රිය කිරීම සඳහා negative ණ බෑවුමයි
50    def __init__(self, in_features: int, out_features: int, n_heads: int,
51                 is_concat: bool = True,
52                 dropout: float = 0.6,
53                 leaky_relu_negative_slope: float = 0.2):
62        super().__init__()
63
64        self.is_concat = is_concat
65        self.n_heads = n_heads

හිසකටමානයන් ගණන ගණනය කරන්න

68        if is_concat:
69            assert out_features % n_heads == 0

අපිබහු හිස් සංකෝචනය කරන්නේ නම්

71            self.n_hidden = out_features // n_heads
72        else:

අපිබහු හිස් සාමාන්යය කරන්නේ නම්

74            self.n_hidden = out_features

මූලිකපරිවර්තනය සඳහා රේඛීය ස්ථරය; i.e. ස්වයං අවධානය පෙර node එකක් මතම ඊට අදාල කාවැද්දීම් පරිවර්තනය කිරීමට

78        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)

අවධානයලකුණු ගණනය කිරීම සඳහා රේඛීය ස්ථරය

80        self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)

අවධානයලකුණු සඳහා සක්රිය

82        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)

අවධානයගණනය කිරීමට සොෆ්ට්මැක්ස්

84        self.softmax = nn.Softmax(dim=1)

අවධානයසඳහා යෙදිය යුතු ස්තරය

86        self.dropout = nn.Dropout(dropout)
  • h , හැඩයේ ආදාන නෝඩ් කාවැද්දීම් [n_nodes, in_features] වේ.
  • adj_mat යනු හැඩයේ විඝටන අනුකෘතියකි [n_nodes, n_nodes, n_heads] . එක් එක් හිස සඳහා adjacency එක සමාන [n_nodes, n_nodes, 1] බැවින් අපි හැඩය භාවිතා කරමු.

සමපාතඅනුකෘතිය නෝඩ් අතර දාර (හෝ සම්බන්ධතා) නියෝජනය කරයි. adj_mat[i][j] නෝඩ් සිට නෝඩ් සිට නෝඩ් i දක්වා දාරයක් තිබේ True නම් j .

88    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):

නෝඩ්ගණන

99        n_nodes = h.shape[0]

එක්එක් හිස සඳහා ආරම්භක පරිවර්තනය. අපි තනි රේඛීය පරිවර්තනයක් කර එක් එක් හිස සඳහා එය බෙදුවෙමු.

104        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)

අවධානයලකුණු ගණනය කරන්න

එක්එක් හිස සඳහා අපි මේවා ගණනය කරමු . සරලබව සඳහා අපි මඟ හරවා ඇත්තෙමු.

අවධානය ලකුණු (වැදගත්කම) node එකක් මතම ඊට අදාල සිට node එකක් මතම ඊට අදාල අපි එක් එක් හිස සඳහා මෙය ගණනය කරමු.

අවධානය යොමු කිරීමේ යාන්ත්රණය, එය අවධානය ලකුණු ගණනය කරයි. කඩදාසි concatenates , සහ a විසින් අනුගමනය බර දෛශිකයක් සමග රේඛීය පරිවර්තනයක් කරන්නේ .

පළමුවඅපි සියලු යුගල සඳහා ගණනය කරමු .

g_repeat එක් එක් node එකක් මතම ඊට අදාල කාවැද්දීම නැවත නැවත n_nodes වතාවක් කොහෙද ලැබෙන.

135        g_repeat = g.repeat(n_nodes, 1, 1)

g_repeat_interleave එක් එක් node එකක් මතම ඊට අදාල කාවැද්දීම නැවත නැවත n_nodes වතාවක් කොහෙද ලැබෙන.

140        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)

දැන්අපි ලබා ගැනීමට එකඟ වෙමු

148        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)

ඒනිසා නැවත g_concat[i, j] හැඩගස්වා

150        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)

ගණනයහැඩයෙන් e යුක්ත වේ [n_nodes, n_nodes, n_heads, 1]

158        e = self.activation(self.attn(g_concat))

ප්රමාණයේඅවසාන මානය ඉවත් කරන්න 1

160        e = e.squeeze(-1)

මෙමadjacency න්යාසය හැඩය [n_nodes, n_nodes, n_heads] හෝ තිබිය යුතුය[n_nodes, n_nodes, 1]

164        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
165        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
166        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads

මැස්සිඅනුකෘතිය මත පදනම් වූ මාස්ක්. සිට දාරයක් නොමැති නම් දක්වා සකසා ඇත.

169        e = e.masked_fill(adj_mat == 0, float('-inf'))

ඉන්පසුඅපි අවධානය යොමු කිරීමේ ලකුණු සාමාන්යකරණය කරමු (හෝ සංගුණක)

සම්බන්ධවූ නෝඩ් කට්ටලය කොහේද?

අපිමෙය කරන්නේ අසම්බන්ධිත යුගල සඳහා සම්බන්ධ නොවූ සැකසීමෙනි.

179        a = self.softmax(e)

අතහැරදැමීමේ විධිමත් කිරීම යොදන්න

182        a = self.dropout(a)

එක්එක් හිස සඳහා අවසාන ප්රතිදානය ගණනය කරන්න

සටහන: කඩදාසි වල අවසාන සක්රිය කිරීම ඇතුළත් වේ අපි මෙය ප්රස්තාරය අවධානය යොමු කිරීමේ ස්ථර ක්රියාත්මක කිරීමෙන් මඟ හැර ඇති අතර වෙනත් PyTorch මොඩියුල අර්ථ දක්වා ඇති ආකාරය සමඟ ගැලපෙන පරිදි GAT ආකෘතිය මත එය භාවිතා කරන්න - සක්රිය කිරීම වෙනම ස්ථරයක් ලෙස.

191        attn_res = torch.einsum('ijh,jhf->ihf', a, g)

ප්රධානීන්සංයුක්ත කරන්න

194        if self.is_concat:

196            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)

ප්රධානීන්ගේමධ්යන්යය ගන්න

198        else:

200            return attn_res.mean(dim=1)