mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 16:50:39 +08:00
transpose with \top
This commit is contained in:
@ -28,7 +28,8 @@ This reduces the memory used for caching during prediction.
|
||||
|
||||
Here's a notebook for training a feedback transformer on Tiny Shakespeare dataset.
|
||||
|
||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/hypernetworks/experiment.ipynb)
|
||||
[](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/feedback/experiment.ipynb)
|
||||
[](https://web.lab-ml.com/run?uuid=d8eb9416530a11eb8fb50242ac1c0002)
|
||||
"""
|
||||
|
||||
import math
|
||||
@ -51,7 +52,7 @@ class FeedbackAttention(Module):
|
||||
This module computes recurrent attention similar to attention from original transformers
|
||||
paper.
|
||||
|
||||
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^T K}{\sqrt{d_k}}\Bigg)V$$
|
||||
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q^\top K}{\sqrt{d_k}}\Bigg)V$$
|
||||
"""
|
||||
|
||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||
@ -99,23 +100,23 @@ class FeedbackAttention(Module):
|
||||
### Get attention scores
|
||||
|
||||
\begin{align}
|
||||
A_{j} &== Q^T K_j \\
|
||||
&= lin_q(\color{cyan}{X^q + P})^T lin_k(\color{lightgreen}{X^k_j + P_j}) \\
|
||||
&= (\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}
|
||||
A_{j} &= Q^\top K_j \\
|
||||
&= lin_q(X^q + P_q)^\top lin_k(X^k_j + P_j) \\
|
||||
&= (Q + U^Q)^\top(K_j + U^K_j)
|
||||
\end{align}
|
||||
|
||||
where $\color{cyan}{Q}, \color{lightgreen}{K_j}$, are linear transformations of
|
||||
original embeddings $\color{cyan}{X^q}, \color{lightgreen}{X^k_j}$
|
||||
and $\color{cyan}{V}, \color{lightgreen}{U_j}$ are linear transformations of
|
||||
absolute positional encodings $\color{cyan}{P}, \color{lightgreen}{P_j}$.
|
||||
where $Q, K_j$, are linear transformations of
|
||||
original embeddings $X^q, X^k_j$
|
||||
and $U^Q, U^K_j$ are linear transformations of
|
||||
absolute positional encodings $P_q, P_j$.
|
||||
"""
|
||||
|
||||
# $\color{lightgreen}{U_j}$
|
||||
# $U^K_j$
|
||||
key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
|
||||
# $\color{cyan}{V^T}$
|
||||
# $U^Q$
|
||||
query_pos_bias = self.query_pos_bias[None, :, :]
|
||||
|
||||
# $(\color{cyan}{Q^T + V^T})(\color{lightgreen}{K_j + U_j)}$
|
||||
# $(Q + U^Q)^\top(K_j + U^K_j)$
|
||||
return torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key + key_pos_emb[:, None, :, :])
|
||||
|
||||
def __call__(self, *,
|
||||
|
@ -64,7 +64,7 @@ class MultiHeadAttention(Module):
|
||||
|
||||
This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
|
||||
|
||||
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
|
||||
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
||||
|
||||
In simple terms, it finds keys that matches the query, and get the values of
|
||||
those keys.
|
||||
@ -115,7 +115,7 @@ class MultiHeadAttention(Module):
|
||||
This method can be overridden for other variations like relative attention.
|
||||
"""
|
||||
|
||||
# Calculate $Q K^T$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
|
||||
# Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
|
||||
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
||||
|
||||
def __call__(self, *,
|
||||
@ -151,11 +151,11 @@ class MultiHeadAttention(Module):
|
||||
key = self.key(key)
|
||||
value = self.value(value)
|
||||
|
||||
# Compute attention scores $Q K^T$
|
||||
# Compute attention scores $Q K^\top$
|
||||
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
|
||||
scores = self.get_scores(query, key)
|
||||
|
||||
# Scale scores $\frac{Q K^T}{\sqrt{d_k}}$
|
||||
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
|
||||
scores *= self.scale
|
||||
|
||||
# Apply mask
|
||||
@ -163,7 +163,7 @@ class MultiHeadAttention(Module):
|
||||
scores = scores.masked_fill(mask == 0, -1e9)
|
||||
|
||||
# $softmax$ attention along the key sequence dimension
|
||||
# $\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
|
||||
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
|
||||
attn = self.softmax(scores)
|
||||
|
||||
# Save attentions if debugging
|
||||
@ -173,7 +173,7 @@ class MultiHeadAttention(Module):
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# Multiply by values
|
||||
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
|
||||
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
||||
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
|
||||
|
||||
# Save attentions for any other calculations
|
||||
|
@ -102,14 +102,14 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
|
||||
|
||||
For the second and third terms relative positional encodings are introduced.
|
||||
So $\underset{\color{lightgreen}{B}}{Q_i^\top U^K_j}$ is
|
||||
replaced with $\underset{\color{lightgreen}{B}}{Q_i^T \color{orange}{R_{i - j}}}$
|
||||
replaced with $\underset{\color{lightgreen}{B}}{Q_i^\top \color{orange}{R_{i - j}}}$
|
||||
and $\underset{\color{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$
|
||||
with $\underset{\color{lightgreen}{D}}{\color{orange}{S_{i-j}}}$.
|
||||
|
||||
\begin{align}
|
||||
A^{rel}_{i,j} &= \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^T K_j} +
|
||||
\underset{\mathbf{\color{lightgreen}{B}}}{Q_i^T \color{orange}{R_{i - j}}} +
|
||||
\underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^T} K_j} +
|
||||
A^{rel}_{i,j} &= \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^\top K_j} +
|
||||
\underset{\mathbf{\color{lightgreen}{B}}}{Q_i^\top \color{orange}{R_{i - j}}} +
|
||||
\underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^\top} K_j} +
|
||||
\underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
|
||||
\end{align}
|
||||
"""
|
||||
@ -118,14 +118,14 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
|
||||
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||
# $\color{orange}{S_k}$
|
||||
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||
# $\color{orange}{v^T}$
|
||||
# $\color{orange}{v^\top}$
|
||||
query_pos_bias = self.query_pos_bias[None, None, :, :]
|
||||
|
||||
# ${(\color{lightgreen}{\mathbf{A + C}})}_{i,j} =
|
||||
# Q_i^T K_j +
|
||||
# \color{orange}{v^T} K_jZ$
|
||||
# Q_i^\top K_j +
|
||||
# \color{orange}{v^\top} K_jZ$
|
||||
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
|
||||
# $\color{lightgreen}{\mathbf{B'}_{i,k}} = \color{cyan}{Q_i^T} \color{orange}{R_k}$
|
||||
# $\color{lightgreen}{\mathbf{B'}_{i,k}} = Q_i^\top \color{orange}{R_k}$
|
||||
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
|
||||
# $\color{lightgreen}{\mathbf{D'}_{i,k}} = \color{orange}{S_k}$
|
||||
d = key_pos_bias[None, :, None, :]
|
||||
@ -136,9 +136,9 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
|
||||
bd = bd[:, -key.shape[0]:]
|
||||
|
||||
# Return the sum $$
|
||||
# \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^T K_j} +
|
||||
# \underset{\mathbf{\color{lightgreen}{B}}}{Q_i^T \color{orange}{R_{i - j}}} +
|
||||
# \underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^T} K_j} +
|
||||
# \underset{\mathbf{\color{lightgreen}{A}}}{Q_i^\top K_j} +
|
||||
# \underset{\mathbf{\color{lightgreen}{B}}}{Q_i^\top \color{orange}{R_{i - j}}} +
|
||||
# \underset{\mathbf{\color{lightgreen}{C}}}{\color{orange}{v^\top} K_j} +
|
||||
# \underset{\mathbf{\color{lightgreen}{D}}}{\color{orange}{S_{i-j}}}
|
||||
# $$
|
||||
return ac + bd
|
||||
|
Reference in New Issue
Block a user