mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 16:50:39 +08:00
all comments
This commit is contained in:
File diff suppressed because one or more lines are too long
@ -315,11 +315,11 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
Strides `z`, `h`, `m` and `d` denote the stride of the corresponding dimensions
|
||||
(`batch_size`, `n_heads`, `seq_len`, `d_head`) in the query.
|
||||
Stride `n` denote the stride on `seq_len` of key.
|
||||
|
||||
"""
|
||||
|
||||
i = tl.program_id(0)
|
||||
z = tl.program_id(1) // n_groups
|
||||
g = tl.program_id(1) % n_groups
|
||||
g = tl.program_id(1) % n_groups # TODO
|
||||
|
||||
# Create block pointers
|
||||
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
@ -359,15 +359,15 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
# Mask for $Q$ for the last block
|
||||
i_mask = offs_i < q_seq_len
|
||||
|
||||
# Precalculate $\frac{\sigma}{\log 2}$.
|
||||
# Precalculate $\frac{\sigma}{\log_2 e}$.
|
||||
#
|
||||
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
|
||||
sm_scale_log2 = sm_scale * 1.44269504
|
||||
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log_2 e$ instead.
|
||||
sm_scale_log2e = sm_scale * 1.44269504
|
||||
|
||||
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
|
||||
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
|
||||
#
|
||||
# `b_m` will be storing $m_i \log 2$
|
||||
# `b_m` will be storing $m_i \log_2 e$
|
||||
b_m = tl.where(i_mask, -float("inf"), 0.0)
|
||||
b_l = tl.where(i_mask, 1.0, 0.0)
|
||||
|
||||
@ -380,39 +380,39 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
if is_causal:
|
||||
# Inner loop upto the diagonal block
|
||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
p_kT, p_v,
|
||||
sm_scale_log2,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
j=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=(i * BLOCK_Q) // BLOCK_K,
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
p_kT, p_v,
|
||||
sm_scale_log2e,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
j=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=(i * BLOCK_Q) // BLOCK_K,
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
# Diagonal block with masking within it
|
||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale_log2,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
j=i * BLOCK_Q,
|
||||
steps=BLOCK_Q // BLOCK_K,
|
||||
MASK=True,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
sm_scale_log2e,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
j=i * BLOCK_Q,
|
||||
steps=BLOCK_Q // BLOCK_K,
|
||||
MASK=True,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
else:
|
||||
# Iterate through all $K_j$
|
||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale_log2,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
j=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=tl.cdiv(kv_seq_len, BLOCK_K),
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
sm_scale_log2e,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
j=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=tl.cdiv(kv_seq_len, BLOCK_K),
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
|
||||
# Store LSE $\log_2 L_i = \log_2 \big( l_i * e^{m_i} \big) = \log_2 l_i + m_i log 2$
|
||||
tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
|
||||
@ -423,7 +423,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
p_kT, p_v,
|
||||
sm_scale_log2,
|
||||
sm_scale_log2e,
|
||||
BLOCK_Q: tl.constexpr,
|
||||
d_head: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
@ -446,7 +446,7 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
||||
# Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
|
||||
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
||||
b_s = b_s * sm_scale_log2
|
||||
b_s = b_s * sm_scale_log2e
|
||||
|
||||
# Apply causal mask
|
||||
if MASK:
|
||||
@ -457,9 +457,13 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
j_mask = (j + offs_j) < kv_seq_len
|
||||
b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
|
||||
|
||||
# $m_{i}^{\text{new}} = \max(m_i, \text{rowmax}(S_{ij}))$
|
||||
# $(\log_2 e) m_{i}^{\text{new}} = \max((\log_2 e) m_i, \max_{j=j1}^{j2} (\log_2 e) S_{ij})$
|
||||
b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
|
||||
# $\tilde{P}_{ij} = \exp(S_{ij} - m_i^{\text{new}})$
|
||||
# \begin{align}
|
||||
# \tilde{P}_{ij} &= e^{(S_{ij} - m_i^{\text{new}}}
|
||||
# \\
|
||||
# &= 2^{(\log_2 e) S_{ij} - (\log_2 e) m_i^{\text{new}}}
|
||||
# \end{align}
|
||||
b_p = tl.math.exp2(b_s - b_m_new[:, None])
|
||||
|
||||
# $\sum_{j=j1}^{j2} \tilde{P}_{ij}$
|
||||
@ -471,11 +475,11 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
|
||||
# $O_i \leftarrow e^{m_i - m_{i}^{\text{new}}} O_i + \tilde{P}_{ij} V_j$
|
||||
b_o = b_o * b_m_m_new[:, None]
|
||||
b_p = b_p.to(b_q.dtype) # TODO
|
||||
b_p = b_p.to(b_q.dtype) # TODO
|
||||
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
|
||||
b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
|
||||
|
||||
# $m_i \leftarrow m_{i}^{\text{new}}$
|
||||
# $(\log_2 e) m_i \leftarrow (\log_2 e) m_{i}^{\text{new}}$
|
||||
b_m = b_m_new
|
||||
|
||||
# Move pointers
|
||||
@ -674,14 +678,13 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
|
||||
MASK: tl.constexpr,
|
||||
q_seq_len: tl.constexpr,
|
||||
kv_seq_len: tl.constexpr):
|
||||
"""Inner loop along m query"""
|
||||
"""Inner loop along query"""
|
||||
|
||||
# To apply the mask
|
||||
tl.static_assert(BLOCK_K % BLOCK_Q == 0)
|
||||
|
||||
# Offsets for mask computation
|
||||
# Offsets and mask
|
||||
offs_i = i + tl.arange(0, BLOCK_Q)
|
||||
i_mask = offs_i < q_seq_len
|
||||
offs_j = j + tl.arange(0, BLOCK_K)
|
||||
|
||||
# Move the pointers
|
||||
@ -710,12 +713,17 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
|
||||
# \end{align}
|
||||
b_pT = tl.math.exp2(b_sT - b_l[None, :])
|
||||
|
||||
# Autoregressive masking.
|
||||
# Autoregressive masking
|
||||
if MASK:
|
||||
mask = (offs_i[None, :] >= offs_j[:, None])
|
||||
b_pT = tl.where(mask, b_pT, 0.0)
|
||||
|
||||
# Mask out if the block is beyond the end of $Q_i$
|
||||
#
|
||||
# Note: No need to mask out based on $j$
|
||||
# because the effects on positions outside boundary will not get stored in $dK$ or $dV$
|
||||
# Masking by $i$ may also not be necessary size the tensors have 0 on loading
|
||||
i_mask = offs_i < q_seq_len
|
||||
b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
|
||||
|
||||
# $dV_j = \sum_i P_{ij} dO_i$
|
||||
@ -728,7 +736,7 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
|
||||
b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
|
||||
# $dS_{ij} = P_{ij} \big( dP_{ij} - D_i \big)$
|
||||
b_dsT = b_pT * (b_dpT - b_pdp[None, :])
|
||||
# $\frac{1}{\sigma} dk_j = \sum_i dS_{ij} Q_i$
|
||||
# $\frac{1}{\sigma} dK_j = \sum_i dS_{ij} Q_i$
|
||||
b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
|
||||
|
||||
# Increment pointers.
|
||||
@ -759,7 +767,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
|
||||
i = tl.program_id(0) * BLOCK_Q
|
||||
z = tl.program_id(1) // n_groups
|
||||
g = tl.program_id(1) % n_groups
|
||||
g = tl.program_id(1) % n_groups # TODO
|
||||
|
||||
# Create block pointers
|
||||
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
@ -865,49 +873,59 @@ def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||
MASK: tl.constexpr,
|
||||
q_seq_len: tl.constexpr,
|
||||
kv_seq_len: tl.constexpr):
|
||||
"""Inner loop over n key"""
|
||||
offs_i = i + tl.arange(0, BLOCK_Q)
|
||||
offs_j = tl.arange(0, BLOCK_K)
|
||||
"""Inner loop over key"""
|
||||
|
||||
# Offsets
|
||||
offs_i = i + tl.arange(0, BLOCK_Q)
|
||||
offs_j = j + tl.arange(0, BLOCK_K)
|
||||
|
||||
# Move the pointers
|
||||
p_kT = tl.advance(p_kT, (0, j))
|
||||
p_vT = tl.advance(p_vT, (0, j))
|
||||
|
||||
tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')
|
||||
|
||||
# Iterate over $K$
|
||||
for _ in range(steps):
|
||||
current_j = j + offs_j
|
||||
j_mask = current_j < kv_seq_len
|
||||
|
||||
# $$P_{ij} = \frac{e^{q_i^T k_j}}{L_i} = e^{q_i^T k_j - M_i}$$
|
||||
# Not that k is already multiplied by softmax scale.
|
||||
# It is also divided by $log_e 2$ so we can use $2^x$ instead of $e^x$
|
||||
# Load $K_j^T$
|
||||
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
||||
# Load $V_j^T$
|
||||
b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")
|
||||
b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
||||
b_p = tl.math.exp2(b_qk - b_lse[:, None])
|
||||
|
||||
# Autoregressive masking.
|
||||
# $(\log_2 e) S_{ij} = \sigma (\log_2 e) Q_i K_j^T$
|
||||
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
||||
|
||||
# \begin{align}
|
||||
# P_{ij} &= \frac{e^{S_{ij}}}{L_i}
|
||||
# \\
|
||||
# &= \frac{2^{(log_2 e) S_{ij}}}{2^{\log_2 L_i}}
|
||||
# \\
|
||||
# &= 2^{(log_2 e) S_{ij} - \log_2 L_i}
|
||||
# \end{align}
|
||||
b_p = tl.math.exp2(b_s - b_lse[:, None])
|
||||
|
||||
# Autoregressive masking
|
||||
if MASK:
|
||||
causal_mask = (offs_i[:, None] >= current_j[None, :])
|
||||
causal_mask = (offs_i[:, None] >= offs_j[None, :])
|
||||
b_p = tl.where(causal_mask, b_p, 0.0)
|
||||
|
||||
# Mask out if the block is beyond the end of $Q_i$
|
||||
j_mask = offs_j < kv_seq_len
|
||||
b_p = tl.where(j_mask[None, :], b_p, 0.0)
|
||||
|
||||
# $$dq_i = \sum_j dS_{ij} k_j = \sum_j P_{ij} \big( dP_{ij} - D_i \big) k_j$$
|
||||
|
||||
# $dP_{ij} = do^T_i v_j$
|
||||
# $dP_{ij} = dO_i V_j^T$
|
||||
b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
|
||||
# $dS_{ij} = P_{ij} \big( dP_{i:} - D_i \big)$
|
||||
# $dS_{ij} = P_{ij} \big( dP_{ij} - D_i \big)$
|
||||
b_ds = b_p * (b_dp - b_pdp[:, None])
|
||||
# $dq_j = \sum_j dS_{ij} k_j$
|
||||
b_dq += tl.dot(b_ds.to(b_kT.dtype),
|
||||
tl.trans(b_kT),
|
||||
out_dtype=HI_PRES_TL)
|
||||
# $(\log_2 e) dQ_i = \sum_j dS_{ij} \sigma (\log_2 e) K_j$
|
||||
b_dq += tl.dot(b_ds.to(b_kT.dtype), tl.trans(b_kT), out_dtype=HI_PRES_TL)
|
||||
|
||||
# Increment pointers.
|
||||
j += BLOCK_K
|
||||
offs_j += BLOCK_K
|
||||
p_kT = tl.advance(p_kT, (0, BLOCK_K))
|
||||
p_vT = tl.advance(p_vT, (0, BLOCK_K))
|
||||
|
||||
# Return accumulated $dq$
|
||||
# Return accumulated $dQ$
|
||||
return b_dq
|
||||
|
Reference in New Issue
Block a user