mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 16:50:39 +08:00
transpose with \top
This commit is contained in:
@ -28,7 +28,8 @@ This reduces the memory used for caching during prediction.
|
|||||||
|
|
||||||
Here's a notebook for training a feedback transformer on Tiny Shakespeare dataset.
|
Here's a notebook for training a feedback transformer on Tiny Shakespeare dataset.
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/hypernetworks/experiment.ipynb)
|
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb)
|
||||||
|
[](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@ -51,7 +52,7 @@ class FeedbackAttention(Module):
|
|||||||
This module computes recurrent attention similar to attention from original transformers
|
This module computes recurrent attention similar to attention from original transformers
|
||||||
paper.
|
paper.
|
||||||
|
|
||||||
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^T K}{\sqrt{d_k}}\Bigg)V$$
|
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^\top K}{\sqrt{d_k}}\Bigg)V$$
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||||
@ -99,23 +100,23 @@ class FeedbackAttention(Module):
|
|||||||
### Get attention scores
|
### Get attention scores
|
||||||
|
|
||||||
\begin{align}
|
\begin{align}
|
||||||
A_{j} &== Q^T K_j \\
|
A_{j} &= Q^\top K_j \\
|
||||||
&= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
|
&= lin_q(X^q + P_q)^\top lin_k(X^k_j + P_j) \\
|
||||||
&= (\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}
|
&= (Q + U^Q)^\top(K_j + U^K_j)
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
where $\color{cyan}{Q}, \color{lightgreen}{K_j}$, are linear transformations of
|
where $Q, K_j$, are linear transformations of
|
||||||
original embeddings $\color{cyan}{X^q}, \color{lightgreen}{X^k_j}$
|
original embeddings $X^q, X^k_j$
|
||||||
and $\color{cyan}{V}, \color{lightgreen}{U_j}$ are linear transformations of
|
and $U^Q, U^K_j$ are linear transformations of
|
||||||
absolute positional encodings $\color{cyan}{P}, \color{lightgreen}{P_j}$.
|
absolute positional encodings $P_q, P_j$.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# $\color{lightgreen}{U_j}$
|
# $U^K_j$
|
||||||
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
|
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
|
||||||
# $\color{cyan}{V^T}$
|
# $U^Q$
|
||||||
query_pos_bias = self.query_pos_bias[None, :, :]
|
query_pos_bias = self.query_pos_bias[None, :, :]
|
||||||
|
|
||||||
# $(\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}$
|
# $(Q + U^Q)^\top(K_j + U^K_j)$
|
||||||
return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
|
return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
|
||||||
|
|
||||||
def __call__(self, *,
|
def __call__(self, *,
|
||||||
|
@ -64,7 +64,7 @@ class MultiHeadAttention(Module):
|
|||||||
|
|
||||||
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
|
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
|
||||||
|
|
||||||
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
|
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
||||||
|
|
||||||
In simple terms, it finds keys that matches the query, and get the values of
|
In simple terms, it finds keys that matches the query, and get the values of
|
||||||
those keys.
|
those keys.
|
||||||
@ -115,7 +115,7 @@ class MultiHeadAttention(Module):
|
|||||||
This method can be overridden for other variations like relative attention.
|
This method can be overridden for other variations like relative attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Calculate $Q K^T$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
|
# Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
|
||||||
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
||||||
|
|
||||||
def __call__(self, *,
|
def __call__(self, *,
|
||||||
@ -151,11 +151,11 @@ class MultiHeadAttention(Module):
|
|||||||
key = self.key(key)
|
key = self.key(key)
|
||||||
value = self.value(value)
|
value = self.value(value)
|
||||||
|
|
||||||
# Compute attention scores $Q K^T$
|
# Compute attention scores $Q K^\top$
|
||||||
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
|
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
|
||||||
scores = self.get_scores(query, key)
|
scores = self.get_scores(query, key)
|
||||||
|
|
||||||
# Scale scores $\frac{Q K^T}{\sqrt{d_k}}$
|
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
|
||||||
scores *= self.scale
|
scores *= self.scale
|
||||||
|
|
||||||
# Apply mask
|
# Apply mask
|
||||||
@ -163,7 +163,7 @@ class MultiHeadAttention(Module):
|
|||||||
scores = scores.masked_fill(mask == 0, -1e9)
|
scores = scores.masked_fill(mask == 0, -1e9)
|
||||||
|
|
||||||
# $softmax$ attention along the key sequence dimension
|
# $softmax$ attention along the key sequence dimension
|
||||||
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
|
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
|
||||||
attn = self.softmax(scores)
|
attn = self.softmax(scores)
|
||||||
|
|
||||||
# Save attentions if debugging
|
# Save attentions if debugging
|
||||||
@ -173,7 +173,7 @@ class MultiHeadAttention(Module):
|
|||||||
attn = self.dropout(attn)
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
# Multiply by values
|
# Multiply by values
|
||||||
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
|
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
||||||
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
|
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
|
||||||
|
|
||||||
# Save attentions for any other calculations
|
# Save attentions for any other calculations
|
||||||
|
@ -102,14 +102,14 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
|
|||||||
|
|
||||||
For the second and third terms relative positional encodings are introduced.
|
For the second and third terms relative positional encodings are introduced.
|
||||||
So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is
|
So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is
|
||||||
replaced with $\underset{\color{lightgreen}{B}}{Q_i^T \color{orange}{R_{i - j}}}$
|
replaced with $\underset{\color{lightgreen}{B}}{Q_i^\top \color{orange}{R_{i - j}}}$
|
||||||
and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$
|
and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$
|
||||||
with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$.
|
with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$.
|
||||||
|
|
||||||
\begin{align}
|
\begin{align}
|
||||||
A^{rel}_{i,j} &= \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^T K_j} +
|
A^{rel}_{i,j} &= \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^\top K_j} +
|
||||||
\underset{\mathbf{\color{lightgreen}{B}}}{Q_i^T \color{orange}{R_{i - j}}} +
|
\underset{\mathbf{\color{lightgreen}{B}}}{Q_i^\top \color{orange}{R_{i - j}}} +
|
||||||
\underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^T} K_j} +
|
\underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^\top} K_j} +
|
||||||
\underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
|
\underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
|
||||||
\end{align}
|
\end{align}
|
||||||
"""
|
"""
|
||||||
@ -118,14 +118,14 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
|
|||||||
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
|
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||||
# $\color{orange}{S_k}$
|
# $\color{orange}{S_k}$
|
||||||
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
|
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||||
# $\color{orange}{v^T}$
|
# $\color{orange}{v^\top}$
|
||||||
query_pos_bias = self.query_pos_bias[None, None, :, :]
|
query_pos_bias = self.query_pos_bias[None, None, :, :]
|
||||||
|
|
||||||
# ${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} =
|
# ${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} =
|
||||||
# Q_i^T K_j +
|
# Q_i^\top K_j +
|
||||||
# \color{orange}{v^T} K_jZ$
|
# \color{orange}{v^\top} K_jZ$
|
||||||
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
|
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
|
||||||
# $\color{lightgreen}{\mathbf{B'}_{i,k}} = \color{cyan}{Q_i^T} \color{orange}{R_k}$
|
# $\color{lightgreen}{\mathbf{B'}_{i,k}} = Q_i^\top \color{orange}{R_k}$
|
||||||
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
|
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
|
||||||
# $\color{lightgreen}{\mathbf{D'}_{i,k}} = \color{orange}{S_k}$
|
# $\color{lightgreen}{\mathbf{D'}_{i,k}} = \color{orange}{S_k}$
|
||||||
d = key_pos_bias[None, :, None, :]
|
d = key_pos_bias[None, :, None, :]
|
||||||
@ -136,9 +136,9 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
|
|||||||
bd = bd[:, -key.shape[0]:]
|
bd = bd[:, -key.shape[0]:]
|
||||||
|
|
||||||
# Return the sum $$
|
# Return the sum $$
|
||||||
# \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^T K_j} +
|
# \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^\top K_j} +
|
||||||
# \underset{\mathbf{\color{lightgreen}{B}}}{Q_i^T \color{orange}{R_{i - j}}} +
|
# \underset{\mathbf{\color{lightgreen}{B}}}{Q_i^\top \color{orange}{R_{i - j}}} +
|
||||||
# \underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^T} K_j} +
|
# \underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^\top} K_j} +
|
||||||
# \underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
|
# \underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
|
||||||
# $$
|
# $$
|
||||||
return ac + bd
|
return ac + bd
|
||||||
|
Reference in New Issue
Block a user