backward pass formulas

This commit is contained in:
Varuna Jayasiri
2025-08-01 13:24:57 +05:30
parent a9b5c923eb
commit 5a8182d21b
3 changed files with 1021 additions and 830 deletions

View File

@ -1086,7 +1086,7 @@
<url>
<loc>https://nn.labml.ai/transformers/flash/test.html</loc>
<lastmod>2025-07-30T16:30:00+00:00</lastmod>
<lastmod>2025-07-31T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>

File diff suppressed because one or more lines are too long

View File

@ -4,7 +4,7 @@
## Forward pass
\begin{align}
S_{ij} &= q_i k_j^T
S_{ij} &= \sigma Q_i K_j^T
\\
L_i &= \sum_j e^{S_{ij}}
\\
@ -20,7 +20,7 @@ by computing the sum of exponents $l_i$ and the unnormalized output $\tilde{O}_i
while iterating over keys:
\begin{align}
S_{ij} &= Q_i K_j^T
S_{ij} &= \sigma Q_i K_j^T
\\
l_i &\leftarrow l_i + e^{S_{ij}}
\\
@ -50,6 +50,7 @@ l_i &\leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \sum_{j=j1}^{j2} \tilde{P}_{i
\\
\tilde{O}_i &\leftarrow e^{m_i - m_{i}^{\text{new}}} \tilde{O}_i + \tilde{P}_{ij} * V_j
\\
m_i &\leftarrow m_{i}^{\text{new}}
\end{align}
Then finally,
@ -69,9 +70,9 @@ dS_{ij} &= d\text{softmax}(dP_{ij})
\\
&= P_{ij} dP_{ij} - P_{ij} \sum P_{ik} dP_{ik}
\\
dQ_i &= \sum_j dS_{ij} K_j
dQ_i &= \sigma \sum_j dS_{ij} K_j
\\
qK_j &= \sum_i dS_{ij} Q_i
dK_j &= \sigma \sum_i dS_{ij} Q_i
\end{align}
where $\delta_{jk}$ is $1$ when $j = k$ and $0$ otherwise.
@ -144,7 +145,7 @@ class AttentionFunc(torch.autograd.Function):
# Tensor for the output
o = torch.empty_like(q)
# Tensor for $\log_2 \sum_j e^{S_{ij}}$
# Tensor for log of sum of exponentials $\log_2 L_i = \log_2 \sum_j e^{S_{ij}}$
lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
@ -200,17 +201,18 @@ class AttentionFunc(torch.autograd.Function):
dk = torch.empty_like(k)
dv = torch.empty_like(v)
# $\frac{1}{\log_e 2}$
# $\log_2 e$
RCP_LN2 = 1.4426950408889634
# Multiply $k$ by softmax scale
# Precompute $\sigma (\log_2 e) K_j$
k_scaled = k * (sm_scale * RCP_LN2)
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
pdp = torch.empty_like(lse)
# We use fixed `BLOCK_Q` for backward pass on $D$
BLOCK_Q = 16
# Compute $D_i$
#
# This is parallelized along the batch and query in blocks of size `BLOCK_Q`
BLOCK_Q = 16
pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
_attn_bwd_d[pre_grid](
o, do,
@ -221,6 +223,7 @@ class AttentionFunc(torch.autograd.Function):
n_groups=n_groups,
num_stages=1,
)
# Compute $dK$ and $dV$
#
# This is parallelized along the batch and keys in blocks of size `BLOCK_K`
@ -232,6 +235,7 @@ class AttentionFunc(torch.autograd.Function):
is_causal=causal,
)
# Compute $dQ$
#
# This is parallelized along the batch and queries in blocks of size `BLOCK_Q`
@ -351,23 +355,31 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
# Initialize offsets
offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
i_mask = offs_i < q_seq_len
offs_j = tl.arange(0, BLOCK_K)
# Mask for $Q$ for the last block
i_mask = offs_i < q_seq_len
# Initialize $m_i$ and $l_i$
# Precalculate $\frac{\sigma}{\log 2}$.
#
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
sm_scale = sm_scale * 1.44269504
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
#
# `b_m` will be storing $m_i \log 2$
b_m = tl.where(i_mask, -float("inf"), 0.0)
b_l = tl.where(i_mask, 1.0, 0.0)
# Accumulate $O$
b_acc = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
# softmax scale / log(2)
sm_scale = sm_scale * 1.44269504
# Load $Q_i$
# $O_i$
b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
# Load $Q_i$ outside the loop since it will be reused through out the loop over $K_j$.
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
if is_causal:
# Upto the diagonal block
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
# Inner loop upto the diagonal block
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
p_kT, p_v,
sm_scale,
BLOCK_Q, d_head, BLOCK_K,
@ -379,7 +391,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
kv_seq_len=kv_seq_len
)
# Diagonal block with masking within it
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
sm_scale,
BLOCK_Q, d_head, BLOCK_K,
offs_i, offs_j,
@ -390,7 +402,8 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
kv_seq_len=kv_seq_len
)
else:
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
# Iterate through all $K_j$
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
sm_scale,
BLOCK_Q, d_head, BLOCK_K,
offs_i, offs_j,
@ -401,13 +414,14 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
kv_seq_len=kv_seq_len
)
# Update LSE
# Store LSE $\log_2 L_i = \log_2 \big( l_i * e^{m_i} \big) = \log_2 l_i + m_i log 2$
tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
# Store $O_i = \frac{\tilde{O}_i}{l_i}$
tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
@triton.jit
def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
def _attn_fwd_inner(b_o, b_l, b_m, b_q,
p_kT, p_v,
scale,
BLOCK_Q: tl.constexpr,
@ -422,45 +436,46 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
):
tl.static_assert(BLOCK_Q % BLOCK_K == 0)
# Move $K_j$ and $V_j$ pointers
p_kT = tl.advance(p_kT, (0, j))
p_v = tl.advance(p_v, (j, 0))
# loop over k, v and update accumulator
# Iterate over $K$, $V$ and update $\tilde{O}_i$ and $l_i$
for _ in range(steps):
current_j = j + offs_j
j_mask = current_j < kv_seq_len
# Load $K_j^T$
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
# Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
tl.static_assert(b_s.dtype == HI_PRES_TL)
b_s = b_s * scale
# Apply causal mask
if MASK:
causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
b_s = tl.where(causal_mask, b_s, -float("inf"))
# always apply seq mask
# Mask out if the block is beyond the end of $K_j$
j_mask = (j + offs_j) < kv_seq_len
b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
# $m_{i}^{\text{new}} = \max(m_i, \text{rowmax}(S_{ij}))$
tl.static_assert(len(b_s.shape) == 2)
b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
# $\tilde{P}_{ij} = \exp(S_{ij} - m_i^{\text{new}})$
b_p = tl.math.exp2(b_s - b_m_new[:, None])
# $\tilde{l}_ij = \text{rowsum}(\tilde{P}_{ij})$
b_l_new = tl.sum(b_p, -1)
# $\sum_{j=j1}^{j2} \tilde{P}_{ij}$
b_l_new = tl.sum(b_p, -1)
# $e^{m_i - m_{i}^{\text{new}}}$
b_m_m_new = tl.math.exp2(b_m - b_m_new)
# $l_i \leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \tilde{l}_{ij}$
# $l_i \leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \sum_{j=j1}^{j2} \tilde{P}_{ij}$
b_l = b_l * b_m_m_new + b_l_new
# $O_i \leftarrow e^{m_i - m_{i}^{\text{new}}} O_i + \tilde{P}_{ij} * V_j$
# $O_i \leftarrow e^{m_i - m_{i}^{\text{new}}} O_i + \tilde{P}_{ij} V_j$
b_o = b_o * b_m_m_new[:, None]
b_p = b_p.to(b_q.dtype) # TODO
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
b_acc = b_acc * b_m_m_new[:, None]
b_p = b_p.to(b_q.dtype)
b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
# update $m_i$
# $m_i \leftarrow m_{i}^{\text{new}}$
b_m = b_m_new
# Move pointers
@ -468,9 +483,9 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
p_v = tl.advance(p_v, (BLOCK_K, 0))
p_kT = tl.advance(p_kT, (0, BLOCK_K))
tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
return b_acc, b_l, b_m
return b_o, b_l, b_m
@triton.jit
@ -503,9 +518,13 @@ def _attn_bwd_d(t_o, t_do,
(n_groups, BLOCK_Q),
(1, 0))
# Load $O_i$
o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
# Load $dO_i$
do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
# Calculate $D_i = dO_i O_i^T$
d = tl.sum(o * do, axis=-1)
# Save $D_i$
tl.store(p_pdp, d, boundary_check=(1,))
@ -523,12 +542,13 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
BLOCK_K: tl.constexpr,
):
"""
Loop along m query; n % m == 0
Compute $dK_j$ and $dV_j$ for $j1 \dots j2$ by iterating over $Q_i$
"""
# K is already multiplied by scale
j = tl.program_id(0) * BLOCK_K
z = tl.program_id(1)
# Create block pointers
p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
(kv_seq_len, d_head),
(d_head, 1),
@ -554,14 +574,15 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
(BLOCK_K, d_head),
(1, 0))
b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
# Initialize $\frac{1}{\sigma} dK$ and $dV$
b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
# load K and V: they stay in SRAM throughout the inner loop.
# Load $\frac{\sigma}{\log 2} K$ and $V$ outside the loop.
b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
# Iterate through queries that attend to save keys
# Iterate through queries in GQA
for g in range(n_groups):
# Create block pointers
p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
@ -590,19 +611,12 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
(BLOCK_Q,),
(0,))
# $$dk_j = \sum_i dS_{ij} q_i = \sum_i P_{ij} \big( do_i^T v_j - D_i \big) q_i$$
# $$dv_j = \sum_i P_{ij} do_i$$
# Compute $dk$ $dv$ and $dv$ along the masked blocks near diagonal.
# Use smaller block size of MASK_BLOCK_Q
# because there is a little extra computation?
if is_causal:
# loop along m
# Inner loop at the diagonal block
b_dk, b_dv = _attn_bwd_dkdv_inner(
b_dk, b_dv,
p_qT, b_k, b_v, p_do,
p_lse, p_pdp,
# You can use a smaller BLOCK_Q if BLOCK_K is not divisible by BLOCK_Q
BLOCK_Q, BLOCK_K,
d_head,
j=j, i=j,
@ -612,7 +626,7 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
kv_seq_len=kv_seq_len,
)
# Compute $dk$ and $dv$ for non-masked blocks.
# Innerloop on queries after the diagonal
b_dk, b_dv = _attn_bwd_dkdv_inner(
b_dk, b_dv,
p_qT, b_k, b_v, p_do,
@ -626,6 +640,7 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
kv_seq_len=kv_seq_len
)
else:
# Iterate through all queries
b_dk, b_dv = _attn_bwd_dkdv_inner(
b_dk, b_dv,
p_qT, b_k, b_v, p_do,
@ -639,14 +654,13 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
kv_seq_len=kv_seq_len
)
# Save $dv$
# Save $dV$
tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
# Since we used $k = \text{scale} * \hat{k}$ where $\hat{k} are the original keys
# we multiple by scale again to get gradient on original keys.
# `b_dk` had $\frac{1}{\sigma} dK$
b_dk *= sm_scale
# Save $dk$
# Save $dK$
tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
@ -670,44 +684,51 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
i_mask = offs_i < q_seq_len
offs_j = j + tl.arange(0, BLOCK_K)
# Pointers
# Move the pointers
p_qT = tl.advance(p_qT, (0, i))
p_do = tl.advance(p_do, (i, 0))
p_lse = tl.advance(p_lse, (i,))
p_pdp = tl.advance(p_pdp, (i,))
# Loop
# Iterate over $Q$
for _ in range(steps):
# Load $$qT$$
# Load $Q_i^T$
b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
# $M_i = log_2 L_i$
b_m = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
# $log_2 L_i$
b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
# $$P_{ij} = \frac{e^{q_i^T k_j}}{L_i} = e^{q_i^T k_j - M_i}$$
# Not that k is already multiplied by softmax scale.
# It is also divided by $log_e 2$ so we can use $2^x$ instead of $e^x$
b_qkT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
b_pT = tl.math.exp2(b_qkT - b_m[None, :])
# $(\log_2 e) S_{ij}^T = \sigma (\log_2 e) K_j Q_i^T$
b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
# \begin{align}
# P_{ij} &= \frac{e^{S_{ij}}}{L_i}
# \\
# &= \frac{2^{(log_2 e) S_{ij}}}{2^{\log_2 L_i}}
# \\
# &= 2^{(log_2 e) S_{ij} - \log_2 L_i}
# \end{align}
b_pT = tl.math.exp2(b_sT - b_l[None, :])
# Autoregressive masking.
if MASK:
mask = (offs_i[None, :] >= offs_j[:, None])
b_pT = tl.where(mask, b_pT, 0.0)
# Mask out if the block is beyond the end of $Q_i$
b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
# $$dv_j = \sum_i P_{ij} do_i$$
# $dV_j = \sum_i P_{ij} dO_i$
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
# $$dk_j = \sum_i dS_{ij} q_i = \sum_i P_{ij} \big( dP^T_{i:} - D_i \big) q_i$$
# $D_i$
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
# $dP_{ij} = do^T_i v_j$
# $dP_{ij} = V_j dO_i^T$
b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
# $dS_{ij} = P_{ij} \big( dP_{i:} - D_i \big)$
# $dS_{ij} = P_{ij} \big( dP_{ij} - D_i \big)$
b_dsT = b_pT * (b_dpT - b_pdp[None, :])
# $dk_j = \sum_i dS_{ij} q_i$
# $\frac{1}{\sigma} dk_j = \sum_i dS_{ij} Q_i$
b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
# Increment pointers.
@ -717,7 +738,7 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
p_qT = tl.advance(p_qT, (0, BLOCK_Q))
p_do = tl.advance(p_do, (BLOCK_Q, 0))
# Return accumulated $dk$ and $dv$
# Return accumulated $dK$ and $dV$
return b_dk, b_dv
@ -784,14 +805,15 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
(BLOCK_Q,),
(0,))
# Load $Q_i$, $dO_i$, $D_i$, and $\log_2 L_i$ outside the loop
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
# Initialize $(\log_2 e)dQ$
b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
# $$dq_i = \sum_j dS_{ij} k_j = \sum_j P_{ij} \big( dP_{ij} - D_i \big) k_j$$
if is_causal:
@ -806,7 +828,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
kv_seq_len=kv_seq_len
)
# Other blocks
# Compute for other blocks
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
b_do, b_lse, b_pdp,
BLOCK_Q, BLOCK_K,
@ -817,6 +839,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
kv_seq_len=kv_seq_len
)
else:
# Iterate through all $K$
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
b_do, b_lse, b_pdp,
BLOCK_Q, BLOCK_K,
@ -827,11 +850,10 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
kv_seq_len=kv_seq_len
)
# Since $k$ was scaled by $\frac{1}{log_e 2}$, and $dq_j = \sum_j dS_{ij} k_j$
# got this factor in to computed $dq$ we need to reverse it.
# `b_dq` stores $(\log_2 e)dQ$ so multiply by $\log_e 2$ to get $dQ$
b_dq *= LN2
# Save $dq$
# Save $dQ$
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))