mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 03:43:09 +08:00
relative attention notes
This commit is contained in:
@ -18,38 +18,17 @@ from labml_nn.transformers.models import FeedForward
|
|||||||
from labml_nn.utils import clone_module_list
|
from labml_nn.utils import clone_module_list
|
||||||
|
|
||||||
|
|
||||||
class PrepareQueryForMultiHeadAttention(Module):
|
|
||||||
"""
|
|
||||||
## Prepare query for multi-head attention
|
|
||||||
|
|
||||||
This module does a linear transformation and splits the vector into given
|
|
||||||
number of heads for multi-head attention.
|
|
||||||
This is used to transform **key**, **query**, and **value** vectors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
|
|
||||||
super().__init__()
|
|
||||||
# Linear layer for linear transform
|
|
||||||
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
|
|
||||||
# Number of heads
|
|
||||||
self.heads = heads
|
|
||||||
# Number of dimensions in vectors in each head
|
|
||||||
self.d_k = d_k
|
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
|
||||||
# Input has shape `[seq_len, batch_size, d_model]`
|
|
||||||
batch_size, _ = x.shape
|
|
||||||
|
|
||||||
# Linear transform
|
|
||||||
x = self.linear(x)
|
|
||||||
# Split into heads
|
|
||||||
x = x.view(batch_size, self.heads, self.d_k)
|
|
||||||
|
|
||||||
# Output has shape `[seq_len, batch_size, heads, d_k]`
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackAttention(Module):
|
class FeedbackAttention(Module):
|
||||||
|
"""
|
||||||
|
## Feedback Attention
|
||||||
|
|
||||||
|
This is very similar to [Relative Multi-Head Attention](../relative_mha.html)
|
||||||
|
but with some modifications.
|
||||||
|
|
||||||
|
📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html)
|
||||||
|
or [Multi-Head Attention](../mha.html) to improve readability.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -57,7 +36,7 @@ class FeedbackAttention(Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||||
self.query = PrepareQueryForMultiHeadAttention(d_model, heads, self.d_k, False)
|
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||||
|
|
||||||
@ -68,17 +47,60 @@ class FeedbackAttention(Module):
|
|||||||
# Scaling factor before the softmax
|
# Scaling factor before the softmax
|
||||||
self.scale = 1 / math.sqrt(self.d_k)
|
self.scale = 1 / math.sqrt(self.d_k)
|
||||||
|
|
||||||
|
# Softmax for attention along the time dimension of `key`
|
||||||
|
self.softmax = nn.Softmax(dim=0)
|
||||||
|
|
||||||
|
# Number of relative positions
|
||||||
|
self.P = 2 ** 12
|
||||||
|
|
||||||
|
# Relative positional embeddings for key relative to the query.
|
||||||
|
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
|
||||||
|
# Relative positional embedding bias for key relative to the query.
|
||||||
|
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
|
||||||
|
# Positional embeddings for the query is independent of the position of the query
|
||||||
|
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
||||||
|
|
||||||
# We store attentions so that it can used for logging, or other computations if needed
|
# We store attentions so that it can used for logging, or other computations if needed
|
||||||
self.attn = None
|
self.attn = None
|
||||||
|
|
||||||
self.P = 2 ** 12
|
|
||||||
|
|
||||||
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
|
|
||||||
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
|
|
||||||
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
|
||||||
self.softmax = nn.Softmax(dim=0)
|
|
||||||
|
|
||||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||||
|
"""
|
||||||
|
### Get relative attention scores
|
||||||
|
|
||||||
|
With absolute attention
|
||||||
|
|
||||||
|
\begin{align}
|
||||||
|
A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
|
||||||
|
&= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
|
||||||
|
\color{cyan}{Q_i^T} \color{lightgreen}{U_j} +
|
||||||
|
\color{cyan}{V_i^T} \color{lightgreen}{K_j} +
|
||||||
|
\color{cyan}{V_i^T} \color{lightgreen}{U_j}
|
||||||
|
\end{align}
|
||||||
|
|
||||||
|
where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of
|
||||||
|
original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$
|
||||||
|
and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of
|
||||||
|
absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$.
|
||||||
|
|
||||||
|
They reason out that the attention to a given key should be the same regardless of
|
||||||
|
the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$
|
||||||
|
with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$.
|
||||||
|
🤔 May be worthwhile testing without this assumption.
|
||||||
|
|
||||||
|
For the second and third terms relative positional encodings are introduced.
|
||||||
|
So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is
|
||||||
|
replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$
|
||||||
|
and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$.
|
||||||
|
|
||||||
|
\begin{align}
|
||||||
|
A^{rel}_{i,j} &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
|
||||||
|
\color{cyan}{Q_i^T} \color{orange}{R_{i - j}} +
|
||||||
|
\color{orange}{v^T} \color{lightgreen}{K_j} +
|
||||||
|
\color{orange}{S_{i-j}}
|
||||||
|
\end{align}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# $\color{orange}{R_{i - j}}$
|
||||||
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
|
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
|
||||||
key_pos_bias = self.key_pos_bias[-key.shape[0]:]
|
key_pos_bias = self.key_pos_bias[-key.shape[0]:]
|
||||||
query_pos_bias = self.query_pos_bias[None, :, :]
|
query_pos_bias = self.query_pos_bias[None, :, :]
|
||||||
|
|||||||
@ -18,10 +18,10 @@ import math
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
from labml import tracker
|
from labml import tracker
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from torch import nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class PrepareForMultiHeadAttention(Module):
|
class PrepareForMultiHeadAttention(Module):
|
||||||
@ -43,29 +43,28 @@ class PrepareForMultiHeadAttention(Module):
|
|||||||
self.d_k = d_k
|
self.d_k = d_k
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def __call__(self, x: torch.Tensor):
|
||||||
# Input has shape `[seq_len, batch_size, d_model]`
|
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
|
||||||
seq_len, batch_size, _ = x.shape
|
# We apply the linear transformation of the last dimension and splits that into
|
||||||
|
# the heads
|
||||||
|
head_shape = x.shape[:-1]
|
||||||
|
|
||||||
# Linear transform
|
# Linear transform
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
# Split into heads
|
|
||||||
x = x.view(seq_len, batch_size, self.heads, self.d_k)
|
|
||||||
|
|
||||||
# Output has shape `[seq_len, batch_size, heads, d_k]`
|
# Split last dimension into heads
|
||||||
|
x = x.view(*head_shape, self.heads, self.d_k)
|
||||||
|
|
||||||
|
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(Module):
|
class MultiHeadAttention(Module):
|
||||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
|
r"""
|
||||||
"""
|
|
||||||
## Multi-Head Attention Module
|
## Multi-Head Attention Module
|
||||||
|
|
||||||
* `heads` is the number of heads.
|
|
||||||
* `d_model` is the number of features in the `query`, `key` and `value` vectors.
|
|
||||||
|
|
||||||
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
|
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
|
||||||
|
|
||||||
$$\mathop{Attention}(Q, K, V) = \mathop{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
|
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
|
||||||
|
|
||||||
In simple terms, it finds keys that matches the query, and get the values of
|
In simple terms, it finds keys that matches the query, and get the values of
|
||||||
those keys.
|
those keys.
|
||||||
@ -78,8 +77,17 @@ class MultiHeadAttention(Module):
|
|||||||
Softmax is calculate along the axis of of the sequence (or time).
|
Softmax is calculate along the axis of of the sequence (or time).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
|
||||||
|
"""
|
||||||
|
* `heads` is the number of heads.
|
||||||
|
* `d_model` is the number of features in the `query`, `key` and `value` vectors.
|
||||||
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# Number of features per head
|
||||||
self.d_k = d_model // heads
|
self.d_k = d_model // heads
|
||||||
|
# Number of heads
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||||
@ -87,6 +95,9 @@ class MultiHeadAttention(Module):
|
|||||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||||
|
|
||||||
|
# Softmax for attention along the time dimension of `key`
|
||||||
|
self.softmax = nn.Softmax(dim=1)
|
||||||
|
|
||||||
# Output layer
|
# Output layer
|
||||||
self.output = nn.Linear(d_model, d_model)
|
self.output = nn.Linear(d_model, d_model)
|
||||||
# Dropout
|
# Dropout
|
||||||
@ -153,7 +164,7 @@ class MultiHeadAttention(Module):
|
|||||||
|
|
||||||
# $softmax$ attention along the key sequence dimension
|
# $softmax$ attention along the key sequence dimension
|
||||||
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
|
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
|
||||||
attn = F.softmax(scores, dim=1)
|
attn = self.softmax(scores)
|
||||||
|
|
||||||
# Save attentions if debugging
|
# Save attentions if debugging
|
||||||
tracker.debug('attn', attn)
|
tracker.debug('attn', attn)
|
||||||
|
|||||||
@ -51,56 +51,77 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
|
|||||||
# The linear transformations doesn't need a bias since we take care of it when
|
# The linear transformations doesn't need a bias since we take care of it when
|
||||||
# calculating scores.
|
# calculating scores.
|
||||||
# However having a bias for `value` might make sense.
|
# However having a bias for `value` might make sense.
|
||||||
super().__init__(heads, d_model, dropout_prob, False)
|
super().__init__(heads, d_model, dropout_prob, bias=False)
|
||||||
|
|
||||||
|
# Number of relative positions
|
||||||
self.P = 2 ** 12
|
self.P = 2 ** 12
|
||||||
|
|
||||||
|
# Relative positional embeddings for key relative to the query.
|
||||||
|
# We need $2P$ embeddings because the keys can be before or after the query.
|
||||||
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
|
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
|
||||||
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
# Relative positional embedding bias for key relative to the query.
|
||||||
self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
|
self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
|
||||||
|
# Positional embeddings for the query is independent of the position of the query
|
||||||
|
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
||||||
|
|
||||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||||
"""
|
r"""
|
||||||
|
### Get relative attention scores
|
||||||
|
|
||||||
With absolute attention
|
With absolute attention
|
||||||
|
|
||||||
\begin{align}
|
\begin{align}
|
||||||
A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j) \\
|
A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
|
||||||
&= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j
|
&= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
|
||||||
|
\color{cyan}{Q_i^T} \color{lightgreen}{U_j} +
|
||||||
|
\color{cyan}{V_i^T} \color{lightgreen}{K_j} +
|
||||||
|
\color{cyan}{V_i^T} \color{lightgreen}{U_j}
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of
|
where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of
|
||||||
orginal embeddings and positional encodings.
|
original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$
|
||||||
|
and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of
|
||||||
|
absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$.
|
||||||
|
|
||||||
They reason out that the attention to a given key should be the same regardless of
|
They reason out that the attention to a given key should be the same regardless of
|
||||||
the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$.
|
the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$
|
||||||
|
with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$.
|
||||||
🤔 May be worthwhile testing without this assumption.
|
🤔 May be worthwhile testing without this assumption.
|
||||||
|
|
||||||
For the second and third terms relative positional encodings are introduced.
|
For the second and third terms relative positional encodings are introduced.
|
||||||
So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$.
|
So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is
|
||||||
|
replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$
|
||||||
|
and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$.
|
||||||
|
|
||||||
\begin{align}
|
\begin{align}
|
||||||
A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j}
|
A^{rel}_{i,j} &= \underset{\mathbf{A}}{\color{cyan}{Q_i^T} \color{lightgreen}{K_j}} +
|
||||||
|
\underset{\mathbf{B}}{\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}} +
|
||||||
|
\underset{\mathbf{C}}{\color{orange}{v^T} \color{lightgreen}{K_j}} +
|
||||||
|
\underset{\mathbf{D}}{\color{orange}{S_{i-j}}}
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# $R_{i-j}$ pre-shift
|
# $\color{orange}{R_k}$
|
||||||
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
|
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||||
# $S_{i-j}$ pre-shift
|
# $\color{orange}{S_k}$
|
||||||
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
|
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||||
# $v^T$
|
# $\color{orange}{v^T}$
|
||||||
query_pos_bias = self.query_pos_bias[None, None, :, :]
|
query_pos_bias = self.query_pos_bias[None, None, :, :]
|
||||||
|
|
||||||
# $Q_i^T K_j + v^T K_j$
|
# ${(\mathbf{A} + \mathbf{C})}_{i,j} = \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
|
||||||
|
# \color{orange}{v^T} \color{lightgreen}{K_j}$
|
||||||
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
|
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
|
||||||
# $Q_i^T R_{i - j}$ pre-shift
|
# $\mathbf{B'}_{i,k} = \color{cyan}{Q_i^T} \color{orange}{R_k}$
|
||||||
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
|
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
|
||||||
# $S_{i-j}$ pre-shift
|
# $\mathbf{D'}_{i,k} = \color{orange}{S_k}$
|
||||||
d = key_pos_bias[None, :, None, :]
|
d = key_pos_bias[None, :, None, :]
|
||||||
# $Q_i^T R_{i - j} + S_{i-j}$
|
# Shift the rows of $\mathbf{(B' + D')}_{i,k}$
|
||||||
|
# to get $$\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}$$
|
||||||
bd = shift_right(b + d)
|
bd = shift_right(b + d)
|
||||||
|
# Remove extra positions
|
||||||
bd = bd[:, -key.shape[0]:]
|
bd = bd[:, -key.shape[0]:]
|
||||||
|
|
||||||
|
# Return the sum $\mathbf{(A + B + C + D)}_{i,j}$
|
||||||
return ac + bd
|
return ac + bd
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user