transpose with \top

This commit is contained in:
Varuna Jayasiri
2021-01-10 12:16:50 +05:30
parent 8ee70da198
commit 9102883e1f
3 changed files with 30 additions and 29 deletions

View File

@ -28,7 +28,8 @@ This reduces the memory used for caching during prediction.
Here's a notebook for training a feedback transformer on Tiny Shakespeare dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/hypernetworks/experiment.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002)
"""
import math
@ -51,7 +52,7 @@ class FeedbackAttention(Module):
This module computes recurrent attention similar to attention from original transformers
paper.
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^T K}{\sqrt{d_k}}\Bigg)V$$
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^\top K}{\sqrt{d_k}}\Bigg)V$$
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
@ -99,23 +100,23 @@ class FeedbackAttention(Module):
### Get attention scores
\begin{align}
A_{j} &== Q^T K_j \\
&= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
&= (\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}
A_{j} &= Q^\top K_j \\
&= lin_q(X^q + P_q)^\top lin_k(X^k_j + P_j) \\
&= (Q + U^Q)^\top(K_j + U^K_j)
\end{align}
where $\color{cyan}{Q}, \color{lightgreen}{K_j}$, are linear transformations of
original embeddings $\color{cyan}{X^q}, \color{lightgreen}{X^k_j}$
and $\color{cyan}{V}, \color{lightgreen}{U_j}$ are linear transformations of
absolute positional encodings $\color{cyan}{P}, \color{lightgreen}{P_j}$.
where $Q, K_j$, are linear transformations of
original embeddings $X^q, X^k_j$
and $U^Q, U^K_j$ are linear transformations of
absolute positional encodings $P_q, P_j$.
"""
# $\color{lightgreen}{U_j}$
# $U^K_j$
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
# $\color{cyan}{V^T}$
# $U^Q$
query_pos_bias = self.query_pos_bias[None, :, :]
# $(\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}$
# $(Q + U^Q)^\top(K_j + U^K_j)$
return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
def __call__(self, *,

View File

@ -64,7 +64,7 @@ class MultiHeadAttention(Module):
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
In simple terms, it finds keys that matches the query, and get the values of
those keys.
@ -115,7 +115,7 @@ class MultiHeadAttention(Module):
This method can be overridden for other variations like relative attention.
"""
# Calculate $Q K^T$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
# Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
return torch.einsum('ibhd,jbhd->ijbh', query, key)
def __call__(self, *,
@ -151,11 +151,11 @@ class MultiHeadAttention(Module):
key = self.key(key)
value = self.value(value)
# Compute attention scores $Q K^T$
# Compute attention scores $Q K^\top$
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
scores = self.get_scores(query, key)
# Scale scores $\frac{Q K^T}{\sqrt{d_k}}$
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
scores *= self.scale
# Apply mask
@ -163,7 +163,7 @@ class MultiHeadAttention(Module):
scores = scores.masked_fill(mask == 0, -1e9)
# $softmax$ attention along the key sequence dimension
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = self.softmax(scores)
# Save attentions if debugging
@ -173,7 +173,7 @@ class MultiHeadAttention(Module):
attn = self.dropout(attn)
# Multiply by values
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# Save attentions for any other calculations

View File

@ -102,14 +102,14 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
For the second and third terms relative positional encodings are introduced.
So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is
replaced with $\underset{\color{lightgreen}{B}}{Q_i^T \color{orange}{R_{i - j}}}$
replaced with $\underset{\color{lightgreen}{B}}{Q_i^\top \color{orange}{R_{i - j}}}$
and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$
with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$.
\begin{align}
A^{rel}_{i,j} &= \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^T K_j} +
\underset{\mathbf{\color{lightgreen}{B}}}{Q_i^T \color{orange}{R_{i - j}}} +
\underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^T} K_j} +
A^{rel}_{i,j} &= \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^\top K_j} +
\underset{\mathbf{\color{lightgreen}{B}}}{Q_i^\top \color{orange}{R_{i - j}}} +
\underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^\top} K_j} +
\underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
\end{align}
"""
@ -118,14 +118,14 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
# $\color{orange}{S_k}$
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
# $\color{orange}{v^T}$
# $\color{orange}{v^\top}$
query_pos_bias = self.query_pos_bias[None, None, :, :]
# ${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} =
# Q_i^T K_j +
# \color{orange}{v^T} K_jZ$
# Q_i^\top K_j +
# \color{orange}{v^\top} K_jZ$
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
# $\color{lightgreen}{\mathbf{B'}_{i,k}} = \color{cyan}{Q_i^T} \color{orange}{R_k}$
# $\color{lightgreen}{\mathbf{B'}_{i,k}} = Q_i^\top \color{orange}{R_k}$
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
# $\color{lightgreen}{\mathbf{D'}_{i,k}} = \color{orange}{S_k}$
d = key_pos_bias[None, :, None, :]
@ -136,9 +136,9 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
bd = bd[:, -key.shape[0]:]
# Return the sum $$
# \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^T K_j} +
# \underset{\mathbf{\color{lightgreen}{B}}}{Q_i^T \color{orange}{R_{i - j}}} +
# \underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^T} K_j} +
# \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^\top K_j} +
# \underset{\mathbf{\color{lightgreen}{B}}}{Q_i^\top \color{orange}{R_{i - j}}} +
# \underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^\top} K_j} +
# \underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
# $$
return ac + bd