mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
backward pass formulas
This commit is contained in:
@ -1086,7 +1086,7 @@
|
|||||||
|
|
||||||
<url>
|
<url>
|
||||||
<loc>https://nn.labml.ai/transformers/flash/test.html</loc>
|
<loc>https://nn.labml.ai/transformers/flash/test.html</loc>
|
||||||
<lastmod>2025-07-30T16:30:00+00:00</lastmod>
|
<lastmod>2025-07-31T16:30:00+00:00</lastmod>
|
||||||
<priority>1.00</priority>
|
<priority>1.00</priority>
|
||||||
</url>
|
</url>
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
@ -4,7 +4,7 @@
|
|||||||
## Forward pass
|
## Forward pass
|
||||||
|
|
||||||
\begin{align}
|
\begin{align}
|
||||||
S_{ij} &= q_i k_j^T
|
S_{ij} &= \sigma Q_i K_j^T
|
||||||
\\
|
\\
|
||||||
L_i &= \sum_j e^{S_{ij}}
|
L_i &= \sum_j e^{S_{ij}}
|
||||||
\\
|
\\
|
||||||
@ -20,7 +20,7 @@ by computing the sum of exponents $l_i$ and the unnormalized output $\tilde{O}_i
|
|||||||
while iterating over keys:
|
while iterating over keys:
|
||||||
|
|
||||||
\begin{align}
|
\begin{align}
|
||||||
S_{ij} &= Q_i K_j^T
|
S_{ij} &= \sigma Q_i K_j^T
|
||||||
\\
|
\\
|
||||||
l_i &\leftarrow l_i + e^{S_{ij}}
|
l_i &\leftarrow l_i + e^{S_{ij}}
|
||||||
\\
|
\\
|
||||||
@ -50,6 +50,7 @@ l_i &\leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \sum_{j=j1}^{j2} \tilde{P}_{i
|
|||||||
\\
|
\\
|
||||||
\tilde{O}_i &\leftarrow e^{m_i - m_{i}^{\text{new}}} \tilde{O}_i + \tilde{P}_{ij} * V_j
|
\tilde{O}_i &\leftarrow e^{m_i - m_{i}^{\text{new}}} \tilde{O}_i + \tilde{P}_{ij} * V_j
|
||||||
\\
|
\\
|
||||||
|
m_i &\leftarrow m_{i}^{\text{new}}
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
Then finally,
|
Then finally,
|
||||||
@ -69,9 +70,9 @@ dS_{ij} &= d\text{softmax}(dP_{ij})
|
|||||||
\\
|
\\
|
||||||
&= P_{ij} dP_{ij} - P_{ij} \sum P_{ik} dP_{ik}
|
&= P_{ij} dP_{ij} - P_{ij} \sum P_{ik} dP_{ik}
|
||||||
\\
|
\\
|
||||||
dQ_i &= \sum_j dS_{ij} K_j
|
dQ_i &= \sigma \sum_j dS_{ij} K_j
|
||||||
\\
|
\\
|
||||||
qK_j &= \sum_i dS_{ij} Q_i
|
dK_j &= \sigma \sum_i dS_{ij} Q_i
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
where $\delta_{jk}$ is $1$ when $j = k$ and $0$ otherwise.
|
where $\delta_{jk}$ is $1$ when $j = k$ and $0$ otherwise.
|
||||||
@ -144,7 +145,7 @@ class AttentionFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
# Tensor for the output
|
# Tensor for the output
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
# Tensor for $\log_2 \sum_j e^{S_{ij}}$
|
# Tensor for log of sum of exponentials $\log_2 L_i = \log_2 \sum_j e^{S_{ij}}$
|
||||||
lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
|
lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
|
||||||
|
|
||||||
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
|
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
|
||||||
@ -200,17 +201,18 @@ class AttentionFunc(torch.autograd.Function):
|
|||||||
dk = torch.empty_like(k)
|
dk = torch.empty_like(k)
|
||||||
dv = torch.empty_like(v)
|
dv = torch.empty_like(v)
|
||||||
|
|
||||||
# $\frac{1}{\log_e 2}$
|
# $\log_2 e$
|
||||||
RCP_LN2 = 1.4426950408889634
|
RCP_LN2 = 1.4426950408889634
|
||||||
# Multiply $k$ by softmax scale
|
# Precompute $\sigma (\log_2 e) K_j$
|
||||||
k_scaled = k * (sm_scale * RCP_LN2)
|
k_scaled = k * (sm_scale * RCP_LN2)
|
||||||
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
||||||
pdp = torch.empty_like(lse)
|
pdp = torch.empty_like(lse)
|
||||||
# We use fixed `BLOCK_Q` for backward pass on $D$
|
# We use fixed `BLOCK_Q` for backward pass on $D$
|
||||||
BLOCK_Q = 16
|
|
||||||
# Compute $D_i$
|
# Compute $D_i$
|
||||||
#
|
#
|
||||||
# This is parallelized along the batch and query in blocks of size `BLOCK_Q`
|
# This is parallelized along the batch and query in blocks of size `BLOCK_Q`
|
||||||
|
BLOCK_Q = 16
|
||||||
pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
|
pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
|
||||||
_attn_bwd_d[pre_grid](
|
_attn_bwd_d[pre_grid](
|
||||||
o, do,
|
o, do,
|
||||||
@ -221,6 +223,7 @@ class AttentionFunc(torch.autograd.Function):
|
|||||||
n_groups=n_groups,
|
n_groups=n_groups,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute $dK$ and $dV$
|
# Compute $dK$ and $dV$
|
||||||
#
|
#
|
||||||
# This is parallelized along the batch and keys in blocks of size `BLOCK_K`
|
# This is parallelized along the batch and keys in blocks of size `BLOCK_K`
|
||||||
@ -232,6 +235,7 @@ class AttentionFunc(torch.autograd.Function):
|
|||||||
is_causal=causal,
|
is_causal=causal,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute $dQ$
|
# Compute $dQ$
|
||||||
#
|
#
|
||||||
# This is parallelized along the batch and queries in blocks of size `BLOCK_Q`
|
# This is parallelized along the batch and queries in blocks of size `BLOCK_Q`
|
||||||
@ -351,23 +355,31 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
|
|
||||||
# Initialize offsets
|
# Initialize offsets
|
||||||
offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
|
offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
|
||||||
i_mask = offs_i < q_seq_len
|
|
||||||
offs_j = tl.arange(0, BLOCK_K)
|
offs_j = tl.arange(0, BLOCK_K)
|
||||||
|
# Mask for $Q$ for the last block
|
||||||
|
i_mask = offs_i < q_seq_len
|
||||||
|
|
||||||
# Initialize $m_i$ and $l_i$
|
# Precalculate $\frac{\sigma}{\log 2}$.
|
||||||
|
#
|
||||||
|
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
|
||||||
|
sm_scale = sm_scale * 1.44269504
|
||||||
|
|
||||||
|
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
|
||||||
|
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
|
||||||
|
#
|
||||||
|
# `b_m` will be storing $m_i \log 2$
|
||||||
b_m = tl.where(i_mask, -float("inf"), 0.0)
|
b_m = tl.where(i_mask, -float("inf"), 0.0)
|
||||||
b_l = tl.where(i_mask, 1.0, 0.0)
|
b_l = tl.where(i_mask, 1.0, 0.0)
|
||||||
# Accumulate $O$
|
|
||||||
b_acc = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
|
||||||
|
|
||||||
# softmax scale / log(2)
|
# $O_i$
|
||||||
sm_scale = sm_scale * 1.44269504
|
b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
||||||
# Load $Q_i$
|
|
||||||
|
# Load $Q_i$ outside the loop since it will be reused through out the loop over $K_j$.
|
||||||
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
|
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
# Upto the diagonal block
|
# Inner loop upto the diagonal block
|
||||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||||
p_kT, p_v,
|
p_kT, p_v,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
BLOCK_Q, d_head, BLOCK_K,
|
BLOCK_Q, d_head, BLOCK_K,
|
||||||
@ -379,7 +391,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
kv_seq_len=kv_seq_len
|
kv_seq_len=kv_seq_len
|
||||||
)
|
)
|
||||||
# Diagonal block with masking within it
|
# Diagonal block with masking within it
|
||||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
BLOCK_Q, d_head, BLOCK_K,
|
BLOCK_Q, d_head, BLOCK_K,
|
||||||
offs_i, offs_j,
|
offs_i, offs_j,
|
||||||
@ -390,7 +402,8 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
kv_seq_len=kv_seq_len
|
kv_seq_len=kv_seq_len
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
# Iterate through all $K_j$
|
||||||
|
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
BLOCK_Q, d_head, BLOCK_K,
|
BLOCK_Q, d_head, BLOCK_K,
|
||||||
offs_i, offs_j,
|
offs_i, offs_j,
|
||||||
@ -401,13 +414,14 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
kv_seq_len=kv_seq_len
|
kv_seq_len=kv_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update LSE
|
# Store LSE $\log_2 L_i = \log_2 \big( l_i * e^{m_i} \big) = \log_2 l_i + m_i log 2$
|
||||||
tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
|
tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
|
||||||
tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
|
# Store $O_i = \frac{\tilde{O}_i}{l_i}$
|
||||||
|
tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||||
p_kT, p_v,
|
p_kT, p_v,
|
||||||
scale,
|
scale,
|
||||||
BLOCK_Q: tl.constexpr,
|
BLOCK_Q: tl.constexpr,
|
||||||
@ -422,45 +436,46 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
|||||||
):
|
):
|
||||||
tl.static_assert(BLOCK_Q % BLOCK_K == 0)
|
tl.static_assert(BLOCK_Q % BLOCK_K == 0)
|
||||||
|
|
||||||
|
# Move $K_j$ and $V_j$ pointers
|
||||||
p_kT = tl.advance(p_kT, (0, j))
|
p_kT = tl.advance(p_kT, (0, j))
|
||||||
p_v = tl.advance(p_v, (j, 0))
|
p_v = tl.advance(p_v, (j, 0))
|
||||||
|
|
||||||
# loop over k, v and update accumulator
|
# Iterate over $K$, $V$ and update $\tilde{O}_i$ and $l_i$
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
current_j = j + offs_j
|
# Load $K_j^T$
|
||||||
j_mask = current_j < kv_seq_len
|
|
||||||
|
|
||||||
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
||||||
|
# Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
|
||||||
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
||||||
|
|
||||||
tl.static_assert(b_s.dtype == HI_PRES_TL)
|
|
||||||
b_s = b_s * scale
|
b_s = b_s * scale
|
||||||
|
|
||||||
|
# Apply causal mask
|
||||||
if MASK:
|
if MASK:
|
||||||
causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
|
causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
|
||||||
b_s = tl.where(causal_mask, b_s, -float("inf"))
|
b_s = tl.where(causal_mask, b_s, -float("inf"))
|
||||||
# always apply seq mask
|
|
||||||
|
# Mask out if the block is beyond the end of $K_j$
|
||||||
|
j_mask = (j + offs_j) < kv_seq_len
|
||||||
b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
|
b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
|
||||||
|
|
||||||
# $m_{i}^{\text{new}} = \max(m_i, \text{rowmax}(S_{ij}))$
|
# $m_{i}^{\text{new}} = \max(m_i, \text{rowmax}(S_{ij}))$
|
||||||
tl.static_assert(len(b_s.shape) == 2)
|
|
||||||
b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
|
b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
|
||||||
# $\tilde{P}_{ij} = \exp(S_{ij} - m_i^{\text{new}})$
|
# $\tilde{P}_{ij} = \exp(S_{ij} - m_i^{\text{new}})$
|
||||||
b_p = tl.math.exp2(b_s - b_m_new[:, None])
|
b_p = tl.math.exp2(b_s - b_m_new[:, None])
|
||||||
# $\tilde{l}_ij = \text{rowsum}(\tilde{P}_{ij})$
|
|
||||||
b_l_new = tl.sum(b_p, -1)
|
|
||||||
|
|
||||||
|
# $\sum_{j=j1}^{j2} \tilde{P}_{ij}$
|
||||||
|
b_l_new = tl.sum(b_p, -1)
|
||||||
# $e^{m_i - m_{i}^{\text{new}}}$
|
# $e^{m_i - m_{i}^{\text{new}}}$
|
||||||
b_m_m_new = tl.math.exp2(b_m - b_m_new)
|
b_m_m_new = tl.math.exp2(b_m - b_m_new)
|
||||||
# $l_i \leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \tilde{l}_{ij}$
|
# $l_i \leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \sum_{j=j1}^{j2} \tilde{P}_{ij}$
|
||||||
b_l = b_l * b_m_m_new + b_l_new
|
b_l = b_l * b_m_m_new + b_l_new
|
||||||
|
|
||||||
# $O_i \leftarrow e^{m_i - m_{i}^{\text{new}}} O_i + \tilde{P}_{ij} * V_j$
|
# $O_i \leftarrow e^{m_i - m_{i}^{\text{new}}} O_i + \tilde{P}_{ij} V_j$
|
||||||
|
b_o = b_o * b_m_m_new[:, None]
|
||||||
|
b_p = b_p.to(b_q.dtype) # TODO
|
||||||
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
|
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
|
||||||
b_acc = b_acc * b_m_m_new[:, None]
|
b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
|
||||||
b_p = b_p.to(b_q.dtype)
|
|
||||||
b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
|
|
||||||
|
|
||||||
# update $m_i$
|
# $m_i \leftarrow m_{i}^{\text{new}}$
|
||||||
b_m = b_m_new
|
b_m = b_m_new
|
||||||
|
|
||||||
# Move pointers
|
# Move pointers
|
||||||
@ -468,9 +483,9 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
|||||||
p_v = tl.advance(p_v, (BLOCK_K, 0))
|
p_v = tl.advance(p_v, (BLOCK_K, 0))
|
||||||
p_kT = tl.advance(p_kT, (0, BLOCK_K))
|
p_kT = tl.advance(p_kT, (0, BLOCK_K))
|
||||||
|
|
||||||
tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
|
tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
|
||||||
|
|
||||||
return b_acc, b_l, b_m
|
return b_o, b_l, b_m
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -503,9 +518,13 @@ def _attn_bwd_d(t_o, t_do,
|
|||||||
(n_groups, BLOCK_Q),
|
(n_groups, BLOCK_Q),
|
||||||
(1, 0))
|
(1, 0))
|
||||||
|
|
||||||
|
# Load $O_i$
|
||||||
o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
|
o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
|
||||||
|
# Load $dO_i$
|
||||||
do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
|
do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
|
||||||
|
# Calculate $D_i = dO_i O_i^T$
|
||||||
d = tl.sum(o * do, axis=-1)
|
d = tl.sum(o * do, axis=-1)
|
||||||
|
# Save $D_i$
|
||||||
tl.store(p_pdp, d, boundary_check=(1,))
|
tl.store(p_pdp, d, boundary_check=(1,))
|
||||||
|
|
||||||
|
|
||||||
@ -523,12 +542,13 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
|||||||
BLOCK_K: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loop along m query; n % m == 0
|
Compute $dK_j$ and $dV_j$ for $j1 \dots j2$ by iterating over $Q_i$
|
||||||
"""
|
"""
|
||||||
# K is already multiplied by scale
|
|
||||||
j = tl.program_id(0) * BLOCK_K
|
j = tl.program_id(0) * BLOCK_K
|
||||||
z = tl.program_id(1)
|
z = tl.program_id(1)
|
||||||
|
|
||||||
|
# Create block pointers
|
||||||
p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
|
p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
|
||||||
(kv_seq_len, d_head),
|
(kv_seq_len, d_head),
|
||||||
(d_head, 1),
|
(d_head, 1),
|
||||||
@ -554,14 +574,15 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
|||||||
(BLOCK_K, d_head),
|
(BLOCK_K, d_head),
|
||||||
(1, 0))
|
(1, 0))
|
||||||
|
|
||||||
b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
|
# Initialize $\frac{1}{\sigma} dK$ and $dV$
|
||||||
b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
|
b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
|
||||||
|
b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
|
||||||
|
|
||||||
# load K and V: they stay in SRAM throughout the inner loop.
|
# Load $\frac{\sigma}{\log 2} K$ and $V$ outside the loop.
|
||||||
b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
|
b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
|
||||||
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
|
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
|
||||||
|
|
||||||
# Iterate through queries that attend to save keys
|
# Iterate through queries in GQA
|
||||||
for g in range(n_groups):
|
for g in range(n_groups):
|
||||||
# Create block pointers
|
# Create block pointers
|
||||||
p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||||
@ -590,19 +611,12 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
|||||||
(BLOCK_Q,),
|
(BLOCK_Q,),
|
||||||
(0,))
|
(0,))
|
||||||
|
|
||||||
# $$dk_j = \sum_i dS_{ij} q_i = \sum_i P_{ij} \big( do_i^T v_j - D_i \big) q_i$$
|
|
||||||
# $$dv_j = \sum_i P_{ij} do_i$$
|
|
||||||
|
|
||||||
# Compute $dk$ $dv$ and $dv$ along the masked blocks near diagonal.
|
|
||||||
# Use smaller block size of MASK_BLOCK_Q
|
|
||||||
# because there is a little extra computation?
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
# loop along m
|
# Inner loop at the diagonal block
|
||||||
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
||||||
b_dk, b_dv,
|
b_dk, b_dv,
|
||||||
p_qT, b_k, b_v, p_do,
|
p_qT, b_k, b_v, p_do,
|
||||||
p_lse, p_pdp,
|
p_lse, p_pdp,
|
||||||
# You can use a smaller BLOCK_Q if BLOCK_K is not divisible by BLOCK_Q
|
|
||||||
BLOCK_Q, BLOCK_K,
|
BLOCK_Q, BLOCK_K,
|
||||||
d_head,
|
d_head,
|
||||||
j=j, i=j,
|
j=j, i=j,
|
||||||
@ -612,7 +626,7 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
|||||||
kv_seq_len=kv_seq_len,
|
kv_seq_len=kv_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute $dk$ and $dv$ for non-masked blocks.
|
# Innerloop on queries after the diagonal
|
||||||
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
||||||
b_dk, b_dv,
|
b_dk, b_dv,
|
||||||
p_qT, b_k, b_v, p_do,
|
p_qT, b_k, b_v, p_do,
|
||||||
@ -626,6 +640,7 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
|||||||
kv_seq_len=kv_seq_len
|
kv_seq_len=kv_seq_len
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Iterate through all queries
|
||||||
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
||||||
b_dk, b_dv,
|
b_dk, b_dv,
|
||||||
p_qT, b_k, b_v, p_do,
|
p_qT, b_k, b_v, p_do,
|
||||||
@ -639,14 +654,13 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
|||||||
kv_seq_len=kv_seq_len
|
kv_seq_len=kv_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save $dv$
|
# Save $dV$
|
||||||
tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
|
tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
|
||||||
|
|
||||||
# Since we used $k = \text{scale} * \hat{k}$ where $\hat{k} are the original keys
|
# `b_dk` had $\frac{1}{\sigma} dK$
|
||||||
# we multiple by scale again to get gradient on original keys.
|
|
||||||
b_dk *= sm_scale
|
b_dk *= sm_scale
|
||||||
|
|
||||||
# Save $dk$
|
# Save $dK$
|
||||||
tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
|
tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
|
||||||
|
|
||||||
|
|
||||||
@ -670,44 +684,51 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
|
|||||||
i_mask = offs_i < q_seq_len
|
i_mask = offs_i < q_seq_len
|
||||||
offs_j = j + tl.arange(0, BLOCK_K)
|
offs_j = j + tl.arange(0, BLOCK_K)
|
||||||
|
|
||||||
# Pointers
|
# Move the pointers
|
||||||
p_qT = tl.advance(p_qT, (0, i))
|
p_qT = tl.advance(p_qT, (0, i))
|
||||||
p_do = tl.advance(p_do, (i, 0))
|
p_do = tl.advance(p_do, (i, 0))
|
||||||
p_lse = tl.advance(p_lse, (i,))
|
p_lse = tl.advance(p_lse, (i,))
|
||||||
p_pdp = tl.advance(p_pdp, (i,))
|
p_pdp = tl.advance(p_pdp, (i,))
|
||||||
|
|
||||||
# Loop
|
# Iterate over $Q$
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
# Load $$qT$$
|
# Load $Q_i^T$
|
||||||
b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
|
b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
|
||||||
|
|
||||||
# $M_i = log_2 L_i$
|
# $log_2 L_i$
|
||||||
b_m = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
|
b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
|
||||||
|
|
||||||
# $$P_{ij} = \frac{e^{q_i^T k_j}}{L_i} = e^{q_i^T k_j - M_i}$$
|
# $(\log_2 e) S_{ij}^T = \sigma (\log_2 e) K_j Q_i^T$
|
||||||
# Not that k is already multiplied by softmax scale.
|
b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
|
||||||
# It is also divided by $log_e 2$ so we can use $2^x$ instead of $e^x$
|
|
||||||
b_qkT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
|
# \begin{align}
|
||||||
b_pT = tl.math.exp2(b_qkT - b_m[None, :])
|
# P_{ij} &= \frac{e^{S_{ij}}}{L_i}
|
||||||
|
# \\
|
||||||
|
# &= \frac{2^{(log_2 e) S_{ij}}}{2^{\log_2 L_i}}
|
||||||
|
# \\
|
||||||
|
# &= 2^{(log_2 e) S_{ij} - \log_2 L_i}
|
||||||
|
# \end{align}
|
||||||
|
b_pT = tl.math.exp2(b_sT - b_l[None, :])
|
||||||
|
|
||||||
# Autoregressive masking.
|
# Autoregressive masking.
|
||||||
if MASK:
|
if MASK:
|
||||||
mask = (offs_i[None, :] >= offs_j[:, None])
|
mask = (offs_i[None, :] >= offs_j[:, None])
|
||||||
b_pT = tl.where(mask, b_pT, 0.0)
|
b_pT = tl.where(mask, b_pT, 0.0)
|
||||||
|
|
||||||
|
# Mask out if the block is beyond the end of $Q_i$
|
||||||
b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
|
b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
|
||||||
|
|
||||||
# $$dv_j = \sum_i P_{ij} do_i$$
|
# $dV_j = \sum_i P_{ij} dO_i$
|
||||||
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
|
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
|
||||||
b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
|
b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
|
||||||
|
|
||||||
# $$dk_j = \sum_i dS_{ij} q_i = \sum_i P_{ij} \big( dP^T_{i:} - D_i \big) q_i$$
|
# $D_i$
|
||||||
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
|
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
|
||||||
# $dP_{ij} = do^T_i v_j$
|
# $dP_{ij} = V_j dO_i^T$
|
||||||
b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
|
b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
|
||||||
# $dS_{ij} = P_{ij} \big( dP_{i:} - D_i \big)$
|
# $dS_{ij} = P_{ij} \big( dP_{ij} - D_i \big)$
|
||||||
b_dsT = b_pT * (b_dpT - b_pdp[None, :])
|
b_dsT = b_pT * (b_dpT - b_pdp[None, :])
|
||||||
# $dk_j = \sum_i dS_{ij} q_i$
|
# $\frac{1}{\sigma} dk_j = \sum_i dS_{ij} Q_i$
|
||||||
b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
|
b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
|
||||||
|
|
||||||
# Increment pointers.
|
# Increment pointers.
|
||||||
@ -717,7 +738,7 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
|
|||||||
p_qT = tl.advance(p_qT, (0, BLOCK_Q))
|
p_qT = tl.advance(p_qT, (0, BLOCK_Q))
|
||||||
p_do = tl.advance(p_do, (BLOCK_Q, 0))
|
p_do = tl.advance(p_do, (BLOCK_Q, 0))
|
||||||
|
|
||||||
# Return accumulated $dk$ and $dv$
|
# Return accumulated $dK$ and $dV$
|
||||||
return b_dk, b_dv
|
return b_dk, b_dv
|
||||||
|
|
||||||
|
|
||||||
@ -784,14 +805,15 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
|||||||
(BLOCK_Q,),
|
(BLOCK_Q,),
|
||||||
(0,))
|
(0,))
|
||||||
|
|
||||||
|
# Load $Q_i$, $dO_i$, $D_i$, and $\log_2 L_i$ outside the loop
|
||||||
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
|
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
|
||||||
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
|
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
|
||||||
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
|
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
|
||||||
|
|
||||||
b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
|
||||||
|
|
||||||
b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
|
b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
|
||||||
|
|
||||||
|
# Initialize $(\log_2 e)dQ$
|
||||||
|
b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
||||||
|
|
||||||
# $$dq_i = \sum_j dS_{ij} k_j = \sum_j P_{ij} \big( dP_{ij} - D_i \big) k_j$$
|
# $$dq_i = \sum_j dS_{ij} k_j = \sum_j P_{ij} \big( dP_{ij} - D_i \big) k_j$$
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
@ -806,7 +828,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
|||||||
kv_seq_len=kv_seq_len
|
kv_seq_len=kv_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
# Other blocks
|
# Compute for other blocks
|
||||||
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||||
b_do, b_lse, b_pdp,
|
b_do, b_lse, b_pdp,
|
||||||
BLOCK_Q, BLOCK_K,
|
BLOCK_Q, BLOCK_K,
|
||||||
@ -817,6 +839,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
|||||||
kv_seq_len=kv_seq_len
|
kv_seq_len=kv_seq_len
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Iterate through all $K$
|
||||||
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||||
b_do, b_lse, b_pdp,
|
b_do, b_lse, b_pdp,
|
||||||
BLOCK_Q, BLOCK_K,
|
BLOCK_Q, BLOCK_K,
|
||||||
@ -827,11 +850,10 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
|||||||
kv_seq_len=kv_seq_len
|
kv_seq_len=kv_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
# Since $k$ was scaled by $\frac{1}{log_e 2}$, and $dq_j = \sum_j dS_{ij} k_j$
|
# `b_dq` stores $(\log_2 e)dQ$ so multiply by $\log_e 2$ to get $dQ$
|
||||||
# got this factor in to computed $dq$ we need to reverse it.
|
|
||||||
b_dq *= LN2
|
b_dq *= LN2
|
||||||
|
|
||||||
# Save $dq$
|
# Save $dQ$
|
||||||
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
|
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user