Multi-Headed Attention

This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer

17import math
18from typing import Optional
19
20import torch
21from torch import nn as nn
22
23from labml import tracker
24from labml_helpers.module import Module

Prepare for multi-head attention

This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.

27class PrepareForMultiHeadAttention(Module):
36    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
37        super().__init__()

Linear layer for linear transform

39        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

Number of heads

41        self.heads = heads

Number of dimensions in vectors in each head

43        self.d_k = d_k
45    def forward(self, x: torch.Tensor):

Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model]. We apply the linear transformation to the last dimension and split that into the heads.

49        head_shape = x.shape[:-1]

Linear transform

52        x = self.linear(x)

Split last dimension into heads

55        x = x.view(*head_shape, self.heads, self.d_k)

Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]

58        return x

Multi-Head Attention Module

This computes scaled multi-headed attention for given query, key and value vectors.

In simple terms, it finds keys that matches the query, and gets the values of those keys.

It uses dot-product of query and key as the indicator of how matching they are. Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. This is done to avoid large dot-product values causing softmax to give very small gradients when $d_k$ is large.

Softmax is calculated along the axis of of the sequence (or time).

61class MultiHeadAttention(Module):
  • heads is the number of heads.
  • d_model is the number of features in the query, key and value vectors.
80    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
86        super().__init__()

Number of features per head

89        self.d_k = d_model // heads

Number of heads

91        self.heads = heads

These transform the query, key and value vectors for multi-headed attention.

94        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=bias)
95        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=bias)
96        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

Softmax for attention along the time dimension of key

99        self.softmax = nn.Softmax(dim=1)

Output layer

102        self.output = nn.Linear(d_model, d_model)

Dropout

104        self.dropout = nn.Dropout(dropout_prob)

Scaling factor before the softmax

106        self.scale = 1 / math.sqrt(self.d_k)

We store attentions so that it can be used for logging, or other computations if needed

109        self.attn = None

Calculate scores between queries and keys

This method can be overridden for other variations like relative attention.

111    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$

119        return torch.einsum('ibhd,jbhd->ijbh', query, key)

query, key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model].

mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b, query at position i has access to key-value at position j.

121    def forward(self, *,
122                 query: torch.Tensor,
123                 key: torch.Tensor,
124                 value: torch.Tensor,
125                 mask: Optional[torch.Tensor] = None):

query, key and value have shape [seq_len, batch_size, d_model]

137        seq_len, batch_size, _ = query.shape
138
139        if mask is not None:

mask has shape [seq_len_q, seq_len_k, batch_size], where first dimension is the query dimension. If the query dimension is equal to $1$ it will be broadcasted.

143            assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
144            assert mask.shape[1] == key.shape[0]
145            assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]

Same mask applied to all heads.

148            mask = mask.unsqueeze(-1)

Prepare query, key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k].

152        query = self.query(query)
153        key = self.key(key)
154        value = self.value(value)

Compute attention scores $Q K^\top$. This gives a tensor of shape [seq_len, seq_len, batch_size, heads].

158        scores = self.get_scores(query, key)

Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$

161        scores *= self.scale

Apply mask

164        if mask is not None:
165            scores = scores.masked_fill(mask == 0, -1e9)

$softmax$ attention along the key sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$

169        attn = self.softmax(scores)

Save attentions if debugging

172        tracker.debug('attn', attn)

Apply dropout

175        attn = self.dropout(attn)

Multiply by values

179        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

Save attentions for any other calculations

182        self.attn = attn.detach()

Concatenate multiple heads

185        x = x.reshape(seq_len, batch_size, -1)

Output layer

188        return self.output(x)