mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
sm scale log2
This commit is contained in:
@ -362,7 +362,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
# Precalculate $\frac{\sigma}{\log 2}$.
|
# Precalculate $\frac{\sigma}{\log 2}$.
|
||||||
#
|
#
|
||||||
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
|
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
|
||||||
sm_scale = sm_scale * 1.44269504
|
sm_scale_log2 = sm_scale * 1.44269504
|
||||||
|
|
||||||
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
|
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
|
||||||
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
|
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
|
||||||
@ -381,7 +381,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
# Inner loop upto the diagonal block
|
# Inner loop upto the diagonal block
|
||||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||||
p_kT, p_v,
|
p_kT, p_v,
|
||||||
sm_scale,
|
sm_scale_log2,
|
||||||
BLOCK_Q, d_head, BLOCK_K,
|
BLOCK_Q, d_head, BLOCK_K,
|
||||||
offs_i, offs_j,
|
offs_i, offs_j,
|
||||||
j=tl.full([], 0, tl.int32), # type: ignore
|
j=tl.full([], 0, tl.int32), # type: ignore
|
||||||
@ -392,7 +392,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
)
|
)
|
||||||
# Diagonal block with masking within it
|
# Diagonal block with masking within it
|
||||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||||
sm_scale,
|
sm_scale_log2,
|
||||||
BLOCK_Q, d_head, BLOCK_K,
|
BLOCK_Q, d_head, BLOCK_K,
|
||||||
offs_i, offs_j,
|
offs_i, offs_j,
|
||||||
j=i * BLOCK_Q,
|
j=i * BLOCK_Q,
|
||||||
@ -404,7 +404,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
else:
|
else:
|
||||||
# Iterate through all $K_j$
|
# Iterate through all $K_j$
|
||||||
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
|
||||||
sm_scale,
|
sm_scale_log2,
|
||||||
BLOCK_Q, d_head, BLOCK_K,
|
BLOCK_Q, d_head, BLOCK_K,
|
||||||
offs_i, offs_j,
|
offs_i, offs_j,
|
||||||
j=tl.full([], 0, tl.int32), # type: ignore
|
j=tl.full([], 0, tl.int32), # type: ignore
|
||||||
@ -423,7 +423,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
||||||
p_kT, p_v,
|
p_kT, p_v,
|
||||||
scale,
|
sm_scale_log2,
|
||||||
BLOCK_Q: tl.constexpr,
|
BLOCK_Q: tl.constexpr,
|
||||||
d_head: tl.constexpr,
|
d_head: tl.constexpr,
|
||||||
BLOCK_K: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
@ -446,7 +446,7 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
|
|||||||
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
||||||
# Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
|
# Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
|
||||||
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
||||||
b_s = b_s * scale
|
b_s = b_s * sm_scale_log2
|
||||||
|
|
||||||
# Apply causal mask
|
# Apply causal mask
|
||||||
if MASK:
|
if MASK:
|
||||||
|
Reference in New Issue
Block a user