sm scale log2

This commit is contained in:
Varuna Jayasiri
2025-08-01 13:26:48 +05:30
parent 5a8182d21b
commit eb5c004fac

View File

@ -362,7 +362,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
# Precalculate $\frac{\sigma}{\log 2}$.
#
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
sm_scale = sm_scale * 1.44269504
sm_scale_log2 = sm_scale * 1.44269504
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
@ -381,7 +381,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
# Inner loop upto the diagonal block
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
p_kT, p_v,
sm_scale,
sm_scale_log2,
BLOCK_Q, d_head, BLOCK_K,
offs_i, offs_j,
j=tl.full([], 0, tl.int32), # type: ignore
@ -392,7 +392,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
)
# Diagonal block with masking within it
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
sm_scale,
sm_scale_log2,
BLOCK_Q, d_head, BLOCK_K,
offs_i, offs_j,
j=i * BLOCK_Q,
@ -404,7 +404,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
else:
# Iterate through all $K_j$
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
sm_scale,
sm_scale_log2,
BLOCK_Q, d_head, BLOCK_K,
offs_i, offs_j,
j=tl.full([], 0, tl.int32), # type: ignore
@ -423,7 +423,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
@triton.jit
def _attn_fwd_inner(b_o, b_l, b_m, b_q,
p_kT, p_v,
scale,
sm_scale_log2,
BLOCK_Q: tl.constexpr,
d_head: tl.constexpr,
BLOCK_K: tl.constexpr,
@ -446,7 +446,7 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
# Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
b_s = b_s * scale
b_s = b_s * sm_scale_log2
# Apply causal mask
if MASK: