mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
208 lines
7.4 KiB
Python
208 lines
7.4 KiB
Python
"""
|
|
---
|
|
title: Multi-Headed Attention (MHA)
|
|
summary: >
|
|
This implements the Multi-Headed Attention used in transformers
|
|
using PyTorch with explanations.
|
|
---
|
|
|
|
# Multi-Headed Attention (MHA)
|
|
|
|
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/basic/autoregressive_experiment.ipynb)
|
|
[](https://comet.ml/labml/transformer/ea8c108c2d94434ca3c2bc2b21015082)
|
|
|
|
This is a tutorial/implementation of multi-headed attention
|
|
from paper [Attention Is All You Need](https://papers.labml.ai/paper/1706.03762)
|
|
in [PyTorch](https://pytorch.org/).
|
|
The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html).
|
|
|
|
Here is the [training code](basic/autoregressive_experiment.html) that uses a basic transformer
|
|
with MHA for NLP auto-regression.
|
|
|
|
[Here is an experiment implementation](basic/autoregressive_experiment.html) that trains a simple transformer.
|
|
"""
|
|
|
|
import math
|
|
from typing import Optional, List
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from labml import tracker
|
|
|
|
|
|
class PrepareForMultiHeadAttention(nn.Module):
|
|
"""
|
|
<a id="PrepareMHA"></a>
|
|
|
|
## Prepare for multi-head attention
|
|
|
|
This module does a linear transformation and splits the vector into given
|
|
number of heads for multi-head attention.
|
|
This is used to transform **key**, **query**, and **value** vectors.
|
|
"""
|
|
|
|
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
|
|
super().__init__()
|
|
# Linear layer for linear transform
|
|
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
|
|
# Number of heads
|
|
self.heads = heads
|
|
# Number of dimensions in vectors in each head
|
|
self.d_k = d_k
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
|
|
# We apply the linear transformation to the last dimension and split that into
|
|
# the heads.
|
|
head_shape = x.shape[:-1]
|
|
|
|
# Linear transform
|
|
x = self.linear(x)
|
|
|
|
# Split last dimension into heads
|
|
x = x.view(*head_shape, self.heads, self.d_k)
|
|
|
|
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
|
|
return x
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
r"""
|
|
<a id="MHA"></a>
|
|
|
|
## Multi-Head Attention Module
|
|
|
|
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
|
|
|
|
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
|
|
|
In simple terms, it finds keys that matches the query, and gets the values of
|
|
those keys.
|
|
|
|
It uses dot-product of query and key as the indicator of how matching they are.
|
|
Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
|
|
This is done to avoid large dot-product values causing softmax to
|
|
give very small gradients when $d_k$ is large.
|
|
|
|
Softmax is calculated along the axis of of the sequence (or time).
|
|
"""
|
|
|
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
|
|
"""
|
|
* `heads` is the number of heads.
|
|
* `d_model` is the number of features in the `query`, `key` and `value` vectors.
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
# Number of features per head
|
|
self.d_k = d_model // heads
|
|
# Number of heads
|
|
self.heads = heads
|
|
|
|
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
|
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
|
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
|
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
|
|
|
|
# Softmax for attention along the time dimension of `key`
|
|
self.softmax = nn.Softmax(dim=1)
|
|
|
|
# Output layer
|
|
self.output = nn.Linear(d_model, d_model)
|
|
# Dropout
|
|
self.dropout = nn.Dropout(dropout_prob)
|
|
# Scaling factor before the softmax
|
|
self.scale = 1 / math.sqrt(self.d_k)
|
|
|
|
# We store attentions so that it can be used for logging, or other computations if needed
|
|
self.attn = None
|
|
|
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
|
"""
|
|
### Calculate scores between queries and keys
|
|
|
|
This method can be overridden for other variations like relative attention.
|
|
"""
|
|
|
|
# Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
|
|
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
|
|
|
def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
|
|
"""
|
|
`mask` has shape `[seq_len_q, seq_len_k, batch_size]`, where first dimension is the query dimension.
|
|
If the query dimension is equal to $1$ it will be broadcasted.
|
|
"""
|
|
|
|
assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
|
|
assert mask.shape[1] == key_shape[0]
|
|
assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
|
|
|
|
# Same mask applied to all heads.
|
|
mask = mask.unsqueeze(-1)
|
|
|
|
# resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`
|
|
return mask
|
|
|
|
def forward(self, *,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None):
|
|
"""
|
|
`query`, `key` and `value` are the tensors that store
|
|
collection of *query*, *key* and *value* vectors.
|
|
They have shape `[seq_len, batch_size, d_model]`.
|
|
|
|
`mask` has shape `[seq_len, seq_len, batch_size]` and
|
|
`mask[i, j, b]` indicates whether for batch `b`,
|
|
query at position `i` has access to key-value at position `j`.
|
|
"""
|
|
|
|
# `query`, `key` and `value` have shape `[seq_len, batch_size, d_model]`
|
|
seq_len, batch_size, _ = query.shape
|
|
|
|
if mask is not None:
|
|
mask = self.prepare_mask(mask, query.shape, key.shape)
|
|
|
|
# Prepare `query`, `key` and `value` for attention computation.
|
|
# These will then have shape `[seq_len, batch_size, heads, d_k]`.
|
|
query = self.query(query)
|
|
key = self.key(key)
|
|
value = self.value(value)
|
|
|
|
# Compute attention scores $Q K^\top$.
|
|
# This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
|
|
scores = self.get_scores(query, key)
|
|
|
|
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
|
|
scores *= self.scale
|
|
|
|
# Apply mask
|
|
if mask is not None:
|
|
scores = scores.masked_fill(mask == 0, float('-inf'))
|
|
|
|
# $softmax$ attention along the key sequence dimension
|
|
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
|
|
attn = self.softmax(scores)
|
|
|
|
# Save attentions if debugging
|
|
tracker.debug('attn', attn)
|
|
|
|
# Apply dropout
|
|
attn = self.dropout(attn)
|
|
|
|
# Multiply by values
|
|
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
|
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
|
|
|
|
# Save attentions for any other calculations
|
|
self.attn = attn.detach()
|
|
|
|
# Concatenate multiple heads
|
|
x = x.reshape(seq_len, batch_size, -1)
|
|
|
|
# Output layer
|
|
return self.output(x)
|