all comments

This commit is contained in:
Varuna Jayasiri
2025-08-01 14:14:00 +05:30
parent 73b9892be6
commit 3dd36b80b3

View File

@ -151,7 +151,7 @@ class AttentionFunc(torch.autograd.Function):
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
_attn_fwd[grid](
q, k, v, sm_scale, lse, o,
q, k, v, sm_scale * 1.4426950408889634, lse, o,
n_groups=n_groups,
q_seq_len=q_seq_len,
kv_seq_len=kv_seq_len,
@ -201,10 +201,8 @@ class AttentionFunc(torch.autograd.Function):
dk = torch.empty_like(k)
dv = torch.empty_like(v)
# $\log_2 e$
RCP_LN2 = 1.4426950408889634
# Precompute $\sigma (\log_2 e) K_j$
k_scaled = k * (sm_scale * RCP_LN2)
k_scaled = k * (sm_scale * 1.4426950408889634)
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
pdp = torch.empty_like(lse)
# We use fixed `BLOCK_Q` for backward pass on $D$
@ -288,7 +286,7 @@ def _get_autotune_configs(inner_loop: str) -> list:
@triton.autotune(_get_autotune_configs(inner_loop='key'),
key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
@triton.jit
def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
n_groups: tl.constexpr,
q_seq_len: tl.constexpr,
kv_seq_len: tl.constexpr,
@ -359,11 +357,6 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
# Mask for $Q$ for the last block
i_mask = offs_i < q_seq_len
# Precalculate $\frac{\sigma}{\log_2 e}$.
#
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log_2 e$ instead.
sm_scale_log2e = sm_scale * 1.44269504
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
#
@ -762,9 +755,6 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# $\log_e 2$
LN2: tl.constexpr = 0.6931471824645996 # type: ignore
i = tl.program_id(0) * BLOCK_Q
z = tl.program_id(1) // n_groups
g = tl.program_id(1) % n_groups # TODO
@ -859,7 +849,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
)
# `b_dq` stores $(\log_2 e)dQ$ so multiply by $\log_e 2$ to get $dQ$
b_dq *= LN2
b_dq *= 0.6931471824645996
# Save $dQ$
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))