これは、PyTorchの論文「注意さえあれば十分」の「多面的な注意」のチュートリアル/実装です。実装は注釈付きトランスフォーマーから着想を得ています
。24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import trackerこのモジュールは線形変換を行い、ベクトルを指定された数のヘッドに分割してマルチヘッドアテンションを行います。これは、キー、クエリ、および値のベクトルを変換するために使用されます。
33class PrepareForMultiHeadAttention(nn.Module):44    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45        super().__init__()線形変換用の線形層
47        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)ヘッド数
49        self.heads = heads各ヘッドのベクトルの次元数
51        self.d_k = d_k53    def forward(self, x: torch.Tensor):[seq_len, batch_size, d_model]
[batch_size, d_model]
入力の形状はまたはです。線形変換を最後の次元に適用し、それを頭に分割します。
57        head_shape = x.shape[:-1]線形変換
60        x = self.linear(x)最後のディメンションをヘッドに分割
63        x = x.view(*head_shape, self.heads, self.d_k)[seq_len, batch_size, heads, d_k]
出力の形状があるか [batch_size, heads, d_model]
66        return xquery
与えられたベクトルやベクトルに対して、スケーリングされたマルチヘッド・アテンションを計算します。key
 value
簡単に言うと、クエリに一致するキーを見つけ、それらのキーの値を取得します。
クエリとキーのドット積がどの程度一致しているかを示す指標として使用します。撮影前にドットプロダクトをスケーリングします。これは、ドット積値が大きい場合に softmax のグラデーションが非常に小さくなる原因とならないようにするためです
。Softmax は、シーケンス (または時間) の軸に沿って計算されます。
69class MultiHeadAttention(nn.Module):heads
は頭の数です。d_model
はquery
、key
value
およびベクトル内の特徴の数です。90    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):96        super().__init__()ヘッドあたりの機能数
99        self.d_k = d_model // headsヘッド数
101        self.heads = headsこれらはquery
、、key
value
のベクトルを変えて、多面的な注意を促します。
104        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)時間軸に沿った注目のソフトマックス key
109        self.softmax = nn.Softmax(dim=1)出力レイヤー
112        self.output = nn.Linear(d_model, d_model)ドロップアウト
114        self.dropout = nn.Dropout(dropout_prob)ソフトマックス前のスケーリングファクター
116        self.scale = 1 / math.sqrt(self.d_k)必要に応じてロギングやその他の計算に使用できるように、アテンションを保存します
119        self.attn = None121    def get_scores(self, query: torch.Tensor, key: torch.Tensor):計算または
129        return torch.einsum('ibhd,jbhd->ijbh', query, key)mask
には形状があり[seq_len_q, seq_len_k, batch_size]
、最初の次元はクエリ次元です。クエリディメンションがそれと等しい場合はブロードキャストされます
131    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):137        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138        assert mask.shape[1] == key_shape[0]
139        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]すべての頭に同じマスクをかけました。
142        mask = mask.unsqueeze(-1)生成されるマスクには形状があります [seq_len_q, seq_len_k, batch_size, heads]
145        return maskquery
、key
value
およびは、クエリ、キー、および値のベクトルのコレクションを格納するテンソルです。形があります[seq_len, batch_size, d_model]
。
mask
[seq_len, seq_len, batch_size]
形状があり、バッチの場合b
、mask[i, j, b]
i
その位置のクエリがその位置のキー値にアクセスできるかどうかを示します。j
147    def forward(self, *,
148                query: torch.Tensor,
149                key: torch.Tensor,
150                value: torch.Tensor,
151                mask: Optional[torch.Tensor] = None):query
、key
value
そして形がある [seq_len, batch_size, d_model]
163        seq_len, batch_size, _ = query.shape
164
165        if mask is not None:
166            mask = self.prepare_mask(mask, query.shape, key.shape)query
key
value
注意力計算の準備をして[seq_len, batch_size, heads, d_k]
これで形ができあがります。
170        query = self.query(query)
171        key = self.key(key)
172        value = self.value(value)アテンションスコアを計算します。[seq_len, seq_len, batch_size, heads]
これにより形状のテンソルが得られます
176        scores = self.get_scores(query, key)スケールスコア
179        scores *= self.scaleマスクを適用
182        if mask is not None:
183            scores = scores.masked_fill(mask == 0, float('-inf'))キーシーケンス次元に沿って注目
187        attn = self.softmax(scores)デバッグ時の注意事項を保存
190        tracker.debug('attn', attn)ドロップアウトを適用
193        attn = self.dropout(attn)値による乗算
197        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)他の計算に注意を向けておく
200        self.attn = attn.detach()複数のヘッドを連結
203        x = x.reshape(seq_len, batch_size, -1)出力レイヤー
206        return self.output(x)