මෙයකඩදාසි වලින් බහු-ශීර්ෂ අවධානය යොමු කිරීමේ නිබන්ධනයක්/ක්රියාත්මක කිරීමකි අවධානය PyTorch හි ඔබට අවශ්ය සියල්ල වේ. ක්රියාත්මක කිරීම ආනුභාව ලත් ට්රාන්ස්ෆෝමර්වෙතින් දේවානුභාවයෙන් ය.
NLPස්වයංක්රීය-ප්රතිගාමී සඳහා MHA සමඟ මූලික ට්රාන්ස්ෆෝමරයක් භාවිතා කරන පුහුණු කේතය මෙන්න.
සරලට්රාන්ස්ෆෝමරයක් පුහුණු කරන අත්හදා බැලීමේ ක්රියාත්මක කිරීමක් මෙන්න .
25import math
26from typing import Optional, List
27
28import torch
29from torch import nn
30
31from labml import trackerමෙමමොඩියුලය රේඛීය පරිවර්තනයක් සිදු කරන අතර බහු හිස අවධානය සඳහා දෛශිකය ලබා දී ඇති හිස් ගණනට බෙදේ. යතුර, විමසුමසහ අගයදෛශික පරිවර්තනය කිරීමට මෙය භාවිතා කරයි.
34class PrepareForMultiHeadAttention(nn.Module):45 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
46 super().__init__()රේඛීයපරිණාමනය සඳහා රේඛීය ස්ථරය
48 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)හිස්ගණන
50 self.heads = headsඑක්එක් හිසෙහි දෛශිකවල මානයන් ගණන
52 self.d_k = d_k54 def forward(self, x: torch.Tensor):ආදානයටහැඩය [seq_len, batch_size, d_model]
හෝ [batch_size, d_model]
ඇත. අපි රේඛීය පරිවර්තනය අවසාන මානයට අදාළ වන අතර එය හිස් බවට බෙදී යයි.
58 head_shape = x.shape[:-1]රේඛීයපරිණාමනය
61 x = self.linear(x)අවසානමානය හිස් බවට බෙදන්න
64 x = x.view(*head_shape, self.heads, self.d_k)නිමැවුමේහැඩය [seq_len, batch_size, heads, d_k]
හෝ [batch_size, d_model]
67 return xමෙයලබා දී query
ඇති key
සහ value
දෛශික සඳහා බහු-ශීර්ෂ අවධානය පරිමාණය කරයි.
සරළවකිවහොත්, එය විමසුමට ගැලපෙන යතුරු සොයා ගන්නා අතර එම යතුරු වල අගයන් ලබා ගනී.
එයඔවුන් කෙතරම් ගැලපෙන දර්ශකයක් ලෙස විමසුම හා ප්රධාන තිත්-නිෂ්පාදන භාවිතා කරයි. තිත්-නිෂ්පාදන ගැනීමට පෙර පරිමාණය කරනු ලැබේ . මෙය සිදු කරනු ලබන්නේ විශාල තිත් නිෂ්පාදන අගයන් වළක්වා ගැනීම සඳහා වන අතර එමඟින් සොෆ්ට්මැක්ස් විශාල වන විට ඉතා කුඩා අනුක්රමික ප්රමාණයක් ලබා දේ.
Softmaxඅනුක්රමය (හෝ කාලය) අක්ෂය ඔස්සේ ගණනය කරනු ලැබේ.
70class MultiHeadAttention(nn.Module):heads
හිස් ගණන වේ. d_model
යනු query
, key
සහ value
දෛශිකවල ඇති ලක්ෂණ ගණන වේ. 91 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):97 super().__init__()හිසකටවිශේෂාංග ගණන
100 self.d_k = d_model // headsහිස්ගණන
102 self.heads = headsමේවාබහු-ශීර්ෂ අවධානය සඳහා key
සහ value
දෛශික පරිවර්තනය කරයි. query
105 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
107 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)කාලමානය ඔස්සේ අවධානය යොමු කිරීම සඳහා සොෆ්ට්මැක්ස් key
110 self.softmax = nn.Softmax(dim=1)ප්රතිදානස්ථරය
113 self.output = nn.Linear(d_model, d_model)හැලීම
115 self.dropout = nn.Dropout(dropout_prob)සොෆ්ට්මැක්ස්වලට පෙර පරිමාණ සාධකය
117 self.scale = 1 / math.sqrt(self.d_k)අවශ්යනම් ලොග් වීම හෝ වෙනත් ගණනය කිරීම් සඳහා භාවිතා කළ හැකි වන පරිදි අපි අවධානය ගබඩා කරමු
120 self.attn = Noneසාපේක්ෂඅවධානය වැනි වෙනත් වෙනස්කම් සඳහා මෙම ක්රමය ඉක්මවා යා හැකිය.
122 def get_scores(self, query: torch.Tensor, key: torch.Tensor):ගණනයකරන්න හෝ
130 return torch.einsum('ibhd,jbhd->ijbh', query, key) mask
හැඩය ඇත [seq_len_q, seq_len_k, batch_size]
, එහිදී පළමු මානය විමසුම් මානයක් වේ. විමසුම මානයක් සමාන වේ නම් එය විකාශනය කරනු ඇත.
132 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):138 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
139 assert mask.shape[1] == key_shape[0]
140 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]සියලුමහිස් සඳහා එකම ආවරණ යොදනු ලැබේ.
143 mask = mask.unsqueeze(-1)එහිප්රති ing ලයක් ලෙස වෙස් [seq_len_q, seq_len_k, batch_size, heads]
146 return mask query
, key
සහ value
විමසුම, යතුරසහ අගයදෛශික එකතු කිරීම ගබඩා කරන ආතන්ය වේ. ඒවායේ හැඩය ඇත [seq_len, batch_size, d_model]
.
mask
හැඩය ඇති [seq_len, seq_len, batch_size]
අතර කණ්ඩායම සඳහා b
, ස්ථානයේ විමසුමට ප්රවේශය i
තිබේද යන්න mask[i, j, b]
දක්වයි ස්ථානයේ ප්රධාන-අගය j
.
148 def forward(self, *,
149 query: torch.Tensor,
150 key: torch.Tensor,
151 value: torch.Tensor,
152 mask: Optional[torch.Tensor] = None):query
, key
value
සහ හැඩය [seq_len, batch_size, d_model]
164 seq_len, batch_size, _ = query.shape
165
166 if mask is not None:
167 mask = self.prepare_mask(mask, query.shape, key.shape)සූදානම්වන්න query
, key
සහ අවධානය ගණනය කිරීම value
සඳහා. මේවාට පසුව හැඩය ඇත [seq_len, batch_size, heads, d_k]
.
171 query = self.query(query)
172 key = self.key(key)
173 value = self.value(value)අවධානයලකුණු ගණනය කරන්න . මෙය හැඩයේ ආතතිකයක් ලබා දෙයි [seq_len, seq_len, batch_size, heads]
.
177 scores = self.get_scores(query, key)පරිමාණලකුණු
180 scores *= self.scaleවෙස්යොදන්න
183 if mask is not None:
184 scores = scores.masked_fill(mask == 0, float('-inf'))ප්රධාන අනුක්රමය මානයක් ඔස්සේ අවධානය
188 attn = self.softmax(scores)නිදොස්කරණයනම් අවධානය සුරකින්න
191 tracker.debug('attn', attn)අතහැරදැමීම යොදන්න
194 attn = self.dropout(attn)අගයන්අනුව ගුණ කරන්න
198 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)වෙනත්ගණනය කිරීම් සඳහා අවධානය සුරකින්න
201 self.attn = attn.detach()බහුහිස් සංයුක්ත කරන්න
204 x = x.reshape(seq_len, batch_size, -1)ප්රතිදානස්ථරය
207 return self.output(x)