mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 10:18:50 +08:00
relative attention notes
This commit is contained in:
@ -18,38 +18,17 @@ from labml_nn.transformers.models import FeedForward
|
||||
from labml_nn.utils import clone_module_list
|
||||
|
||||
|
||||
class PrepareQueryForMultiHeadAttention(Module):
|
||||
"""
|
||||
## Prepare query for multi-head attention
|
||||
|
||||
This module does a linear transformation and splits the vector into given
|
||||
number of heads for multi-head attention.
|
||||
This is used to transform **key**, **query**, and **value** vectors.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
|
||||
super().__init__()
|
||||
# Linear layer for linear transform
|
||||
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
|
||||
# Number of heads
|
||||
self.heads = heads
|
||||
# Number of dimensions in vectors in each head
|
||||
self.d_k = d_k
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
# Input has shape `[seq_len, batch_size, d_model]`
|
||||
batch_size, _ = x.shape
|
||||
|
||||
# Linear transform
|
||||
x = self.linear(x)
|
||||
# Split into heads
|
||||
x = x.view(batch_size, self.heads, self.d_k)
|
||||
|
||||
# Output has shape `[seq_len, batch_size, heads, d_k]`
|
||||
return x
|
||||
|
||||
|
||||
class FeedbackAttention(Module):
|
||||
"""
|
||||
## Feedback Attention
|
||||
|
||||
This is very similar to [Relative Multi-Head Attention](../relative_mha.html)
|
||||
but with some modifications.
|
||||
|
||||
📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html)
|
||||
or [Multi-Head Attention](../mha.html) to improve readability.
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
@ -57,7 +36,7 @@ class FeedbackAttention(Module):
|
||||
self.heads = heads
|
||||
|
||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||
self.query = PrepareQueryForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
|
||||
|
||||
@ -68,17 +47,60 @@ class FeedbackAttention(Module):
|
||||
# Scaling factor before the softmax
|
||||
self.scale = 1 / math.sqrt(self.d_k)
|
||||
|
||||
# Softmax for attention along the time dimension of `key`
|
||||
self.softmax = nn.Softmax(dim=0)
|
||||
|
||||
# Number of relative positions
|
||||
self.P = 2 ** 12
|
||||
|
||||
# Relative positional embeddings for key relative to the query.
|
||||
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
|
||||
# Relative positional embedding bias for key relative to the query.
|
||||
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
|
||||
# Positional embeddings for the query is independent of the position of the query
|
||||
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
||||
|
||||
# We store attentions so that it can used for logging, or other computations if needed
|
||||
self.attn = None
|
||||
|
||||
self.P = 2 ** 12
|
||||
|
||||
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
|
||||
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
|
||||
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
||||
self.softmax = nn.Softmax(dim=0)
|
||||
|
||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||
"""
|
||||
### Get relative attention scores
|
||||
|
||||
With absolute attention
|
||||
|
||||
\begin{align}
|
||||
A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
|
||||
&= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
|
||||
\color{cyan}{Q_i^T} \color{lightgreen}{U_j} +
|
||||
\color{cyan}{V_i^T} \color{lightgreen}{K_j} +
|
||||
\color{cyan}{V_i^T} \color{lightgreen}{U_j}
|
||||
\end{align}
|
||||
|
||||
where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of
|
||||
original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$
|
||||
and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of
|
||||
absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$.
|
||||
|
||||
They reason out that the attention to a given key should be the same regardless of
|
||||
the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$
|
||||
with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$.
|
||||
🤔 May be worthwhile testing without this assumption.
|
||||
|
||||
For the second and third terms relative positional encodings are introduced.
|
||||
So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is
|
||||
replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$
|
||||
and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$.
|
||||
|
||||
\begin{align}
|
||||
A^{rel}_{i,j} &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
|
||||
\color{cyan}{Q_i^T} \color{orange}{R_{i - j}} +
|
||||
\color{orange}{v^T} \color{lightgreen}{K_j} +
|
||||
\color{orange}{S_{i-j}}
|
||||
\end{align}
|
||||
"""
|
||||
|
||||
# $\color{orange}{R_{i - j}}$
|
||||
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
|
||||
key_pos_bias = self.key_pos_bias[-key.shape[0]:]
|
||||
query_pos_bias = self.query_pos_bias[None, :, :]
|
||||
|
||||
@ -18,10 +18,10 @@ import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from labml import tracker
|
||||
from labml_helpers.module import Module
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class PrepareForMultiHeadAttention(Module):
|
||||
@ -43,43 +43,51 @@ class PrepareForMultiHeadAttention(Module):
|
||||
self.d_k = d_k
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
# Input has shape `[seq_len, batch_size, d_model]`
|
||||
seq_len, batch_size, _ = x.shape
|
||||
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
|
||||
# We apply the linear transformation of the last dimension and splits that into
|
||||
# the heads
|
||||
head_shape = x.shape[:-1]
|
||||
|
||||
# Linear transform
|
||||
x = self.linear(x)
|
||||
# Split into heads
|
||||
x = x.view(seq_len, batch_size, self.heads, self.d_k)
|
||||
|
||||
# Output has shape `[seq_len, batch_size, heads, d_k]`
|
||||
# Split last dimension into heads
|
||||
x = x.view(*head_shape, self.heads, self.d_k)
|
||||
|
||||
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(Module):
|
||||
r"""
|
||||
## Multi-Head Attention Module
|
||||
|
||||
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
|
||||
|
||||
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
|
||||
|
||||
In simple terms, it finds keys that matches the query, and get the values of
|
||||
those keys.
|
||||
|
||||
It uses dot-product of query and key as the indicator of how matching they are.
|
||||
Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
|
||||
This is done to avoid large dot-product values causing softmax to
|
||||
give very small gradients when $d_k$ is large.
|
||||
|
||||
Softmax is calculate along the axis of of the sequence (or time).
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
|
||||
"""
|
||||
## Multi-Head Attention Module
|
||||
|
||||
* `heads` is the number of heads.
|
||||
* `d_model` is the number of features in the `query`, `key` and `value` vectors.
|
||||
|
||||
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
|
||||
|
||||
$$\mathop{Attention}(Q, K, V) = \mathop{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
|
||||
|
||||
In simple terms, it finds keys that matches the query, and get the values of
|
||||
those keys.
|
||||
|
||||
It uses dot-product of query and key as the indicator of how matching they are.
|
||||
Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
|
||||
This is done to avoid large dot-product values causing softmax to
|
||||
give very small gradients when $d_k$ is large.
|
||||
|
||||
Softmax is calculate along the axis of of the sequence (or time).
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Number of features per head
|
||||
self.d_k = d_model // heads
|
||||
# Number of heads
|
||||
self.heads = heads
|
||||
|
||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||
@ -87,6 +95,9 @@ class MultiHeadAttention(Module):
|
||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||
|
||||
# Softmax for attention along the time dimension of `key`
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
# Output layer
|
||||
self.output = nn.Linear(d_model, d_model)
|
||||
# Dropout
|
||||
@ -153,7 +164,7 @@ class MultiHeadAttention(Module):
|
||||
|
||||
# $softmax$ attention along the key sequence dimension
|
||||
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
|
||||
attn = F.softmax(scores, dim=1)
|
||||
attn = self.softmax(scores)
|
||||
|
||||
# Save attentions if debugging
|
||||
tracker.debug('attn', attn)
|
||||
|
||||
@ -51,56 +51,77 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
|
||||
# The linear transformations doesn't need a bias since we take care of it when
|
||||
# calculating scores.
|
||||
# However having a bias for `value` might make sense.
|
||||
super().__init__(heads, d_model, dropout_prob, False)
|
||||
super().__init__(heads, d_model, dropout_prob, bias=False)
|
||||
|
||||
# Number of relative positions
|
||||
self.P = 2 ** 12
|
||||
|
||||
# Relative positional embeddings for key relative to the query.
|
||||
# We need $2P$ embeddings because the keys can be before or after the query.
|
||||
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
|
||||
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
||||
# Relative positional embedding bias for key relative to the query.
|
||||
self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
|
||||
# Positional embeddings for the query is independent of the position of the query
|
||||
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
||||
|
||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||
"""
|
||||
r"""
|
||||
### Get relative attention scores
|
||||
|
||||
With absolute attention
|
||||
|
||||
\begin{align}
|
||||
A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j) \\
|
||||
&= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j
|
||||
A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
|
||||
&= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
|
||||
\color{cyan}{Q_i^T} \color{lightgreen}{U_j} +
|
||||
\color{cyan}{V_i^T} \color{lightgreen}{K_j} +
|
||||
\color{cyan}{V_i^T} \color{lightgreen}{U_j}
|
||||
\end{align}
|
||||
|
||||
where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of
|
||||
orginal embeddings and positional encodings.
|
||||
where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of
|
||||
original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$
|
||||
and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of
|
||||
absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$.
|
||||
|
||||
They reason out that the attention to a given key should be the same regardless of
|
||||
the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$.
|
||||
They reason out that the attention to a given key should be the same regardless of
|
||||
the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$
|
||||
with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$.
|
||||
🤔 May be worthwhile testing without this assumption.
|
||||
|
||||
For the second and third terms relative positional encodings are introduced.
|
||||
So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$.
|
||||
So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is
|
||||
replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$
|
||||
and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$.
|
||||
|
||||
\begin{align}
|
||||
A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j}
|
||||
A^{rel}_{i,j} &= \underset{\mathbf{A}}{\color{cyan}{Q_i^T} \color{lightgreen}{K_j}} +
|
||||
\underset{\mathbf{B}}{\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}} +
|
||||
\underset{\mathbf{C}}{\color{orange}{v^T} \color{lightgreen}{K_j}} +
|
||||
\underset{\mathbf{D}}{\color{orange}{S_{i-j}}}
|
||||
\end{align}
|
||||
|
||||
"""
|
||||
|
||||
# $R_{i-j}$ pre-shift
|
||||
# $\color{orange}{R_k}$
|
||||
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||
# $S_{i-j}$ pre-shift
|
||||
# $\color{orange}{S_k}$
|
||||
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||
# $v^T$
|
||||
# $\color{orange}{v^T}$
|
||||
query_pos_bias = self.query_pos_bias[None, None, :, :]
|
||||
|
||||
# $Q_i^T K_j + v^T K_j$
|
||||
# ${(\mathbf{A} + \mathbf{C})}_{i,j} = \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
|
||||
# \color{orange}{v^T} \color{lightgreen}{K_j}$
|
||||
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
|
||||
# $Q_i^T R_{i - j}$ pre-shift
|
||||
# $\mathbf{B'}_{i,k} = \color{cyan}{Q_i^T} \color{orange}{R_k}$
|
||||
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
|
||||
# $S_{i-j}$ pre-shift
|
||||
# $\mathbf{D'}_{i,k} = \color{orange}{S_k}$
|
||||
d = key_pos_bias[None, :, None, :]
|
||||
# $Q_i^T R_{i - j} + S_{i-j}$
|
||||
# Shift the rows of $\mathbf{(B' + D')}_{i,k}$
|
||||
# to get $$\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}$$
|
||||
bd = shift_right(b + d)
|
||||
# Remove extra positions
|
||||
bd = bd[:, -key.shape[0]:]
|
||||
|
||||
# Return the sum $\mathbf{(A + B + C + D)}_{i,j}$
|
||||
return ac + bd
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user