This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer
17import math
18from typing import Optional
19
20import torch
21from torch import nn as nn
22
23from labml import tracker
24from labml_helpers.module import Module
This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.
27class PrepareForMultiHeadAttention(Module):
36 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
37 super().__init__()
Linear layer for linear transform
39 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
Number of heads
41 self.heads = heads
Number of dimensions in vectors in each head
43 self.d_k = d_k
45 def __call__(self, x: torch.Tensor):
Input has shape [seq_len, batch_size, d_model]
or [batch_size, d_model]
.
We apply the linear transformation of the last dimension and splits that into
the heads
49 head_shape = x.shape[:-1]
Linear transform
52 x = self.linear(x)
Split last dimension into heads
55 x = x.view(*head_shape, self.heads, self.d_k)
Output has shape [seq_len, batch_size, heads, d_k]
or [batch_size, d_model]
58 return x
This computes scaled multi-headed attention for given query
, key
and value
vectors.
In simple terms, it finds keys that matches the query, and get the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. This is done to avoid large dot-product values causing softmax to give very small gradients when $d_k$ is large.
Softmax is calculate along the axis of of the sequence (or time).
61class MultiHeadAttention(Module):
heads
is the number of heads.d_model
is the number of features in the query
, key
and value
vectors.80 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
86 super().__init__()
Number of features per head
89 self.d_k = d_model // heads
Number of heads
91 self.heads = heads
These transform the query
, key
and value
vectors for multi-headed attention.
94 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
95 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
96 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
Softmax for attention along the time dimension of key
99 self.softmax = nn.Softmax(dim=1)
Output layer
102 self.output = nn.Linear(d_model, d_model)
Dropout
104 self.dropout = nn.Dropout(dropout_prob)
Scaling factor before the softmax
106 self.scale = 1 / math.sqrt(self.d_k)
We store attentions so that it can used for logging, or other computations if needed
109 self.attn = None
This method can be overridden for other variations like relative attention.
111 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
119 return torch.einsum('ibhd,jbhd->ijbh', query, key)
query
, key
and value
are the tensors that store
collection ofquery, key and value vectors.
They have shape [seq_len, batch_size, d_model]
.
mask
has shape [seq_len, seq_len, batch_size]
and indicates
mask[i, j, b]
indicates whether for batch b
,
query at position i
has access to key-value at position j
.
121 def __call__(self, *,
122 query: torch.Tensor,
123 key: torch.Tensor,
124 value: torch.Tensor,
125 mask: Optional[torch.Tensor] = None):
query
, key
and value
have shape [seq_len, batch_size, d_model]
137 seq_len, batch_size, _ = query.shape
138
139 if mask is not None:
mask
has shape [seq_len, seq_len, batch_size]
,
where first dimension is the query dimension.
If the query dimension is equal to $1$ it will be broadcasted
143 assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
Same mask applied to all heads.
146 mask = mask.unsqueeze(-1)
Prepare query
, key
and value
for attention computation
These will then have shape [seq_len, batch_size, heads, d_k]
150 query = self.query(query)
151 key = self.key(key)
152 value = self.value(value)
Compute attention scores $Q K^\top$
Results in a tensor of shape [seq_len, seq_len, batch_size, heads]
156 scores = self.get_scores(query, key)
Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
159 scores *= self.scale
Apply mask
162 if mask is not None:
163 scores = scores.masked_fill(mask == 0, -1e9)
$softmax$ attention along the key sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
167 attn = self.softmax(scores)
Save attentions if debugging
170 tracker.debug('attn', attn)
Apply dropout
173 attn = self.dropout(attn)
Multiply by values
177 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
Save attentions for any other calculations
180 self.attn = attn.detach()
Concatenate multiple heads
183 x = x.reshape(seq_len, batch_size, -1)
Output layer
186 return self.output(x)