transpose with \top

This commit is contained in:
Varuna Jayasiri
2021-01-10 12:16:50 +05:30
parent 8ee70da198
commit 9102883e1f
3 changed files with 30 additions and 29 deletions

View File

@ -28,7 +28,8 @@ This reduces the memory used for caching during prediction.
Here's a notebook for training a feedback transformer on Tiny Shakespeare dataset. Here's a notebook for training a feedback transformer on Tiny Shakespeare dataset.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/hypernetworks/experiment.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002)
""" """
import math import math
@ -51,7 +52,7 @@ class FeedbackAttention(Module):
This module computes recurrent attention similar to attention from original transformers This module computes recurrent attention similar to attention from original transformers
paper. paper.
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^T K}{\sqrt{d_k}}\Bigg)V$$ $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^\top K}{\sqrt{d_k}}\Bigg)V$$
""" """
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
@ -99,23 +100,23 @@ class FeedbackAttention(Module):
### Get attention scores ### Get attention scores
\begin{align} \begin{align}
A_{j} &== Q^T K_j \\ A_{j} &= Q^\top K_j \\
&= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\ &= lin_q(X^q + P_q)^\top lin_k(X^k_j + P_j) \\
&= (\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)} &= (Q + U^Q)^\top(K_j + U^K_j)
\end{align} \end{align}
where $\color{cyan}{Q}, \color{lightgreen}{K_j}$, are linear transformations of where $Q, K_j$, are linear transformations of
original embeddings $\color{cyan}{X^q}, \color{lightgreen}{X^k_j}$ original embeddings $X^q, X^k_j$
and $\color{cyan}{V}, \color{lightgreen}{U_j}$ are linear transformations of and $U^Q, U^K_j$ are linear transformations of
absolute positional encodings $\color{cyan}{P}, \color{lightgreen}{P_j}$. absolute positional encodings $P_q, P_j$.
""" """
# $\color{lightgreen}{U_j}$ # $U^K_j$
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:] key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
# $\color{cyan}{V^T}$ # $U^Q$
query_pos_bias = self.query_pos_bias[None, :, :] query_pos_bias = self.query_pos_bias[None, :, :]
# $(\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}$ # $(Q + U^Q)^\top(K_j + U^K_j)$
return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :]) return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
def __call__(self, *, def __call__(self, *,

View File

@ -64,7 +64,7 @@ class MultiHeadAttention(Module):
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors. This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
In simple terms, it finds keys that matches the query, and get the values of In simple terms, it finds keys that matches the query, and get the values of
those keys. those keys.
@ -115,7 +115,7 @@ class MultiHeadAttention(Module):
This method can be overridden for other variations like relative attention. This method can be overridden for other variations like relative attention.
""" """
# Calculate $Q K^T$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$ # Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
return torch.einsum('ibhd,jbhd->ijbh', query, key) return torch.einsum('ibhd,jbhd->ijbh', query, key)
def __call__(self, *, def __call__(self, *,
@ -151,11 +151,11 @@ class MultiHeadAttention(Module):
key = self.key(key) key = self.key(key)
value = self.value(value) value = self.value(value)
# Compute attention scores $Q K^T$ # Compute attention scores $Q K^\top$
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]` # Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
scores = self.get_scores(query, key) scores = self.get_scores(query, key)
# Scale scores $\frac{Q K^T}{\sqrt{d_k}}$ # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
scores *= self.scale scores *= self.scale
# Apply mask # Apply mask
@ -163,7 +163,7 @@ class MultiHeadAttention(Module):
scores = scores.masked_fill(mask == 0, -1e9) scores = scores.masked_fill(mask == 0, -1e9)
# $softmax$ attention along the key sequence dimension # $softmax$ attention along the key sequence dimension
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$ # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = self.softmax(scores) attn = self.softmax(scores)
# Save attentions if debugging # Save attentions if debugging
@ -173,7 +173,7 @@ class MultiHeadAttention(Module):
attn = self.dropout(attn) attn = self.dropout(attn)
# Multiply by values # Multiply by values
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
x = torch.einsum("ijbh,jbhd->ibhd", attn, value) x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# Save attentions for any other calculations # Save attentions for any other calculations

View File

@ -102,14 +102,14 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
For the second and third terms relative positional encodings are introduced. For the second and third terms relative positional encodings are introduced.
So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is
replaced with $\underset{\color{lightgreen}{B}}{Q_i^T \color{orange}{R_{i - j}}}$ replaced with $\underset{\color{lightgreen}{B}}{Q_i^\top \color{orange}{R_{i - j}}}$
and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$ and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$
with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$. with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$.
\begin{align} \begin{align}
A^{rel}_{i,j} &= \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^T K_j} + A^{rel}_{i,j} &= \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^\top K_j} +
\underset{\mathbf{\color{lightgreen}{B}}}{Q_i^T \color{orange}{R_{i - j}}} + \underset{\mathbf{\color{lightgreen}{B}}}{Q_i^\top \color{orange}{R_{i - j}}} +
\underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^T} K_j} + \underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^\top} K_j} +
\underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}} \underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
\end{align} \end{align}
""" """
@ -118,14 +118,14 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
# $\color{orange}{S_k}$ # $\color{orange}{S_k}$
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
# $\color{orange}{v^T}$ # $\color{orange}{v^\top}$
query_pos_bias = self.query_pos_bias[None, None, :, :] query_pos_bias = self.query_pos_bias[None, None, :, :]
# ${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} = # ${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} =
# Q_i^T K_j + # Q_i^\top K_j +
# \color{orange}{v^T} K_jZ$ # \color{orange}{v^\top} K_jZ$
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key) ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
# $\color{lightgreen}{\mathbf{B'}_{i,k}} = \color{cyan}{Q_i^T} \color{orange}{R_k}$ # $\color{lightgreen}{\mathbf{B'}_{i,k}} = Q_i^\top \color{orange}{R_k}$
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
# $\color{lightgreen}{\mathbf{D'}_{i,k}} = \color{orange}{S_k}$ # $\color{lightgreen}{\mathbf{D'}_{i,k}} = \color{orange}{S_k}$
d = key_pos_bias[None, :, None, :] d = key_pos_bias[None, :, None, :]
@ -136,9 +136,9 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
bd = bd[:, -key.shape[0]:] bd = bd[:, -key.shape[0]:]
# Return the sum $$ # Return the sum $$
# \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^T K_j} + # \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^\top K_j} +
# \underset{\mathbf{\color{lightgreen}{B}}}{Q_i^T \color{orange}{R_{i - j}}} + # \underset{\mathbf{\color{lightgreen}{B}}}{Q_i^\top \color{orange}{R_{i - j}}} +
# \underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^T} K_j} + # \underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^\top} K_j} +
# \underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}} # \underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
# $$ # $$
return ac + bd return ac + bd