mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
backward pass formulas
This commit is contained in:
@ -1086,7 +1086,7 @@
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/transformers/flash/test.html</loc>
|
||||
<lastmod>2025-07-30T16:30:00+00:00</lastmod>
|
||||
<lastmod>2025-07-31T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
@ -4,7 +4,7 @@
|
||||
## Forward pass
|
||||
|
||||
\begin{align}
|
||||
S_{ij} &= q_i k_j^T
|
||||
S_{ij} &= \sigma Q_i K_j^T
|
||||
\\
|
||||
L_i &= \sum_j e^{S_{ij}}
|
||||
\\
|
||||
@ -20,7 +20,7 @@ by computing the sum of exponents $l_i$ and the unnormalized output $\tilde{O}_i
|
||||
while iterating over keys:
|
||||
|
||||
\begin{align}
|
||||
S_{ij} &= Q_i K_j^T
|
||||
S_{ij} &= \sigma Q_i K_j^T
|
||||
\\
|
||||
l_i &\leftarrow l_i + e^{S_{ij}}
|
||||
\\
|
||||
@ -50,6 +50,7 @@ l_i &\leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \sum_{j=j1}^{j2} \tilde{P}_{i
|
||||
\\
|
||||
\tilde{O}_i &\leftarrow e^{m_i - m_{i}^{\text{new}}} \tilde{O}_i + \tilde{P}_{ij} * V_j
|
||||
\\
|
||||
m_i &\leftarrow m_{i}^{\text{new}}
|
||||
\end{align}
|
||||
|
||||
Then finally,
|
||||
@ -69,9 +70,9 @@ dS_{ij} &= d\text{softmax}(dP_{ij})
|
||||
\\
|
||||
&= P_{ij} dP_{ij} - P_{ij} \sum P_{ik} dP_{ik}
|
||||
\\
|
||||
dQ_i &= \sum_j dS_{ij} K_j
|
||||
dQ_i &= \sigma \sum_j dS_{ij} K_j
|
||||
\\
|
||||
qK_j &= \sum_i dS_{ij} Q_i
|
||||
dK_j &= \sigma \sum_i dS_{ij} Q_i
|
||||
\end{align}
|
||||
|
||||
where $\delta_{jk}$ is $1$ when $j = k$ and $0$ otherwise.
|
||||
@ -144,7 +145,7 @@ class AttentionFunc(torch.autograd.Function):
|
||||
|
||||
# Tensor for the output
|
||||
o = torch.empty_like(q)
|
||||
# Tensor for $\log_2 \sum_j e^{S_{ij}}$
|
||||
# Tensor for log of sum of exponentials $\log_2 L_i = \log_2 \sum_j e^{S_{ij}}$
|
||||
lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
|
||||
|
||||
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
|
||||
@ -200,17 +201,18 @@ class AttentionFunc(torch.autograd.Function):
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
|
||||
# $\frac{1}{\log_e 2}$
|
||||
# $\log_2 e$
|
||||
RCP_LN2 = 1.4426950408889634
|
||||
# Multiply $k$ by softmax scale
|
||||
# Precompute $\sigma (\log_2 e) K_j$
|
||||
k_scaled = k * (sm_scale * RCP_LN2)
|
||||
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
||||
pdp = torch.empty_like(lse)
|
||||
# We use fixed `BLOCK_Q` for backward pass on $D$
|
||||
BLOCK_Q = 16
|
||||
|
||||
# Compute $D_i$
|
||||
#
|
||||
# This is parallelized along the batch and query in blocks of size `BLOCK_Q`
|
||||
BLOCK_Q = 16
|
||||
pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
|
||||
_attn_bwd_d[pre_grid](
|
||||
o, do,
|
||||
@ -221,6 +223,7 @@ class AttentionFunc(torch.autograd.Function):
|
||||
n_groups=n_groups,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
# Compute $dK$ and $dV$
|
||||
#
|
||||
# This is parallelized along the batch and keys in blocks of size `BLOCK_K`
|
||||
@ -232,6 +235,7 @@ class AttentionFunc(torch.autograd.Function):
|
||||
is_causal=causal,
|
||||
|
||||
)
|
||||
|
||||
# Compute $dQ$
|
||||
#
|
||||
# This is parallelized along the batch and queries in blocks of size `BLOCK_Q`
|
||||
@ -351,23 +355,31 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
|
||||
# Initialize offsets
|
||||
offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
|
||||
i_mask = offs_i < q_seq_len
|
||||
offs_j = tl.arange(0, BLOCK_K)
|
||||
# Mask for $Q$ for the last block
|
||||
i_mask = offs_i < q_seq_len
|
||||
|
||||
# Initialize $m_i$ and $l_i$
|
||||
# Precalculate $\frac{\sigma}{\log 2}$.
|
||||
#
|
||||
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
|
||||
sm_scale = sm_scale * 1.44269504
|
||||
|
||||
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
|
||||
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
|
||||
#
|
||||
# `b_m` will be storing $m_i \log 2$
|
||||
b_m = tl.where(i_mask, -float("inf"), 0.0)
|
||||
b_l = tl.where(i_mask, 1.0, 0.0)
|
||||
# Accumulate $O$
|
||||
b_acc = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
||||
|
||||
# softmax scale / log(2)
|
||||
sm_scale = sm_scale * 1.44269504
|
||||
# Load $Q_i$
|
||||
# $O_i$
|
||||
b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
||||
|
||||
# Load $Q_i$ outside the loop since it will be reused through out the loop over $K_j$.
|
||||
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
if is_causal:
|
||||
# Upto the diagonal block
|
||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
# Inner loop upto the diagonal block
|
||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
p_kT, p_v,
|
||||
sm_scale,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
@ -379,7 +391,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
# Diagonal block with masking within it
|
||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
@ -390,7 +402,8 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
else:
|
||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
||||
# Iterate through all $K_j$
|
||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
@ -401,13 +414,14 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
|
||||
# Update LSE
|
||||
# Store LSE $\log_2 L_i = \log_2 \big( l_i * e^{m_i} \big) = \log_2 l_i + m_i log 2$
|
||||
tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
|
||||
tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
|
||||
# Store $O_i = \frac{\tilde{O}_i}{l_i}$
|
||||
tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
p_kT, p_v,
|
||||
scale,
|
||||
BLOCK_Q: tl.constexpr,
|
||||
@ -422,45 +436,46 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
):
|
||||
tl.static_assert(BLOCK_Q % BLOCK_K == 0)
|
||||
|
||||
# Move $K_j$ and $V_j$ pointers
|
||||
p_kT = tl.advance(p_kT, (0, j))
|
||||
p_v = tl.advance(p_v, (j, 0))
|
||||
|
||||
# loop over k, v and update accumulator
|
||||
# Iterate over $K$, $V$ and update $\tilde{O}_i$ and $l_i$
|
||||
for _ in range(steps):
|
||||
current_j = j + offs_j
|
||||
j_mask = current_j < kv_seq_len
|
||||
|
||||
# Load $K_j^T$
|
||||
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
||||
# Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
|
||||
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
||||
|
||||
tl.static_assert(b_s.dtype == HI_PRES_TL)
|
||||
b_s = b_s * scale
|
||||
|
||||
# Apply causal mask
|
||||
if MASK:
|
||||
causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
|
||||
b_s = tl.where(causal_mask, b_s, -float("inf"))
|
||||
# always apply seq mask
|
||||
|
||||
# Mask out if the block is beyond the end of $K_j$
|
||||
j_mask = (j + offs_j) < kv_seq_len
|
||||
b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
|
||||
|
||||
# $m_{i}^{\text{new}} = \max(m_i, \text{rowmax}(S_{ij}))$
|
||||
tl.static_assert(len(b_s.shape) == 2)
|
||||
b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
|
||||
# $\tilde{P}_{ij} = \exp(S_{ij} - m_i^{\text{new}})$
|
||||
b_p = tl.math.exp2(b_s - b_m_new[:, None])
|
||||
# $\tilde{l}_ij = \text{rowsum}(\tilde{P}_{ij})$
|
||||
b_l_new = tl.sum(b_p, -1)
|
||||
|
||||
# $\sum_{j=j1}^{j2} \tilde{P}_{ij}$
|
||||
b_l_new = tl.sum(b_p, -1)
|
||||
# $e^{m_i - m_{i}^{\text{new}}}$
|
||||
b_m_m_new = tl.math.exp2(b_m - b_m_new)
|
||||
# $l_i \leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \tilde{l}_{ij}$
|
||||
# $l_i \leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \sum_{j=j1}^{j2} \tilde{P}_{ij}$
|
||||
b_l = b_l * b_m_m_new + b_l_new
|
||||
|
||||
# $O_i \leftarrow e^{m_i - m_{i}^{\text{new}}} O_i + \tilde{P}_{ij} * V_j$
|
||||
# $O_i \leftarrow e^{m_i - m_{i}^{\text{new}}} O_i + \tilde{P}_{ij} V_j$
|
||||
b_o = b_o * b_m_m_new[:, None]
|
||||
b_p = b_p.to(b_q.dtype) # TODO
|
||||
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
|
||||
b_acc = b_acc * b_m_m_new[:, None]
|
||||
b_p = b_p.to(b_q.dtype)
|
||||
b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
|
||||
b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
|
||||
|
||||
# update $m_i$
|
||||
# $m_i \leftarrow m_{i}^{\text{new}}$
|
||||
b_m = b_m_new
|
||||
|
||||
# Move pointers
|
||||
@ -468,9 +483,9 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
p_v = tl.advance(p_v, (BLOCK_K, 0))
|
||||
p_kT = tl.advance(p_kT, (0, BLOCK_K))
|
||||
|
||||
tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
|
||||
tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
|
||||
|
||||
return b_acc, b_l, b_m
|
||||
return b_o, b_l, b_m
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -503,9 +518,13 @@ def _attn_bwd_d(t_o, t_do,
|
||||
(n_groups, BLOCK_Q),
|
||||
(1, 0))
|
||||
|
||||
# Load $O_i$
|
||||
o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
|
||||
# Load $dO_i$
|
||||
do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
|
||||
# Calculate $D_i = dO_i O_i^T$
|
||||
d = tl.sum(o * do, axis=-1)
|
||||
# Save $D_i$
|
||||
tl.store(p_pdp, d, boundary_check=(1,))
|
||||
|
||||
|
||||
@ -523,12 +542,13 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Loop along m query; n % m == 0
|
||||
Compute $dK_j$ and $dV_j$ for $j1 \dots j2$ by iterating over $Q_i$
|
||||
"""
|
||||
# K is already multiplied by scale
|
||||
|
||||
j = tl.program_id(0) * BLOCK_K
|
||||
z = tl.program_id(1)
|
||||
|
||||
# Create block pointers
|
||||
p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
|
||||
(kv_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
@ -554,14 +574,15 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
(BLOCK_K, d_head),
|
||||
(1, 0))
|
||||
|
||||
b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
|
||||
# Initialize $\frac{1}{\sigma} dK$ and $dV$
|
||||
b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
|
||||
b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
# Load $\frac{\sigma}{\log 2} K$ and $V$ outside the loop.
|
||||
b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
|
||||
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
# Iterate through queries that attend to save keys
|
||||
# Iterate through queries in GQA
|
||||
for g in range(n_groups):
|
||||
# Create block pointers
|
||||
p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
@ -590,19 +611,12 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
(BLOCK_Q,),
|
||||
(0,))
|
||||
|
||||
# $$dk_j = \sum_i dS_{ij} q_i = \sum_i P_{ij} \big( do_i^T v_j - D_i \big) q_i$$
|
||||
# $$dv_j = \sum_i P_{ij} do_i$$
|
||||
|
||||
# Compute $dk$ $dv$ and $dv$ along the masked blocks near diagonal.
|
||||
# Use smaller block size of MASK_BLOCK_Q
|
||||
# because there is a little extra computation?
|
||||
if is_causal:
|
||||
# loop along m
|
||||
# Inner loop at the diagonal block
|
||||
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
||||
b_dk, b_dv,
|
||||
p_qT, b_k, b_v, p_do,
|
||||
p_lse, p_pdp,
|
||||
# You can use a smaller BLOCK_Q if BLOCK_K is not divisible by BLOCK_Q
|
||||
BLOCK_Q, BLOCK_K,
|
||||
d_head,
|
||||
j=j, i=j,
|
||||
@ -612,7 +626,7 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
kv_seq_len=kv_seq_len,
|
||||
)
|
||||
|
||||
# Compute $dk$ and $dv$ for non-masked blocks.
|
||||
# Innerloop on queries after the diagonal
|
||||
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
||||
b_dk, b_dv,
|
||||
p_qT, b_k, b_v, p_do,
|
||||
@ -626,6 +640,7 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
else:
|
||||
# Iterate through all queries
|
||||
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
||||
b_dk, b_dv,
|
||||
p_qT, b_k, b_v, p_do,
|
||||
@ -639,14 +654,13 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
|
||||
# Save $dv$
|
||||
# Save $dV$
|
||||
tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
|
||||
|
||||
# Since we used $k = \text{scale} * \hat{k}$ where $\hat{k} are the original keys
|
||||
# we multiple by scale again to get gradient on original keys.
|
||||
# `b_dk` had $\frac{1}{\sigma} dK$
|
||||
b_dk *= sm_scale
|
||||
|
||||
# Save $dk$
|
||||
# Save $dK$
|
||||
tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
|
||||
|
||||
|
||||
@ -670,44 +684,51 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
|
||||
i_mask = offs_i < q_seq_len
|
||||
offs_j = j + tl.arange(0, BLOCK_K)
|
||||
|
||||
# Pointers
|
||||
# Move the pointers
|
||||
p_qT = tl.advance(p_qT, (0, i))
|
||||
p_do = tl.advance(p_do, (i, 0))
|
||||
p_lse = tl.advance(p_lse, (i,))
|
||||
p_pdp = tl.advance(p_pdp, (i,))
|
||||
|
||||
# Loop
|
||||
# Iterate over $Q$
|
||||
for _ in range(steps):
|
||||
# Load $$qT$$
|
||||
# Load $Q_i^T$
|
||||
b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
|
||||
|
||||
# $M_i = log_2 L_i$
|
||||
b_m = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
|
||||
# $log_2 L_i$
|
||||
b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
# $$P_{ij} = \frac{e^{q_i^T k_j}}{L_i} = e^{q_i^T k_j - M_i}$$
|
||||
# Not that k is already multiplied by softmax scale.
|
||||
# It is also divided by $log_e 2$ so we can use $2^x$ instead of $e^x$
|
||||
b_qkT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
|
||||
b_pT = tl.math.exp2(b_qkT - b_m[None, :])
|
||||
# $(\log_2 e) S_{ij}^T = \sigma (\log_2 e) K_j Q_i^T$
|
||||
b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
|
||||
|
||||
# \begin{align}
|
||||
# P_{ij} &= \frac{e^{S_{ij}}}{L_i}
|
||||
# \\
|
||||
# &= \frac{2^{(log_2 e) S_{ij}}}{2^{\log_2 L_i}}
|
||||
# \\
|
||||
# &= 2^{(log_2 e) S_{ij} - \log_2 L_i}
|
||||
# \end{align}
|
||||
b_pT = tl.math.exp2(b_sT - b_l[None, :])
|
||||
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
mask = (offs_i[None, :] >= offs_j[:, None])
|
||||
b_pT = tl.where(mask, b_pT, 0.0)
|
||||
|
||||
# Mask out if the block is beyond the end of $Q_i$
|
||||
b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
|
||||
|
||||
# $$dv_j = \sum_i P_{ij} do_i$$
|
||||
# $dV_j = \sum_i P_{ij} dO_i$
|
||||
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
|
||||
b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
|
||||
|
||||
# $$dk_j = \sum_i dS_{ij} q_i = \sum_i P_{ij} \big( dP^T_{i:} - D_i \big) q_i$$
|
||||
# $D_i$
|
||||
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
|
||||
# $dP_{ij} = do^T_i v_j$
|
||||
# $dP_{ij} = V_j dO_i^T$
|
||||
b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
|
||||
# $dS_{ij} = P_{ij} \big( dP_{i:} - D_i \big)$
|
||||
# $dS_{ij} = P_{ij} \big( dP_{ij} - D_i \big)$
|
||||
b_dsT = b_pT * (b_dpT - b_pdp[None, :])
|
||||
# $dk_j = \sum_i dS_{ij} q_i$
|
||||
# $\frac{1}{\sigma} dk_j = \sum_i dS_{ij} Q_i$
|
||||
b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
|
||||
|
||||
# Increment pointers.
|
||||
@ -717,7 +738,7 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
|
||||
p_qT = tl.advance(p_qT, (0, BLOCK_Q))
|
||||
p_do = tl.advance(p_do, (BLOCK_Q, 0))
|
||||
|
||||
# Return accumulated $dk$ and $dv$
|
||||
# Return accumulated $dK$ and $dV$
|
||||
return b_dk, b_dv
|
||||
|
||||
|
||||
@ -784,14 +805,15 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
(BLOCK_Q,),
|
||||
(0,))
|
||||
|
||||
# Load $Q_i$, $dO_i$, $D_i$, and $\log_2 L_i$ outside the loop
|
||||
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
|
||||
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
|
||||
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
||||
|
||||
b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
# Initialize $(\log_2 e)dQ$
|
||||
b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
||||
|
||||
# $$dq_i = \sum_j dS_{ij} k_j = \sum_j P_{ij} \big( dP_{ij} - D_i \big) k_j$$
|
||||
|
||||
if is_causal:
|
||||
@ -806,7 +828,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
|
||||
# Other blocks
|
||||
# Compute for other blocks
|
||||
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||
b_do, b_lse, b_pdp,
|
||||
BLOCK_Q, BLOCK_K,
|
||||
@ -817,6 +839,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
else:
|
||||
# Iterate through all $K$
|
||||
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||
b_do, b_lse, b_pdp,
|
||||
BLOCK_Q, BLOCK_K,
|
||||
@ -827,11 +850,10 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
|
||||
# Since $k$ was scaled by $\frac{1}{log_e 2}$, and $dq_j = \sum_j dS_{ij} k_j$
|
||||
# got this factor in to computed $dq$ we need to reverse it.
|
||||
# `b_dq` stores $(\log_2 e)dQ$ so multiply by $\log_e 2$ to get $dQ$
|
||||
b_dq *= LN2
|
||||
|
||||
# Save $dq$
|
||||
# Save $dQ$
|
||||
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user