This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.
Here is the training code that uses a basic transformer with MHA for NLP auto-regression.
20import math
21from typing import Optional, List
22
23import torch
24from torch import nn as nn
25
26from labml import tracker
27from labml_helpers.module import ModuleThis module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.
30class PrepareForMultiHeadAttention(Module):41 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
42 super().__init__()Linear layer for linear transform
44 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)Number of heads
46 self.heads = headsNumber of dimensions in vectors in each head
48 self.d_k = d_k50 def forward(self, x: torch.Tensor):Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model].
We apply the linear transformation to the last dimension and split that into
the heads.
54 head_shape = x.shape[:-1]Linear transform
57 x = self.linear(x)Split last dimension into heads
60 x = x.view(*head_shape, self.heads, self.d_k)Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]
63 return xThis computes scaled multi-headed attention for given query, key and value vectors.
In simple terms, it finds keys that matches the query, and gets the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. This is done to avoid large dot-product values causing softmax to give very small gradients when $d_k$ is large.
Softmax is calculated along the axis of of the sequence (or time).
66class MultiHeadAttention(Module):heads is the number of heads.d_model is the number of features in the query, key and value vectors.87 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):93 super().__init__()Number of features per head
96 self.d_k = d_model // headsNumber of heads
98 self.heads = headsThese transform the query, key and value vectors for multi-headed attention.
101 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
102 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
103 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)Softmax for attention along the time dimension of key
106 self.softmax = nn.Softmax(dim=1)Output layer
109 self.output = nn.Linear(d_model, d_model)Dropout
111 self.dropout = nn.Dropout(dropout_prob)Scaling factor before the softmax
113 self.scale = 1 / math.sqrt(self.d_k)We store attentions so that it can be used for logging, or other computations if needed
116 self.attn = NoneThis method can be overridden for other variations like relative attention.
118 def get_scores(self, query: torch.Tensor, key: torch.Tensor):Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
126 return torch.einsum('ibhd,jbhd->ijbh', query, key)mask has shape [seq_len_q, seq_len_k, batch_size], where first dimension is the query dimension.
If the query dimension is equal to $1$ it will be broadcasted.
128 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):134 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
135 assert mask.shape[1] == key_shape[0]
136 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]Same mask applied to all heads.
139 mask = mask.unsqueeze(-1)resulting mask has shape [seq_len_q, seq_len_k, batch_size, heads]
142 return maskquery, key and value are the tensors that store
collection of query, key and value vectors.
They have shape [seq_len, batch_size, d_model].
mask has shape [seq_len, seq_len, batch_size] and
mask[i, j, b] indicates whether for batch b,
query at position i has access to key-value at position j.
144 def forward(self, *,
145 query: torch.Tensor,
146 key: torch.Tensor,
147 value: torch.Tensor,
148 mask: Optional[torch.Tensor] = None):query, key and value have shape [seq_len, batch_size, d_model]
160 seq_len, batch_size, _ = query.shape
161
162 if mask is not None:
163 mask = self.prepare_mask(mask, query.shape, key.shape)Prepare query, key and value for attention computation.
These will then have shape [seq_len, batch_size, heads, d_k].
167 query = self.query(query)
168 key = self.key(key)
169 value = self.value(value)Compute attention scores $Q K^\top$.
This gives a tensor of shape [seq_len, seq_len, batch_size, heads].
173 scores = self.get_scores(query, key)Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
176 scores *= self.scaleApply mask
179 if mask is not None:
180 scores = scores.masked_fill(mask == 0, float('-inf'))$softmax$ attention along the key sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
184 attn = self.softmax(scores)Save attentions if debugging
187 tracker.debug('attn', attn)Apply dropout
190 attn = self.dropout(attn)Multiply by values
194 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)Save attentions for any other calculations
197 self.attn = attn.detach()Concatenate multiple heads
200 x = x.reshape(seq_len, batch_size, -1)Output layer
203 return self.output(x)