这是论文《 Attention is All You Need 》中多头注意力的PyTorch教程/实现。该实现的灵感来自《带注释的 Transformer 》。
这是使用基础 Transformer 和 MHA 进行 NLP 自回归的训练代码。
这是一个训练简单 Transformer 的代码实现。
24import math
25from typing import Optional, List
26
27import torch
28from torch import nn
29
30from labml import tracker33class PrepareForMultiHeadAttention(nn.Module):44 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
45 super().__init__()线性层用于线性变换
47 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)注意力头数
49 self.heads = heads每个头部中向量的维度数量
51 self.d_k = d_k53 def forward(self, x: torch.Tensor):输入的形状为[seq_len, batch_size, d_model]
或[batch_size, d_model]
。我们对最后一维应用线性变换,并将其分为多个头。
57 head_shape = x.shape[:-1]线性变换
60 x = self.linear(x)将最后一个维度分成多个头部
63 x = x.view(*head_shape, self.heads, self.d_k)输出具有形状[seq_len, batch_size, heads, d_k]
或[batch_size, heads, d_model]
66 return x这将计算给出的key
、value
和query
向量缩放后的多头注意力。
简单来说,它会找到与查询 (Query) 匹配的键 (key),并获取这些键 (Key) 的值 (Value) 。
它使用查询和键的点积作为衡量它们之间匹配程度的指标。在进行之前,点积会乘以。这样做是为了避免当较大时,大的点积值导致 Softmax 操作输出非常小的梯度。
Softmax 是沿序列(或时间)轴计算的。
69class MultiHeadAttention(nn.Module):heads
是注意力头的数量。d_model
是向量query
、key
和value
中的特征数量。90 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):96 super().__init__()每个头部的特征数量
99 self.d_k = d_model // heads注意力头数
101 self.heads = heads这些将对多头注意力的向量query
、key
和value
进行转换。
104 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
105 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
106 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)在键( Key )的时间维度上进行注意力 Softmaxkey
109 self.softmax = nn.Softmax(dim=1)输出层
112 self.output = nn.Linear(d_model, d_model)Dropout
114 self.dropout = nn.Dropout(dropout_prob)Softmax 之前的缩放系数
116 self.scale = 1 / math.sqrt(self.d_k)存储注意力信息,以便在需要时用于记录或其他计算。
119 self.attn = None121 def get_scores(self, query: torch.Tensor, key: torch.Tensor):计算或
129 return torch.einsum('ibhd,jbhd->ijbh', query, key)mask
的形状为[seq_len_q, seq_len_k, batch_size]
,其中第一维是查询维度。如果查询维度等于,则会进行广播。
131 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):137 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
138 assert mask.shape[1] == key_shape[0]
139 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]所有的头部使用相同的掩码。
142 mask = mask.unsqueeze(-1)生成的掩码形状为[seq_len_q, seq_len_k, batch_size, heads]
145 return maskquery
、key
和value
是存储查询、键和值向量集合的张量。它们的形状为[seq_len, batch_size, d_model]
。
mask
的形状为[seq_len, seq_len, batch_size]
,mask[i, j, b]
表示批次b
,在位置i
处查询是否有权访问位置j
处的键值对。
147 def forward(self, *,
148 query: torch.Tensor,
149 key: torch.Tensor,
150 value: torch.Tensor,
151 mask: Optional[torch.Tensor] = None):query
,key
和value
的形状为[seq_len, batch_size, d_model]
163 seq_len, batch_size, _ = query.shape
164
165 if mask is not None:
166 mask = self.prepare_mask(mask, query.shape, key.shape)为注意力计算准备向量query
,key
并value
它们的形状将变为[seq_len, batch_size, heads, d_k]
。
170 query = self.query(query)
171 key = self.key(key)
172 value = self.value(value)计算注意力分数,这将得到一个形状为[seq_len, seq_len, batch_size, heads]
的张量。
176 scores = self.get_scores(query, key)缩放分数
179 scores *= self.scale应用掩码
182 if mask is not None:
183 scores = scores.masked_fill(mask == 0, float('-inf'))对 Key 序列维度上的注意力进行操作,
187 attn = self.softmax(scores)调试时保存注意力信息
190 tracker.debug('attn', attn)应用 Dropout
193 attn = self.dropout(attn)乘以数值
197 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)为其他计算保存注意力信息
200 self.attn = attn.detach()连接多个头
203 x = x.reshape(seq_len, batch_size, -1)输出层
206 return self.output(x)