මෙය PyTorch කඩදාසි ප්රස්ථාර අවධානය යොමු කිරීමේ ජාල ක්රියාත්මක කිරීමයි.
GATSප්රස්තාර දත්ත මත ක්රියා කරයි. ප්රස්ථාරයක් නෝඩ් සහ දාර සම්බන්ධ කරන නෝඩ් වලින් සමන්විත වේ. උදාහරණයක් ලෙස, කෝරා දත්ත කට්ටලයේ නෝඩ් පර්යේෂණ පත්රිකා වන අතර දාර යනු පත්රිකා සම්බන්ධ කරන උපුටා දැක්වීම් වේ.
GATවිසින් වෙස්මූඩ් ස්වයං අවධානය භාවිතා කරයි, ට්රාන්ස්ෆෝමර්වලට සමාන ආකාරයේ. GAT එකිනෙකට ඉහළින් ගොඩගැසී ඇති ප්රස්ථාර අවධානය ස්ථර වලින් සමන්විත වේ. එක් එක් ප්රස්තාරය අවධානය ස්ථරය යෙදවුම් සහ ප්රතිදානයන් පරිවර්තනය කාවැද්දීම් ලෙස node එකක් මතම ඊට අදාල කාවැද්දීම් ලැබෙන. නෝඩ් කාවැද්දීම් එය සම්බන්ධ කර ඇති වෙනත් නෝඩ් වල කාවැද්දීම් කෙරෙහි අවධානය යොමු කරයි. ක්රියාත්මක කිරීම සමඟ ප්රස්ථාර අවධානය ස්ථර පිළිබඳ විස්තර ඇතුළත් වේ.
කෝරා දත්ත කට්ටලය මත ස්ථර දෙකක GAT පුහුණු කිරීම සඳහා පුහුණු කේතය මෙන්න.
30import torch
31from torch import nn
32
33from labml_helpers.module import Module
මෙයතනි ප්රස්ථාර අවධානය යොමු කරන ස්ථරයකි. GAT එවැනි ස්ථර කිහිපයකින් සෑදී ඇත.
ආදානසහ ප්රතිදානයන් ලෙස කොතැනද , එය අවශ්ය වේ .
36class GraphAttentionLayer(Module):
in_features
, , node එකක් මතම ඊට අදාල ආදාන ලක්ෂණ සංඛ්යාව out_features
, , node එකක් මතම ඊට අදාල ප්රතිදානය විශේෂාංග සංඛ්යාව වේ n_heads
, , අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ is_concat
බහු-හිස ප්රති results ල සංයුක්ත කළ යුතුද නැතහොත් සාමාන්යය විය යුතුද යන්න dropout
අතහැර දැමීමේ සම්භාවිතාව leaky_relu_negative_slope
යනු කාන්දු වන රිලූ සක්රිය කිරීම සඳහා negative ණ බෑවුමයි50 def __init__(self, in_features: int, out_features: int, n_heads: int,
51 is_concat: bool = True,
52 dropout: float = 0.6,
53 leaky_relu_negative_slope: float = 0.2):
62 super().__init__()
63
64 self.is_concat = is_concat
65 self.n_heads = n_heads
හිසකටමානයන් ගණන ගණනය කරන්න
68 if is_concat:
69 assert out_features % n_heads == 0
අපිබහු හිස් සංකෝචනය කරන්නේ නම්
71 self.n_hidden = out_features // n_heads
72 else:
අපිබහු හිස් සාමාන්යය කරන්නේ නම්
74 self.n_hidden = out_features
මූලිකපරිවර්තනය සඳහා රේඛීය ස්ථරය; i.e. ස්වයං අවධානය පෙර node එකක් මතම ඊට අදාල කාවැද්දීම් පරිවර්තනය කිරීමට
78 self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
අවධානයලකුණු ගණනය කිරීම සඳහා රේඛීය ස්ථරය
80 self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
අවධානයලකුණු සඳහා සක්රිය
82 self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
අවධානයගණනය කිරීමට සොෆ්ට්මැක්ස්
84 self.softmax = nn.Softmax(dim=1)
අවධානයසඳහා යෙදිය යුතු ස්තරය
86 self.dropout = nn.Dropout(dropout)
h
, හැඩයේ ආදාන නෝඩ් කාවැද්දීම් [n_nodes, in_features]
වේ. adj_mat
යනු හැඩයේ විඝටන අනුකෘතියකි [n_nodes, n_nodes, n_heads]
. එක් එක් හිස සඳහා adjacency එක සමාන [n_nodes, n_nodes, 1]
බැවින් අපි හැඩය භාවිතා කරමු. සමපාතඅනුකෘතිය නෝඩ් අතර දාර (හෝ සම්බන්ධතා) නියෝජනය කරයි. adj_mat[i][j]
නෝඩ් සිට නෝඩ් සිට නෝඩ් i
දක්වා දාරයක් තිබේ True
නම් j
.
88 def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
නෝඩ්ගණන
99 n_nodes = h.shape[0]
එක්එක් හිස සඳහා ආරම්භක පරිවර්තනය. අපි තනි රේඛීය පරිවර්තනයක් කර එක් එක් හිස සඳහා එය බෙදුවෙමු.
104 g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)
එක්එක් හිස සඳහා අපි මේවා ගණනය කරමු . සරලබව සඳහා අපි මඟ හරවා ඇත්තෙමු.
අවධානය ලකුණු (වැදගත්කම) node එකක් මතම ඊට අදාල සිට node එකක් මතම ඊට අදාල අපි එක් එක් හිස සඳහා මෙය ගණනය කරමු.
අවධානය යොමු කිරීමේ යාන්ත්රණය, එය අවධානය ලකුණු ගණනය කරයි. කඩදාසි concatenates , සහ a විසින් අනුගමනය බර දෛශිකයක් සමග රේඛීය පරිවර්තනයක් කරන්නේ .
පළමුවඅපි සියලු යුගල සඳහා ගණනය කරමු .
g_repeat
එක් එක් node එකක් මතම ඊට අදාල කාවැද්දීම නැවත නැවත n_nodes
වතාවක් කොහෙද ලැබෙන.
135 g_repeat = g.repeat(n_nodes, 1, 1)
g_repeat_interleave
එක් එක් node එකක් මතම ඊට අදාල කාවැද්දීම නැවත නැවත n_nodes
වතාවක් කොහෙද ලැබෙන.
140 g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
දැන්අපි ලබා ගැනීමට එකඟ වෙමු
148 g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
ඒනිසා නැවත g_concat[i, j]
හැඩගස්වා
150 g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
ගණනයහැඩයෙන් e
යුක්ත වේ [n_nodes, n_nodes, n_heads, 1]
158 e = self.activation(self.attn(g_concat))
ප්රමාණයේඅවසාන මානය ඉවත් කරන්න 1
160 e = e.squeeze(-1)
මෙමadjacency න්යාසය හැඩය [n_nodes, n_nodes, n_heads]
හෝ තිබිය යුතුය[n_nodes, n_nodes, 1]
164 assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
165 assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
166 assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
මැස්සිඅනුකෘතිය මත පදනම් වූ මාස්ක්. සිට දාරයක් නොමැති නම් දක්වා සකසා ඇත.
169 e = e.masked_fill(adj_mat == 0, float('-inf'))
ඉන්පසුඅපි අවධානය යොමු කිරීමේ ලකුණු සාමාන්යකරණය කරමු (හෝ සංගුණක)
සම්බන්ධවූ නෝඩ් කට්ටලය කොහේද?
අපිමෙය කරන්නේ අසම්බන්ධිත යුගල සඳහා සම්බන්ධ නොවූ සැකසීමෙනි.
179 a = self.softmax(e)
අතහැරදැමීමේ විධිමත් කිරීම යොදන්න
182 a = self.dropout(a)
එක්එක් හිස සඳහා අවසාන ප්රතිදානය ගණනය කරන්න
සටහන: කඩදාසි වල අවසාන සක්රිය කිරීම ඇතුළත් වේ අපි මෙය ප්රස්තාරය අවධානය යොමු කිරීමේ ස්ථර ක්රියාත්මක කිරීමෙන් මඟ හැර ඇති අතර වෙනත් PyTorch මොඩියුල අර්ථ දක්වා ඇති ආකාරය සමඟ ගැලපෙන පරිදි GAT ආකෘතිය මත එය භාවිතා කරන්න - සක්රිය කිරීම වෙනම ස්ථරයක් ලෙස.
191 attn_res = torch.einsum('ijh,jhf->ihf', a, g)
ප්රධානීන්සංයුක්ත කරන්න
194 if self.is_concat:
196 return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
ප්රධානීන්ගේමධ්යන්යය ගන්න
198 else:
200 return attn_res.mean(dim=1)