diff --git a/labml_nn/transformers/flash/__init__.py b/labml_nn/transformers/flash/__init__.py index 1c51a0cc..7f3607b8 100644 --- a/labml_nn/transformers/flash/__init__.py +++ b/labml_nn/transformers/flash/__init__.py @@ -362,7 +362,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o, # Precalculate $\frac{\sigma}{\log 2}$. # # We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead. - sm_scale = sm_scale * 1.44269504 + sm_scale_log2 = sm_scale * 1.44269504 # Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update, # the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$. @@ -381,7 +381,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o, # Inner loop upto the diagonal block b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v, - sm_scale, + sm_scale_log2, BLOCK_Q, d_head, BLOCK_K, offs_i, offs_j, j=tl.full([], 0, tl.int32), # type: ignore @@ -392,7 +392,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o, ) # Diagonal block with masking within it b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v, - sm_scale, + sm_scale_log2, BLOCK_Q, d_head, BLOCK_K, offs_i, offs_j, j=i * BLOCK_Q, @@ -404,7 +404,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o, else: # Iterate through all $K_j$ b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v, - sm_scale, + sm_scale_log2, BLOCK_Q, d_head, BLOCK_K, offs_i, offs_j, j=tl.full([], 0, tl.int32), # type: ignore @@ -423,7 +423,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o, @triton.jit def _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v, - scale, + sm_scale_log2, BLOCK_Q: tl.constexpr, d_head: tl.constexpr, BLOCK_K: tl.constexpr, @@ -446,7 +446,7 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q, b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero") # Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$ b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL) - b_s = b_s * scale + b_s = b_s * sm_scale_log2 # Apply causal mask if MASK: