relative attention notes

This commit is contained in:
Varuna Jayasiri
2021-01-09 10:41:25 +05:30
parent 9f4b494bf2
commit 809a54d6aa
3 changed files with 136 additions and 82 deletions

View File

@ -18,38 +18,17 @@ from labml_nn.transformers.models import FeedForward
from labml_nn.utils import clone_module_list
class PrepareQueryForMultiHeadAttention(Module):
"""
## Prepare query for multi-head attention
This module does a linear transformation and splits the vector into given
number of heads for multi-head attention.
This is used to transform **key**, **query**, and **value** vectors.
"""
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
super().__init__()
# Linear layer for linear transform
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
# Number of heads
self.heads = heads
# Number of dimensions in vectors in each head
self.d_k = d_k
def __call__(self, x: torch.Tensor):
# Input has shape `[seq_len, batch_size, d_model]`
batch_size, _ = x.shape
# Linear transform
x = self.linear(x)
# Split into heads
x = x.view(batch_size, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]`
return x
class FeedbackAttention(Module):
"""
## Feedback Attention
This is very similar to [Relative Multi-Head Attention](../relative_mha.html)
but with some modifications.
📝 Decided not to extend from [Relative Multi-Head Attention](../relative_mha.html)
or [Multi-Head Attention](../mha.html) to improve readability.
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
super().__init__()
@ -57,7 +36,7 @@ class FeedbackAttention(Module):
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
self.query = PrepareQueryForMultiHeadAttention(d_model, heads, self.d_k, False)
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, False)
@ -68,17 +47,60 @@ class FeedbackAttention(Module):
# Scaling factor before the softmax
self.scale = 1 / math.sqrt(self.d_k)
# Softmax for attention along the time dimension of `key`
self.softmax = nn.Softmax(dim=0)
# Number of relative positions
self.P = 2 ** 12
# Relative positional embeddings for key relative to the query.
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
# Relative positional embedding bias for key relative to the query.
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
# Positional embeddings for the query is independent of the position of the query
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
# We store attentions so that it can used for logging, or other computations if needed
self.attn = None
self.P = 2 ** 12
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
self.softmax = nn.Softmax(dim=0)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
"""
### Get relative attention scores
With absolute attention
\begin{align}
A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
&= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
\color{cyan}{Q_i^T} \color{lightgreen}{U_j} +
\color{cyan}{V_i^T} \color{lightgreen}{K_j} +
\color{cyan}{V_i^T} \color{lightgreen}{U_j}
\end{align}
where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of
original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$
and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of
absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$.
They reason out that the attention to a given key should be the same regardless of
the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$
with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$.
🤔 May be worthwhile testing without this assumption.
For the second and third terms relative positional encodings are introduced.
So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is
replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$
and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$.
\begin{align}
A^{rel}_{i,j} &= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
\color{cyan}{Q_i^T} \color{orange}{R_{i - j}} +
\color{orange}{v^T} \color{lightgreen}{K_j} +
\color{orange}{S_{i-j}}
\end{align}
"""
# $\color{orange}{R_{i - j}}$
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
key_pos_bias = self.key_pos_bias[-key.shape[0]:]
query_pos_bias = self.query_pos_bias[None, :, :]

View File

@ -18,10 +18,10 @@ import math
from typing import Optional
import torch
from torch import nn as nn
from labml import tracker
from labml_helpers.module import Module
from torch import nn as nn
from torch.nn import functional as F
class PrepareForMultiHeadAttention(Module):
@ -43,43 +43,51 @@ class PrepareForMultiHeadAttention(Module):
self.d_k = d_k
def __call__(self, x: torch.Tensor):
# Input has shape `[seq_len, batch_size, d_model]`
seq_len, batch_size, _ = x.shape
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
# We apply the linear transformation of the last dimension and splits that into
# the heads
head_shape = x.shape[:-1]
# Linear transform
x = self.linear(x)
# Split into heads
x = x.view(seq_len, batch_size, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]`
# Split last dimension into heads
x = x.view(*head_shape, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
return x
class MultiHeadAttention(Module):
r"""
## Multi-Head Attention Module
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
In simple terms, it finds keys that matches the query, and get the values of
those keys.
It uses dot-product of query and key as the indicator of how matching they are.
Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
This is done to avoid large dot-product values causing softmax to
give very small gradients when $d_k$ is large.
Softmax is calculate along the axis of of the sequence (or time).
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
"""
## Multi-Head Attention Module
* `heads` is the number of heads.
* `d_model` is the number of features in the `query`, `key` and `value` vectors.
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
$$\mathop{Attention}(Q, K, V) = \mathop{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
In simple terms, it finds keys that matches the query, and get the values of
those keys.
It uses dot-product of query and key as the indicator of how matching they are.
Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
This is done to avoid large dot-product values causing softmax to
give very small gradients when $d_k$ is large.
Softmax is calculate along the axis of of the sequence (or time).
"""
super().__init__()
# Number of features per head
self.d_k = d_model // heads
# Number of heads
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
@ -87,6 +95,9 @@ class MultiHeadAttention(Module):
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
# Softmax for attention along the time dimension of `key`
self.softmax = nn.Softmax(dim=1)
# Output layer
self.output = nn.Linear(d_model, d_model)
# Dropout
@ -153,7 +164,7 @@ class MultiHeadAttention(Module):
# $softmax$ attention along the key sequence dimension
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
attn = F.softmax(scores, dim=1)
attn = self.softmax(scores)
# Save attentions if debugging
tracker.debug('attn', attn)

View File

@ -51,56 +51,77 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
# The linear transformations doesn't need a bias since we take care of it when
# calculating scores.
# However having a bias for `value` might make sense.
super().__init__(heads, d_model, dropout_prob, False)
super().__init__(heads, d_model, dropout_prob, bias=False)
# Number of relative positions
self.P = 2 ** 12
# Relative positional embeddings for key relative to the query.
# We need $2P$ embeddings because the keys can be before or after the query.
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
# Relative positional embedding bias for key relative to the query.
self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
# Positional embeddings for the query is independent of the position of the query
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
"""
r"""
### Get relative attention scores
With absolute attention
\begin{align}
A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j) \\
&= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j
A^{abs}_{j} &= lin_q(\color{cyan}{X^q_i + P_i})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
&= \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
\color{cyan}{Q_i^T} \color{lightgreen}{U_j} +
\color{cyan}{V_i^T} \color{lightgreen}{K_j} +
\color{cyan}{V_i^T} \color{lightgreen}{U_j}
\end{align}
where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of
orginal embeddings and positional encodings.
where $\color{cyan}{Q_i}, \color{lightgreen}{K_j}$, are linear transformations of
original embeddings $\color{cyan}{X^q_i}, \color{lightgreen}{X^k_j}$
and $\color{cyan}{V_i}, \color{lightgreen}{U_j}$ are linear transformations of
absolute positional encodings $\color{cyan}{P_i}, \color{lightgreen}{P_j}$.
They reason out that the attention to a given key should be the same regardless of
the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$.
They reason out that the attention to a given key should be the same regardless of
the position of query. Hence replace $\color{cyan}{V_i^T} \color{lightgreen}{K_j}$
with a constant $\color{orange}{v^T} \color{lightgreen}{K_j}$.
🤔 May be worthwhile testing without this assumption.
For the second and third terms relative positional encodings are introduced.
So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$.
So $\color{cyan}{Q_i^T} \color{lightgreen}{U_j}$ is
replaced with $\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}$
and $\color{cyan}{V_i^T} \color{lightgreen}{U_j}$ with $\color{orange}{S_{i-j}}$.
\begin{align}
A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j}
A^{rel}_{i,j} &= \underset{\mathbf{A}}{\color{cyan}{Q_i^T} \color{lightgreen}{K_j}} +
\underset{\mathbf{B}}{\color{cyan}{Q_i^T} \color{orange}{R_{i - j}}} +
\underset{\mathbf{C}}{\color{orange}{v^T} \color{lightgreen}{K_j}} +
\underset{\mathbf{D}}{\color{orange}{S_{i-j}}}
\end{align}
"""
# $R_{i-j}$ pre-shift
# $\color{orange}{R_k}$
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
# $S_{i-j}$ pre-shift
# $\color{orange}{S_k}$
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
# $v^T$
# $\color{orange}{v^T}$
query_pos_bias = self.query_pos_bias[None, None, :, :]
# $Q_i^T K_j + v^T K_j$
# ${(\mathbf{A} + \mathbf{C})}_{i,j} = \color{cyan}{Q_i^T} \color{lightgreen}{K_j} +
# \color{orange}{v^T} \color{lightgreen}{K_j}$
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
# $Q_i^T R_{i - j}$ pre-shift
# $\mathbf{B'}_{i,k} = \color{cyan}{Q_i^T} \color{orange}{R_k}$
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
# $S_{i-j}$ pre-shift
# $\mathbf{D'}_{i,k} = \color{orange}{S_k}$
d = key_pos_bias[None, :, None, :]
# $Q_i^T R_{i - j} + S_{i-j}$
# Shift the rows of $\mathbf{(B' + D')}_{i,k}$
# to get $$\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}$$
bd = shift_right(b + d)
# Remove extra positions
bd = bd[:, -key.shape[0]:]
# Return the sum $\mathbf{(A + B + C + D)}_{i,j}$
return ac + bd