多头注意 (MHA)

Open In ColabOpen In Comet

这是 PyTorch 中多头注意力的教程/实现,摘自论文 “注意就是你所需要的”。该实现的灵感来自带注释的变压器

以下是使用带有 MHA 的基本转换器进行 NLP 自动回归的训练代码

这是一个训练简单变压器的实验实现

25import math
26from typing import Optional, List
27
28import torch
29from torch import nn
30
31from labml import tracker

为多头注意做好准备

该模块进行线性变换,并将向量拆分为给定数量的头部,以获得多头注意。这用于转换查询向量。

34class PrepareForMultiHeadAttention(nn.Module):
45    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
46        super().__init__()

线性变换的线性层

48        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

头数

50        self.heads = heads

每个头部中以向量表示的维度数

52        self.d_k = d_k
54    def forward(self, x: torch.Tensor):

输入的形状[seq_len, batch_size, d_model][batch_size, d_model] 。我们将线性变换应用于最后一个维度,然后将其拆分为头部。

58        head_shape = x.shape[:-1]

线性变换

61        x = self.linear(x)

将最后一个维度拆分成头部

64        x = x.view(*head_shape, self.heads, self.d_k)

输出具有形状[seq_len, batch_size, heads, d_k][batch_size, d_model]

67        return x

多头注意模块

这将计算给定keyvalue 向量的缩放多头注意query 力。

简单来说,它会找到与查询匹配的键,并获取这些键的值。

它使用查询和键的点积作为它们匹配程度的指标。在服用点产品之前,先按比例缩放。这样做是为了避免较大的点积值导致 softmax 在较大时给出非常小的梯度。

Softmax 是沿序列(或时间)的轴计算的。

70class MultiHeadAttention(nn.Module):
  • heads 是头的数量。
  • d_modelquerykeyvalue 向量中的要素数。
91    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
97        super().__init__()

每头特征数

100        self.d_k = d_model // heads

头数

102        self.heads = heads

这些变换了多头注意力的querykeyvalue 向量。

105        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
107        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

Softmax 在时间维度上引起人们的注意key

110        self.softmax = nn.Softmax(dim=1)

输出层

113        self.output = nn.Linear(d_model, d_model)

辍学

115        self.dropout = nn.Dropout(dropout_prob)

softmax 之前的缩放系数

117        self.scale = 1 / math.sqrt(self.d_k)

我们存储注意事项,以便在需要时将其用于日志记录或进行其他计算

120        self.attn = None

计算查询和键之间的分数

对于其他变体,例如相对注意力,可以覆盖此方法。

122    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

计算

130        return torch.einsum('ibhd,jbhd->ijbh', query, key)

mask 有形状[seq_len_q, seq_len_k, batch_size] ,其中第一个维度是查询维度。如果查询维度等于它将被广播。

132    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
138        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
139        assert mask.shape[1] == key_shape[0]
140        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

所有头部都使用相同的面具。

143        mask = mask.unsqueeze(-1)

生成的遮罩有形状[seq_len_q, seq_len_k, batch_size, heads]

146        return mask

query keyvalue 是存储查询向量集合的张量。它们有形状[seq_len, batch_size, d_model]

mask 有形状[seq_len, seq_len, batch_size]mask[i, j, b] 指示是否为批量查询b ,位置处的查询i 有权访问位置处的键值j

148    def forward(self, *,
149                query: torch.Tensor,
150                key: torch.Tensor,
151                value: torch.Tensor,
152                mask: Optional[torch.Tensor] = None):

querykey 并且value 有形状[seq_len, batch_size, d_model]

164        seq_len, batch_size, _ = query.shape
165
166        if mask is not None:
167            mask = self.prepare_mask(mask, query.shape, key.shape)

准备querykeyvalue 进行注意力计算。然后这些就会有形状[seq_len, batch_size, heads, d_k]

171        query = self.query(query)
172        key = self.key(key)
173        value = self.value(value)

计算注意力分数。这给出了形状的张量[seq_len, seq_len, batch_size, heads]

177        scores = self.get_scores(query, key)

音阶分数

180        scores *= self.scale

涂抹面膜

183        if mask is not None:
184            scores = scores.masked_fill(mask == 0, float('-inf'))

关注按键序列维度

188        attn = self.softmax(scores)

调试时省去注意力

191        tracker.debug('attn', attn)

申请退学

194        attn = self.dropout(attn)

乘以值

198        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

保存任何其他计算的注意力

201        self.attn = attn.detach()

连接多个头

204        x = x.reshape(seq_len, batch_size, -1)

输出层

207        return self.output(x)