mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
all comments
This commit is contained in:
@ -151,7 +151,7 @@ class AttentionFunc(torch.autograd.Function):
|
||||
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
|
||||
grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
|
||||
_attn_fwd[grid](
|
||||
q, k, v, sm_scale, lse, o,
|
||||
q, k, v, sm_scale * 1.4426950408889634, lse, o,
|
||||
n_groups=n_groups,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len,
|
||||
@ -201,10 +201,8 @@ class AttentionFunc(torch.autograd.Function):
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
|
||||
# $\log_2 e$
|
||||
RCP_LN2 = 1.4426950408889634
|
||||
# Precompute $\sigma (\log_2 e) K_j$
|
||||
k_scaled = k * (sm_scale * RCP_LN2)
|
||||
k_scaled = k * (sm_scale * 1.4426950408889634)
|
||||
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
||||
pdp = torch.empty_like(lse)
|
||||
# We use fixed `BLOCK_Q` for backward pass on $D$
|
||||
@ -288,7 +286,7 @@ def _get_autotune_configs(inner_loop: str) -> list:
|
||||
@triton.autotune(_get_autotune_configs(inner_loop='key'),
|
||||
key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
|
||||
@triton.jit
|
||||
def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
|
||||
n_groups: tl.constexpr,
|
||||
q_seq_len: tl.constexpr,
|
||||
kv_seq_len: tl.constexpr,
|
||||
@ -359,11 +357,6 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
# Mask for $Q$ for the last block
|
||||
i_mask = offs_i < q_seq_len
|
||||
|
||||
# Precalculate $\frac{\sigma}{\log_2 e}$.
|
||||
#
|
||||
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log_2 e$ instead.
|
||||
sm_scale_log2e = sm_scale * 1.44269504
|
||||
|
||||
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
|
||||
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
|
||||
#
|
||||
@ -762,9 +755,6 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
BLOCK_Q: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
# $\log_e 2$
|
||||
LN2: tl.constexpr = 0.6931471824645996 # type: ignore
|
||||
|
||||
i = tl.program_id(0) * BLOCK_Q
|
||||
z = tl.program_id(1) // n_groups
|
||||
g = tl.program_id(1) % n_groups # TODO
|
||||
@ -859,7 +849,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
)
|
||||
|
||||
# `b_dq` stores $(\log_2 e)dQ$ so multiply by $\log_e 2$ to get $dQ$
|
||||
b_dq *= LN2
|
||||
b_dq *= 0.6931471824645996
|
||||
|
||||
# Save $dQ$
|
||||
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
|
||||
|
Reference in New Issue
Block a user