mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 16:50:39 +08:00
all comments
This commit is contained in:
@ -151,7 +151,7 @@ class AttentionFunc(torch.autograd.Function):
|
|||||||
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
|
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
|
||||||
grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
|
grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
|
||||||
_attn_fwd[grid](
|
_attn_fwd[grid](
|
||||||
q, k, v, sm_scale, lse, o,
|
q, k, v, sm_scale * 1.4426950408889634, lse, o,
|
||||||
n_groups=n_groups,
|
n_groups=n_groups,
|
||||||
q_seq_len=q_seq_len,
|
q_seq_len=q_seq_len,
|
||||||
kv_seq_len=kv_seq_len,
|
kv_seq_len=kv_seq_len,
|
||||||
@ -201,10 +201,8 @@ class AttentionFunc(torch.autograd.Function):
|
|||||||
dk = torch.empty_like(k)
|
dk = torch.empty_like(k)
|
||||||
dv = torch.empty_like(v)
|
dv = torch.empty_like(v)
|
||||||
|
|
||||||
# $\log_2 e$
|
|
||||||
RCP_LN2 = 1.4426950408889634
|
|
||||||
# Precompute $\sigma (\log_2 e) K_j$
|
# Precompute $\sigma (\log_2 e) K_j$
|
||||||
k_scaled = k * (sm_scale * RCP_LN2)
|
k_scaled = k * (sm_scale * 1.4426950408889634)
|
||||||
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
||||||
pdp = torch.empty_like(lse)
|
pdp = torch.empty_like(lse)
|
||||||
# We use fixed `BLOCK_Q` for backward pass on $D$
|
# We use fixed `BLOCK_Q` for backward pass on $D$
|
||||||
@ -288,7 +286,7 @@ def _get_autotune_configs(inner_loop: str) -> list:
|
|||||||
@triton.autotune(_get_autotune_configs(inner_loop='key'),
|
@triton.autotune(_get_autotune_configs(inner_loop='key'),
|
||||||
key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
|
key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
|
||||||
n_groups: tl.constexpr,
|
n_groups: tl.constexpr,
|
||||||
q_seq_len: tl.constexpr,
|
q_seq_len: tl.constexpr,
|
||||||
kv_seq_len: tl.constexpr,
|
kv_seq_len: tl.constexpr,
|
||||||
@ -359,11 +357,6 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
# Mask for $Q$ for the last block
|
# Mask for $Q$ for the last block
|
||||||
i_mask = offs_i < q_seq_len
|
i_mask = offs_i < q_seq_len
|
||||||
|
|
||||||
# Precalculate $\frac{\sigma}{\log_2 e}$.
|
|
||||||
#
|
|
||||||
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log_2 e$ instead.
|
|
||||||
sm_scale_log2e = sm_scale * 1.44269504
|
|
||||||
|
|
||||||
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
|
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
|
||||||
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
|
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
|
||||||
#
|
#
|
||||||
@ -762,9 +755,6 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
|||||||
BLOCK_Q: tl.constexpr,
|
BLOCK_Q: tl.constexpr,
|
||||||
BLOCK_K: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
):
|
):
|
||||||
# $\log_e 2$
|
|
||||||
LN2: tl.constexpr = 0.6931471824645996 # type: ignore
|
|
||||||
|
|
||||||
i = tl.program_id(0) * BLOCK_Q
|
i = tl.program_id(0) * BLOCK_Q
|
||||||
z = tl.program_id(1) // n_groups
|
z = tl.program_id(1) // n_groups
|
||||||
g = tl.program_id(1) % n_groups # TODO
|
g = tl.program_id(1) % n_groups # TODO
|
||||||
@ -859,7 +849,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
|||||||
)
|
)
|
||||||
|
|
||||||
# `b_dq` stores $(\log_2 e)dQ$ so multiply by $\log_e 2$ to get $dQ$
|
# `b_dq` stores $(\log_2 e)dQ$ so multiply by $\log_e 2$ to get $dQ$
|
||||||
b_dq *= LN2
|
b_dq *= 0.6931471824645996
|
||||||
|
|
||||||
# Save $dQ$
|
# Save $dQ$
|
||||||
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
|
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
|
||||||
|
Reference in New Issue
Block a user