මෙය කඩදාසි වලින් බහු-ශීර්ෂ අවධානය යොමු කිරීමේ නිබන්ධනයක්/ක්රියාත්මක කිරීමකි අවධානය PyTorch හි ඔබට අවශ්ය සියල්ල වේ. ක්රියාත්මක කිරීම ආනුභාව ලත් ට්රාන්ස්ෆෝමර් වෙතින් දේවානුභාවයෙන් ය.
NLP ස්වයංක්රීය-ප්රතිගාමී සඳහා MHA සමඟ මූලික ට්රාන්ස්ෆෝමරයක් භාවිතා කරන පුහුණු කේතය මෙන්න.
සරල ට්රාන්ස්ෆෝමරයක් පුහුණු කරන අත්හදා බැලීමේ ක්රියාත්මක කිරීමක් මෙන්න.
24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker
මෙමමොඩියුලය රේඛීය පරිවර්තනයක් සිදු කරන අතර බහු හිස අවධානය සඳහා දෛශිකය ලබා දී ඇති හිස් ගණනට බෙදේ. යතුර, විමසුමසහ අගයදෛශික පරිවර්තනය කිරීමට මෙය භාවිතා කරයි.
33class PrepareForMultiHeadAttention(nn.Module):
44 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45 super().__init__()
රේඛීයපරිණාමනය සඳහා රේඛීය ස්ථරය
47 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
හිස්ගණන
49 self.heads = heads
එක්එක් හිසෙහි දෛශිකවල මානයන් ගණන
51 self.d_k = d_k
53 def forward(self, x: torch.Tensor):
ආදානයටහැඩය [seq_len, batch_size, d_model]
හෝ [batch_size, d_model]
ඇත. අපි රේඛීය පරිවර්තනය අවසාන මානයට අදාළ වන අතර එය හිස් බවට බෙදී යයි.
57 head_shape = x.shape[:-1]
රේඛීයපරිණාමනය
60 x = self.linear(x)
අවසානමානය හිස් බවට බෙදන්න
63 x = x.view(*head_shape, self.heads, self.d_k)
නිමැවුමේහැඩය [seq_len, batch_size, heads, d_k]
හෝ [batch_size, heads, d_model]
66 return x
මෙයලබා දී query
ඇති key
සහ value
දෛශික සඳහා බහු-ශීර්ෂ අවධානය පරිමාණය කරයි.
සරළවකිවහොත්, එය විමසුමට ගැලපෙන යතුරු සොයා ගන්නා අතර එම යතුරු වල අගයන් ලබා ගනී.
එයඔවුන් කෙතරම් ගැලපෙන දර්ශකයක් ලෙස විමසුම හා ප්රධාන තිත්-නිෂ්පාදන භාවිතා කරයි. තිත්-නිෂ්පාදන ගැනීමට පෙර පරිමාණය කරනු ලැබේ . මෙය සිදු කරනු ලබන්නේ විශාල තිත් නිෂ්පාදන අගයන් වළක්වා ගැනීම සඳහා වන අතර එමඟින් සොෆ්ට්මැක්ස් විශාල වන විට ඉතා කුඩා අනුක්රමික ප්රමාණයක් ලබා දේ.
Softmaxඅනුක්රමය (හෝ කාලය) අක්ෂය ඔස්සේ ගණනය කරනු ලැබේ.
69class MultiHeadAttention(nn.Module):
heads
හිස් ගණන වේ. d_model
යනු query
, key
සහ value
දෛශිකවල ඇති ලක්ෂණ ගණන වේ. 90 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96 super().__init__()
හිසකටවිශේෂාංග ගණන
99 self.d_k = d_model // heads
හිස්ගණන
101 self.heads = heads
මේවාබහු-ශීර්ෂ අවධානය සඳහා key
සහ value
දෛශික පරිවර්තනය කරයි. query
104 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
කාලමානය ඔස්සේ අවධානය යොමු කිරීම සඳහා සොෆ්ට්මැක්ස් key
109 self.softmax = nn.Softmax(dim=1)
ප්රතිදානස්ථරය
112 self.output = nn.Linear(d_model, d_model)
හැලීම
114 self.dropout = nn.Dropout(dropout_prob)
සොෆ්ට්මැක්ස්වලට පෙර පරිමාණ සාධකය
116 self.scale = 1 / math.sqrt(self.d_k)
අවශ්යනම් ලොග් වීම හෝ වෙනත් ගණනය කිරීම් සඳහා භාවිතා කළ හැකි වන පරිදි අපි අවධානය ගබඩා කරමු
119 self.attn = None
සාපේක්ෂඅවධානය වැනි වෙනත් වෙනස්කම් සඳහා මෙම ක්රමය ඉක්මවා යා හැකිය.
121 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
ගණනයකරන්න හෝ
129 return torch.einsum('ibhd,jbhd->ijbh', query, key)
mask
හැඩය ඇත [seq_len_q, seq_len_k, batch_size]
, එහිදී පළමු මානය විමසුම් මානයක් වේ. විමසුම මානයක් සමාන වේ නම් එය විකාශනය කරනු ඇත.
131 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
137 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138 assert mask.shape[1] == key_shape[0]
139 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
සියලුමහිස් සඳහා එකම ආවරණ යොදනු ලැබේ.
142 mask = mask.unsqueeze(-1)
එහිප්රති ing ලයක් ලෙස වෙස් [seq_len_q, seq_len_k, batch_size, heads]
145 return mask
query
, key
සහ value
විමසුම, යතුරසහ අගයදෛශික එකතු කිරීම ගබඩා කරන ආතන්ය වේ. ඒවායේ හැඩය ඇත [seq_len, batch_size, d_model]
.
mask
හැඩය ඇති [seq_len, seq_len, batch_size]
අතර කණ්ඩායම සඳහා b
, ස්ථානයේ විමසුමට ප්රවේශය i
තිබේද යන්න mask[i, j, b]
දක්වයි ස්ථානයේ ප්රධාන-අගය j
.
147 def forward(self, *,
148 query: torch.Tensor,
149 key: torch.Tensor,
150 value: torch.Tensor,
151 mask: Optional[torch.Tensor] = None):
query
, key
value
සහ හැඩය [seq_len, batch_size, d_model]
163 seq_len, batch_size, _ = query.shape
164
165 if mask is not None:
166 mask = self.prepare_mask(mask, query.shape, key.shape)
සූදානම්වන්න query
, key
සහ අවධානය ගණනය කිරීම value
සඳහා. මේවාට පසුව හැඩය ඇත [seq_len, batch_size, heads, d_k]
.
170 query = self.query(query)
171 key = self.key(key)
172 value = self.value(value)
අවධානයලකුණු ගණනය කරන්න . මෙය හැඩයේ ආතතිකයක් ලබා දෙයි [seq_len, seq_len, batch_size, heads]
.
176 scores = self.get_scores(query, key)
පරිමාණලකුණු
179 scores *= self.scale
වෙස්යොදන්න
182 if mask is not None:
183 scores = scores.masked_fill(mask == 0, float('-inf'))
ප්රධාන අනුක්රමය මානයක් ඔස්සේ අවධානය
187 attn = self.softmax(scores)
නිදොස්කරණයනම් අවධානය සුරකින්න
190 tracker.debug('attn', attn)
අතහැරදැමීම යොදන්න
193 attn = self.dropout(attn)
අගයන්අනුව ගුණ කරන්න
197 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
වෙනත්ගණනය කිරීම් සඳහා අවධානය සුරකින්න
200 self.attn = attn.detach()
බහුහිස් සංයුක්ත කරන්න
203 x = x.reshape(seq_len, batch_size, -1)
ප්රතිදානස්ථරය
206 return self.output(x)