这是 PyTorch 中多头注意力的教程/实现,摘自论文 “注意就是你所需要的”。该实现的灵感来自带注释的变压器。
以下是使用带有 MHA 的基本转换器进行 NLP 自动回归的训练代码。
25import math
26from typing import Optional, List
27
28import torch
29from torch import nn
30
31from labml import tracker34class PrepareForMultiHeadAttention(nn.Module):45    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
46        super().__init__()线性变换的线性层
48        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)头数
50        self.heads = heads每个头部中以向量表示的维度数
52        self.d_k = d_k54    def forward(self, x: torch.Tensor):输入的形状[seq_len, batch_size, d_model]
或[batch_size, d_model]
。我们将线性变换应用于最后一个维度,然后将其拆分为头部。
58        head_shape = x.shape[:-1]线性变换
61        x = self.linear(x)将最后一个维度拆分成头部
64        x = x.view(*head_shape, self.heads, self.d_k)输出具有形状[seq_len, batch_size, heads, d_k]
或[batch_size, d_model]
67        return x这将计算给定key
和value
向量的缩放多头注意query
力。
简单来说,它会找到与查询匹配的键,并获取这些键的值。
它使用查询和键的点积作为它们匹配程度的指标。在服用点产品之前,先按比例缩放。这样做是为了避免较大的点积值导致 softmax 在较大时给出非常小的梯度。
Softmax 是沿序列(或时间)的轴计算的。
70class MultiHeadAttention(nn.Module):heads
是头的数量。d_model
是query
、key
和value
向量中的要素数。91    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):97        super().__init__()每头特征数
100        self.d_k = d_model // heads头数
102        self.heads = heads这些变换了多头注意力的query
、key
和value
向量。
105        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
107        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)Softmax 在时间维度上引起人们的注意key
110        self.softmax = nn.Softmax(dim=1)输出层
113        self.output = nn.Linear(d_model, d_model)辍学
115        self.dropout = nn.Dropout(dropout_prob)softmax 之前的缩放系数
117        self.scale = 1 / math.sqrt(self.d_k)我们存储注意事项,以便在需要时将其用于日志记录或进行其他计算
120        self.attn = None122    def get_scores(self, query: torch.Tensor, key: torch.Tensor):计算或
130        return torch.einsum('ibhd,jbhd->ijbh', query, key)mask
有形状[seq_len_q, seq_len_k, batch_size]
,其中第一个维度是查询维度。如果查询维度等于它将被广播。
132    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):138        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
139        assert mask.shape[1] == key_shape[0]
140        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]所有头部都使用相同的面具。
143        mask = mask.unsqueeze(-1)生成的遮罩有形状[seq_len_q, seq_len_k, batch_size, heads]
146        return maskquery
key
和value
是存储查询、键和值向量集合的张量。它们有形状[seq_len, batch_size, d_model]
。
mask
有形状[seq_len, seq_len, batch_size]
并mask[i, j, b]
指示是否为批量查询b
,位置处的查询i
有权访问位置处的键值j
。
148    def forward(self, *,
149                query: torch.Tensor,
150                key: torch.Tensor,
151                value: torch.Tensor,
152                mask: Optional[torch.Tensor] = None):query
,key
并且value
有形状[seq_len, batch_size, d_model]
164        seq_len, batch_size, _ = query.shape
165
166        if mask is not None:
167            mask = self.prepare_mask(mask, query.shape, key.shape)准备query
,key
并value
进行注意力计算。然后这些就会有形状[seq_len, batch_size, heads, d_k]
。
171        query = self.query(query)
172        key = self.key(key)
173        value = self.value(value)计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]
。
177        scores = self.get_scores(query, key)音阶分数
180        scores *= self.scale涂抹面膜
183        if mask is not None:
184            scores = scores.masked_fill(mask == 0, float('-inf'))关注按键序列维度
188        attn = self.softmax(scores)调试时省去注意力
191        tracker.debug('attn', attn)申请退学
194        attn = self.dropout(attn)乘以值
198        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)保存任何其他计算的注意力
201        self.attn = attn.detach()连接多个头
204        x = x.reshape(seq_len, batch_size, -1)输出层
207        return self.output(x)