mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
sm scale log2
This commit is contained in:
@ -362,7 +362,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
# Precalculate $\frac{\sigma}{\log 2}$.
|
||||
#
|
||||
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
|
||||
sm_scale = sm_scale * 1.44269504
|
||||
sm_scale_log2 = sm_scale * 1.44269504
|
||||
|
||||
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
|
||||
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
|
||||
@ -381,7 +381,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
# Inner loop upto the diagonal block
|
||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
p_kT, p_v,
|
||||
sm_scale,
|
||||
sm_scale_log2,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
j=tl.full([], 0, tl.int32), # type: ignore
|
||||
@ -392,7 +392,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
)
|
||||
# Diagonal block with masking within it
|
||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale,
|
||||
sm_scale_log2,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
j=i * BLOCK_Q,
|
||||
@ -404,7 +404,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
else:
|
||||
# Iterate through all $K_j$
|
||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale,
|
||||
sm_scale_log2,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
j=tl.full([], 0, tl.int32), # type: ignore
|
||||
@ -423,7 +423,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
p_kT, p_v,
|
||||
scale,
|
||||
sm_scale_log2,
|
||||
BLOCK_Q: tl.constexpr,
|
||||
d_head: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
@ -446,7 +446,7 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
||||
# Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
|
||||
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
||||
b_s = b_s * scale
|
||||
b_s = b_s * sm_scale_log2
|
||||
|
||||
# Apply causal mask
|
||||
if MASK:
|
||||
|
Reference in New Issue
Block a user