This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer
17import math
18from typing import Optional
19
20import torch
21from torch import nn as nn
22
23from labml import tracker
24from labml_helpers.module import ModuleThis module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.
27class PrepareForMultiHeadAttention(Module):38    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
39        super().__init__()Linear layer for linear transform
41        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)Number of heads
43        self.heads = headsNumber of dimensions in vectors in each head
45        self.d_k = d_k47    def forward(self, x: torch.Tensor):Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model].
We apply the linear transformation to the last dimension and split that into
the heads.
51        head_shape = x.shape[:-1]Linear transform
54        x = self.linear(x)Split last dimension into heads
57        x = x.view(*head_shape, self.heads, self.d_k)Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]
60        return xThis computes scaled multi-headed attention for given query, key and value vectors.
In simple terms, it finds keys that matches the query, and gets the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. This is done to avoid large dot-product values causing softmax to give very small gradients when $d_k$ is large.
Softmax is calculated along the axis of of the sequence (or time).
63class MultiHeadAttention(Module):heads is the number of heads.d_model is the number of features in the query, key and value vectors.84    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):90        super().__init__()Number of features per head
93        self.d_k = d_model // headsNumber of heads
95        self.heads = headsThese transform the query, key and value vectors for multi-headed attention.
98        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
99        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
100        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)Softmax for attention along the time dimension of key
103        self.softmax = nn.Softmax(dim=1)Output layer
106        self.output = nn.Linear(d_model, d_model)Dropout
108        self.dropout = nn.Dropout(dropout_prob)Scaling factor before the softmax
110        self.scale = 1 / math.sqrt(self.d_k)We store attentions so that it can be used for logging, or other computations if needed
113        self.attn = NoneThis method can be overridden for other variations like relative attention.
115    def get_scores(self, query: torch.Tensor, key: torch.Tensor):Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
123        return torch.einsum('ibhd,jbhd->ijbh', query, key)query, key and value are the tensors that store
collection of query, key and value vectors.
They have shape [seq_len, batch_size, d_model].
mask has shape [seq_len, seq_len, batch_size] and
mask[i, j, b] indicates whether for batch b,
query at position i has access to key-value at position j.
125    def forward(self, *,
126                query: torch.Tensor,
127                key: torch.Tensor,
128                value: torch.Tensor,
129                mask: Optional[torch.Tensor] = None):query, key and value  have shape [seq_len, batch_size, d_model]
141        seq_len, batch_size, _ = query.shape
142
143        if mask is not None:mask has shape [seq_len_q, seq_len_k, batch_size],
where first dimension is the query dimension.
If the query dimension is equal to $1$ it will be broadcasted.
147            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
148            assert mask.shape[1] == key.shape[0]
149            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]Same mask applied to all heads.
152            mask = mask.unsqueeze(-1)Prepare query, key and value for attention computation.
These will then have shape [seq_len, batch_size, heads, d_k].
156        query = self.query(query)
157        key = self.key(key)
158        value = self.value(value)Compute attention scores $Q K^\top$.
This gives a tensor of shape [seq_len, seq_len, batch_size, heads].
162        scores = self.get_scores(query, key)Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
165        scores *= self.scaleApply mask
168        if mask is not None:
169            scores = scores.masked_fill(mask == 0, float('-inf'))$softmax$ attention along the key sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
173        attn = self.softmax(scores)Save attentions if debugging
176        tracker.debug('attn', attn)Apply dropout
179        attn = self.dropout(attn)Multiply by values
183        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)Save attentions for any other calculations
186        self.attn = attn.detach()Concatenate multiple heads
189        x = x.reshape(seq_len, batch_size, -1)Output layer
192        return self.output(x)