බහු ශීර්ෂ අවධානය (MHA)

Open In Colab

මෙය කඩදාසි වලින් බහු-ශීර්ෂ අවධානය යොමු කිරීමේ නිබන්ධනයක්/ක්රියාත්මක කිරීමකි අවධානය PyTorch හි ඔබට අවශ්ය සියල්ල වේ. ක්රියාත්මක කිරීම ආනුභාව ලත් ට්රාන්ස්ෆෝමර් වෙතින් දේවානුභාවයෙන් ය.

NLP ස්වයංක්රීය-ප්රතිගාමී සඳහා MHA සමඟ මූලික ට්රාන්ස්ෆෝමරයක් භාවිතා කරන පුහුණු කේතය මෙන්න.

සරල ට්රාන්ස්ෆෝමරයක් පුහුණු කරන අත්හදා බැලීමේ ක්රියාත්මක කිරීමක් මෙන්න.

24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker

බහු-හිසඅවධානය සඳහා සූදානම් වන්න

මෙමමොඩියුලය රේඛීය පරිවර්තනයක් සිදු කරන අතර බහු හිස අවධානය සඳහා දෛශිකය ලබා දී ඇති හිස් ගණනට බෙදේ. යතුර, විමසුමසහ අගයදෛශික පරිවර්තනය කිරීමට මෙය භාවිතා කරයි.

33class PrepareForMultiHeadAttention(nn.Module):
44    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45        super().__init__()

රේඛීයපරිණාමනය සඳහා රේඛීය ස්ථරය

47        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

හිස්ගණන

49        self.heads = heads

එක්එක් හිසෙහි දෛශිකවල මානයන් ගණන

51        self.d_k = d_k
53    def forward(self, x: torch.Tensor):

ආදානයටහැඩය [seq_len, batch_size, d_model] හෝ [batch_size, d_model] ඇත. අපි රේඛීය පරිවර්තනය අවසාන මානයට අදාළ වන අතර එය හිස් බවට බෙදී යයි.

57        head_shape = x.shape[:-1]

රේඛීයපරිණාමනය

60        x = self.linear(x)

අවසානමානය හිස් බවට බෙදන්න

63        x = x.view(*head_shape, self.heads, self.d_k)

නිමැවුමේහැඩය [seq_len, batch_size, heads, d_k] හෝ [batch_size, heads, d_model]

66        return x

බහු-ප්රධානඅවධානය මොඩියුලය

මෙයලබා දී query ඇති key සහ value දෛශික සඳහා බහු-ශීර්ෂ අවධානය පරිමාණය කරයි.

සරළවකිවහොත්, එය විමසුමට ගැලපෙන යතුරු සොයා ගන්නා අතර එම යතුරු වල අගයන් ලබා ගනී.

එයඔවුන් කෙතරම් ගැලපෙන දර්ශකයක් ලෙස විමසුම හා ප්රධාන තිත්-නිෂ්පාදන භාවිතා කරයි. තිත්-නිෂ්පාදන ගැනීමට පෙර පරිමාණය කරනු ලැබේ . මෙය සිදු කරනු ලබන්නේ විශාල තිත් නිෂ්පාදන අගයන් වළක්වා ගැනීම සඳහා වන අතර එමඟින් සොෆ්ට්මැක්ස් විශාල වන විට ඉතා කුඩා අනුක්රමික ප්රමාණයක් ලබා දේ.

Softmaxඅනුක්රමය (හෝ කාලය) අක්ෂය ඔස්සේ ගණනය කරනු ලැබේ.

69class MultiHeadAttention(nn.Module):
  • heads හිස් ගණන වේ.
  • d_model යනු query , key සහ value දෛශිකවල ඇති ලක්ෂණ ගණන වේ.
90    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
96        super().__init__()

හිසකටවිශේෂාංග ගණන

99        self.d_k = d_model // heads

හිස්ගණන

101        self.heads = heads

මේවාබහු-ශීර්ෂ අවධානය සඳහා key සහ value දෛශික පරිවර්තනය කරයි. query

104        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

කාලමානය ඔස්සේ අවධානය යොමු කිරීම සඳහා සොෆ්ට්මැක්ස් key

109        self.softmax = nn.Softmax(dim=1)

ප්රතිදානස්ථරය

112        self.output = nn.Linear(d_model, d_model)

හැලීම

114        self.dropout = nn.Dropout(dropout_prob)

සොෆ්ට්මැක්ස්වලට පෙර පරිමාණ සාධකය

116        self.scale = 1 / math.sqrt(self.d_k)

අවශ්යනම් ලොග් වීම හෝ වෙනත් ගණනය කිරීම් සඳහා භාවිතා කළ හැකි වන පරිදි අපි අවධානය ගබඩා කරමු

119        self.attn = None

විමසුම්සහ යතුරු අතර ලකුණු ගණනය කරන්න

සාපේක්ෂඅවධානය වැනි වෙනත් වෙනස්කම් සඳහා මෙම ක්රමය ඉක්මවා යා හැකිය.

121    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

ගණනයකරන්න හෝ

129        return torch.einsum('ibhd,jbhd->ijbh', query, key)

mask හැඩය ඇත [seq_len_q, seq_len_k, batch_size] , එහිදී පළමු මානය විමසුම් මානයක් වේ. විමසුම මානයක් සමාන වේ නම් එය විකාශනය කරනු ඇත.

131    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
137        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138        assert mask.shape[1] == key_shape[0]
139        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

සියලුමහිස් සඳහා එකම ආවරණ යොදනු ලැබේ.

142        mask = mask.unsqueeze(-1)

එහිප්රති ing ලයක් ලෙස වෙස් [seq_len_q, seq_len_k, batch_size, heads]

145        return mask

query , key සහ value විමසුම, යතුරසහ අගයදෛශික එකතු කිරීම ගබඩා කරන ආතන්ය වේ. ඒවායේ හැඩය ඇත [seq_len, batch_size, d_model] .

mask හැඩය ඇති [seq_len, seq_len, batch_size] අතර කණ්ඩායම සඳහා b , ස්ථානයේ විමසුමට ප්රවේශය i තිබේද යන්න mask[i, j, b] දක්වයි ස්ථානයේ ප්රධාන-අගය j .

147    def forward(self, *,
148                query: torch.Tensor,
149                key: torch.Tensor,
150                value: torch.Tensor,
151                mask: Optional[torch.Tensor] = None):

query , key value සහ හැඩය [seq_len, batch_size, d_model]

163        seq_len, batch_size, _ = query.shape
164
165        if mask is not None:
166            mask = self.prepare_mask(mask, query.shape, key.shape)

සූදානම්වන්න query , key සහ අවධානය ගණනය කිරීම value සඳහා. මේවාට පසුව හැඩය ඇත [seq_len, batch_size, heads, d_k] .

170        query = self.query(query)
171        key = self.key(key)
172        value = self.value(value)

අවධානයලකුණු ගණනය කරන්න . මෙය හැඩයේ ආතතිකයක් ලබා දෙයි [seq_len, seq_len, batch_size, heads] .

176        scores = self.get_scores(query, key)

පරිමාණලකුණු

179        scores *= self.scale

වෙස්යොදන්න

182        if mask is not None:
183            scores = scores.masked_fill(mask == 0, float('-inf'))

ප්රධාන අනුක්රමය මානයක් ඔස්සේ අවධානය

187        attn = self.softmax(scores)

නිදොස්කරණයනම් අවධානය සුරකින්න

190        tracker.debug('attn', attn)

අතහැරදැමීම යොදන්න

193        attn = self.dropout(attn)

අගයන්අනුව ගුණ කරන්න

197        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

වෙනත්ගණනය කිරීම් සඳහා අවධානය සුරකින්න

200        self.attn = attn.detach()

බහුහිස් සංයුක්ත කරන්න

203        x = x.reshape(seq_len, batch_size, -1)

ප්රතිදානස්ථරය

206        return self.output(x)