relative attention notes

This commit is contained in:
Varuna Jayasiri
2021-01-09 10:41:25 +05:30
parent 9f4b494bf2
commit 809a54d6aa
3 changed files with 136 additions and 82 deletions

View File

@ -18,38 +18,17 @@ from labml_nn.transformers.models import FeedForward
from labml_nn.utils import clone_module_list from labml_nn.utils import clone_module_list
class PrepareQueryForMultiHeadAttention(Module):
"""
## Prepare query for multi-head attention
This module does a linear transformation and splits the vector into given
number of heads for multi-head attention.
This is used to transform **key**, **query**, and **value** vectors.
"""
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
super().__init__()
# Linear layer for linear transform
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
# Number of heads
self.heads = heads
# Number of dimensions in vectors in each head
self.d_k = d_k
def __call__(self, x: torch.Tensor):
# Input has shape `[seq_len, batch_size, d_model]`
batch_size, _ = x.shape
# Linear transform
x = self.linear(x)
# Split into heads
x = x.view(batch_size, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]`
return x
class FeedbackAttention(Module): class FeedbackAttention(Module):
"""
## Feedback Attention
This is very similar to [Relative Multi-Head Attention](../relative_mha.html)
but with some modifications.
📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html)
or [Multi-Head Attention](../mha.html) to improve readability.
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
super().__init__() super().__init__()
@ -57,7 +36,7 @@ class FeedbackAttention(Module):
self.heads = heads self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention. # These transform the `query`, `key` and `value` vectors for multi-headed attention.
self.query = PrepareQueryForMultiHeadAttention(d_model, heads, self.d_k, False) self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False) self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
@ -68,17 +47,60 @@ class FeedbackAttention(Module):
# Scaling factor before the softmax # Scaling factor before the softmax
self.scale = 1 / math.sqrt(self.d_k) self.scale = 1 / math.sqrt(self.d_k)
# Softmax for attention along the time dimension of `key`
self.softmax = nn.Softmax(dim=0)
# Number of relative positions
self.P = 2 ** 12
# Relative positional embeddings for key relative to the query.
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
# Relative positional embedding bias for key relative to the query.
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
# Positional embeddings for the query is independent of the position of the query
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
# We store attentions so that it can used for logging, or other computations if needed # We store attentions so that it can used for logging, or other computations if needed
self.attn = None self.attn = None
self.P = 2 ** 12
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
self.softmax = nn.Softmax(dim=0)
def get_scores(self, query: torch.Tensor, key: torch.Tensor): def get_scores(self, query: torch.Tensor, key: torch.Tensor):
"""
### Get relative attention scores
With absolute attention
\begin{align}
A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
&= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
\color{cyan}{Q_i^T} \color{lightgreen}{U_j} +
\color{cyan}{V_i^T} \color{lightgreen}{K_j} +
\color{cyan}{V_i^T} \color{lightgreen}{U_j}
\end{align}
where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of
original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$
and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of
absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$.
They reason out that the attention to a given key should be the same regardless of
the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$
with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$.
🤔 May be worthwhile testing without this assumption.
For the second and third terms relative positional encodings are introduced.
So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is
replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$
and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$.
\begin{align}
A^{rel}_{i,j} &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
\color{cyan}{Q_i^T} \color{orange}{R_{i - j}} +
\color{orange}{v^T} \color{lightgreen}{K_j} +
\color{orange}{S_{i-j}}
\end{align}
"""
# $\color{orange}{R_{i - j}}$
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:] key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
key_pos_bias = self.key_pos_bias[-key.shape[0]:] key_pos_bias = self.key_pos_bias[-key.shape[0]:]
query_pos_bias = self.query_pos_bias[None, :, :] query_pos_bias = self.query_pos_bias[None, :, :]

View File

@ -18,10 +18,10 @@ import math
from typing import Optional from typing import Optional
import torch import torch
from torch import nn as nn
from labml import tracker from labml import tracker
from labml_helpers.module import Module from labml_helpers.module import Module
from torch import nn as nn
from torch.nn import functional as F
class PrepareForMultiHeadAttention(Module): class PrepareForMultiHeadAttention(Module):
@ -43,29 +43,28 @@ class PrepareForMultiHeadAttention(Module):
self.d_k = d_k self.d_k = d_k
def __call__(self, x: torch.Tensor): def __call__(self, x: torch.Tensor):
# Input has shape `[seq_len, batch_size, d_model]` # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
seq_len, batch_size, _ = x.shape # We apply the linear transformation of the last dimension and splits that into
# the heads
head_shape = x.shape[:-1]
# Linear transform # Linear transform
x = self.linear(x) x = self.linear(x)
# Split into heads
x = x.view(seq_len, batch_size, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]` # Split last dimension into heads
x = x.view(*head_shape, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
return x return x
class MultiHeadAttention(Module): class MultiHeadAttention(Module):
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True): r"""
"""
## Multi-Head Attention Module ## Multi-Head Attention Module
* `heads` is the number of heads.
* `d_model` is the number of features in the `query`, `key` and `value` vectors.
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors. This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
$$\mathop{Attention}(Q, K, V) = \mathop{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
In simple terms, it finds keys that matches the query, and get the values of In simple terms, it finds keys that matches the query, and get the values of
those keys. those keys.
@ -78,8 +77,17 @@ class MultiHeadAttention(Module):
Softmax is calculate along the axis of of the sequence (or time). Softmax is calculate along the axis of of the sequence (or time).
""" """
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
"""
* `heads` is the number of heads.
* `d_model` is the number of features in the `query`, `key` and `value` vectors.
"""
super().__init__() super().__init__()
# Number of features per head
self.d_k = d_model // heads self.d_k = d_model // heads
# Number of heads
self.heads = heads self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention. # These transform the `query`, `key` and `value` vectors for multi-headed attention.
@ -87,6 +95,9 @@ class MultiHeadAttention(Module):
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
# Softmax for attention along the time dimension of `key`
self.softmax = nn.Softmax(dim=1)
# Output layer # Output layer
self.output = nn.Linear(d_model, d_model) self.output = nn.Linear(d_model, d_model)
# Dropout # Dropout
@ -153,7 +164,7 @@ class MultiHeadAttention(Module):
# $softmax$ attention along the key sequence dimension # $softmax$ attention along the key sequence dimension
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$ # $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
attn = F.softmax(scores, dim=1) attn = self.softmax(scores)
# Save attentions if debugging # Save attentions if debugging
tracker.debug('attn', attn) tracker.debug('attn', attn)

View File

@ -51,56 +51,77 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
# The linear transformations doesn't need a bias since we take care of it when # The linear transformations doesn't need a bias since we take care of it when
# calculating scores. # calculating scores.
# However having a bias for `value` might make sense. # However having a bias for `value` might make sense.
super().__init__(heads, d_model, dropout_prob, False) super().__init__(heads, d_model, dropout_prob, bias=False)
# Number of relative positions
self.P = 2 ** 12 self.P = 2 ** 12
# Relative positional embeddings for key relative to the query.
# We need $2P$ embeddings because the keys can be before or after the query.
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) # Relative positional embedding bias for key relative to the query.
self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
# Positional embeddings for the query is independent of the position of the query
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
def get_scores(self, query: torch.Tensor, key: torch.Tensor): def get_scores(self, query: torch.Tensor, key: torch.Tensor):
""" r"""
### Get relative attention scores
With absolute attention With absolute attention
\begin{align} \begin{align}
A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j) \\ A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
&= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
\color{cyan}{Q_i^T} \color{lightgreen}{U_j} +
\color{cyan}{V_i^T} \color{lightgreen}{K_j} +
\color{cyan}{V_i^T} \color{lightgreen}{U_j}
\end{align} \end{align}
where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of
orginal embeddings and positional encodings. original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$
and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of
absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$.
They reason out that the attention to a given key should be the same regardless of They reason out that the attention to a given key should be the same regardless of
the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$. the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$
with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$.
🤔 May be worthwhile testing without this assumption. 🤔 May be worthwhile testing without this assumption.
For the second and third terms relative positional encodings are introduced. For the second and third terms relative positional encodings are introduced.
So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$. So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is
replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$
and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$.
\begin{align} \begin{align}
A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j} A^{rel}_{i,j} &= \underset{\mathbf{A}}{\color{cyan}{Q_i^T} \color{lightgreen}{K_j}} +
\underset{\mathbf{B}}{\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}} +
\underset{\mathbf{C}}{\color{orange}{v^T} \color{lightgreen}{K_j}} +
\underset{\mathbf{D}}{\color{orange}{S_{i-j}}}
\end{align} \end{align}
""" """
# $R_{i-j}$ pre-shift # $\color{orange}{R_k}$
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
# $S_{i-j}$ pre-shift # $\color{orange}{S_k}$
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
# $v^T$ # $\color{orange}{v^T}$
query_pos_bias = self.query_pos_bias[None, None, :, :] query_pos_bias = self.query_pos_bias[None, None, :, :]
# $Q_i^T K_j + v^T K_j$ # ${(\mathbf{A} + \mathbf{C})}_{i,j} = \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
# \color{orange}{v^T} \color{lightgreen}{K_j}$
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key) ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
# $Q_i^T R_{i - j}$ pre-shift # $\mathbf{B'}_{i,k} = \color{cyan}{Q_i^T} \color{orange}{R_k}$
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
# $S_{i-j}$ pre-shift # $\mathbf{D'}_{i,k} = \color{orange}{S_k}$
d = key_pos_bias[None, :, None, :] d = key_pos_bias[None, :, None, :]
# $Q_i^T R_{i - j} + S_{i-j}$ # Shift the rows of $\mathbf{(B' + D')}_{i,k}$
# to get $$\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}$$
bd = shift_right(b + d) bd = shift_right(b + d)
# Remove extra positions
bd = bd[:, -key.shape[0]:] bd = bd[:, -key.shape[0]:]
# Return the sum $\mathbf{(A + B + C + D)}_{i,j}$
return ac + bd return ac + bd