这是 P yTorch 中论文 “注意力就是你所需要的” 多头注意力的教程/实现。该实现的灵感来自带注释的变形金刚。
以下是使用带有 MHA 的基本转换器进行 NLP 自动回归的训练代码。
24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker33class PrepareForMultiHeadAttention(nn.Module):44    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45        super().__init__()线性变换的线性层
47        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)头数
49        self.heads = heads每个头部中以向量表示的维度数
51        self.d_k = d_k53    def forward(self, x: torch.Tensor):输入的形状[seq_len, batch_size, d_model]
或[batch_size, d_model]
。我们将线性变换应用于最后一个维度,然后将其拆分为头部。
57        head_shape = x.shape[:-1]线性变换
60        x = self.linear(x)将最后一个维度拆分成头部
63        x = x.view(*head_shape, self.heads, self.d_k)输出具有形状[seq_len, batch_size, heads, d_k]
或[batch_size, heads, d_model]
66        return x这将计算给定key
和value
向量的缩放多头注意query
力。
简单来说,它会找到与查询匹配的键,并获取这些键的值。
它使用查询和键的点积作为它们匹配程度的指标。在服用点产品之前,先按比例缩放。这样做是为了避免较大的点积值导致 softmax 在较大时给出非常小的梯度。
Softmax 是沿序列(或时间)的轴计算的。
69class MultiHeadAttention(nn.Module):heads
是头的数量。d_model
是query
、key
和value
向量中的要素数。90    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):96        super().__init__()每头特征数
99        self.d_k = d_model // heads头数
101        self.heads = heads这些变换了多头注意力的query
、key
和value
向量。
104        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)Softmax 在时间维度上引起人们的注意key
109        self.softmax = nn.Softmax(dim=1)输出层
112        self.output = nn.Linear(d_model, d_model)辍学
114        self.dropout = nn.Dropout(dropout_prob)softmax 之前的缩放系数
116        self.scale = 1 / math.sqrt(self.d_k)我们存储注意事项,以便在需要时将其用于日志记录或进行其他计算
119        self.attn = None121    def get_scores(self, query: torch.Tensor, key: torch.Tensor):计算或
129        return torch.einsum('ibhd,jbhd->ijbh', query, key)mask
有形状[seq_len_q, seq_len_k, batch_size]
,其中第一个维度是查询维度。如果查询维度等于它将被广播。
131    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):137        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138        assert mask.shape[1] == key_shape[0]
139        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]所有头部都使用相同的面具。
142        mask = mask.unsqueeze(-1)生成的遮罩有形状[seq_len_q, seq_len_k, batch_size, heads]
145        return maskquery
key
和value
是存储查询、键和值向量集合的张量。它们有形状[seq_len, batch_size, d_model]
。
mask
有形状[seq_len, seq_len, batch_size]
并mask[i, j, b]
指示是否为批量查询b
,位置处的查询i
有权访问位置处的键值j
。
147    def forward(self, *,
148                query: torch.Tensor,
149                key: torch.Tensor,
150                value: torch.Tensor,
151                mask: Optional[torch.Tensor] = None):query
,key
并且value
有形状[seq_len, batch_size, d_model]
163        seq_len, batch_size, _ = query.shape
164
165        if mask is not None:
166            mask = self.prepare_mask(mask, query.shape, key.shape)准备query
,key
并value
进行注意力计算。然后这些就会有形状[seq_len, batch_size, heads, d_k]
。
170        query = self.query(query)
171        key = self.key(key)
172        value = self.value(value)计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]
。
176        scores = self.get_scores(query, key)音阶分数
179        scores *= self.scale涂抹面膜
182        if mask is not None:
183            scores = scores.masked_fill(mask == 0, float('-inf'))关注按键序列维度
187        attn = self.softmax(scores)调试时省去注意力
190        tracker.debug('attn', attn)申请退学
193        attn = self.dropout(attn)乘以值
197        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)保存任何其他计算的注意力
200        self.attn = attn.detach()连接多个头
203        x = x.reshape(seq_len, batch_size, -1)输出层
206        return self.output(x)