flash comments

This commit is contained in:
Varuna Jayasiri
2025-07-31 09:53:14 +05:30
parent c4d2e8cd22
commit 1bc2a69803
3 changed files with 1235 additions and 871 deletions

4
.gitignore vendored
View File

@ -16,4 +16,6 @@ html/
diagrams/
.comet.config
settings.md
labml_app.log
labml_app.log
/extensions
/.nb_editor

File diff suppressed because one or more lines are too long

View File

@ -6,39 +6,54 @@ import triton
import triton.language as tl
import torch
from typing import Any, Tuple, Optional
HI_PRES_TL: tl.constexpr = tl.float32
HI_PRES_TORCH: tl.constexpr = torch.float32
HI_PRES_TORCH: torch.dtype = torch.float32
class AttentionFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# Shape batch size, n_heads, seq, d
def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
causal: bool, sm_scale: float) -> torch.Tensor:
"""
Group query attention forward pass. Returns the output in shape `[batch_size, n_heads, q_seq_len, d_head]`.
:param ctx: is the context for torch gradient descent
:param q: has shape `[batch_size, n_heads, q_seq_len, d_head]`
:param q: has shape `[batch_size, n_heads, q_seq_len, d_head]`
:param k: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
:param v: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
:param causal: whether to apply causal attention mask
:param sm_scale: softmax scale factor
"""
batch_size, n_heads, q_seq_len, d_head = q.shape
k_heads = k.shape[1]
kv_seq_len = k.shape[2]
_, k_heads, kv_seq_len, _ = k.shape
assert n_heads % k_heads == 0
n_groups = n_heads // k_heads
# shape constraints
# Shape constraints
assert d_head == k.shape[-1] == v.shape[-1]
assert d_head in {16, 32, 64, 128, 256}
# Change the tensors combining the heads with the batch dimension
q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
k = k.view(batch_size * k_heads, kv_seq_len, d_head)
v = v.view(batch_size * k_heads, kv_seq_len, d_head)
# Make sure the tensors are contiguous and the strides are same
assert q.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert k.stride() == v.stride()
# Tensor for the output
o = torch.empty_like(q)
# Tensor for $\log_2 \sum_j e^{S_{ij}}$
lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_M`
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups, 1)
ctx.grid = grid
_attn_fwd[grid](
q, k, v, sm_scale, lse, o,
n_groups=n_groups,
@ -48,41 +63,61 @@ class AttentionFunc(torch.autograd.Function):
is_causal=causal,
)
# Save the reshaped inputs and outputs for the backward pass
ctx.save_for_backward(q, k, v, o, lse)
ctx.sm_scale = sm_scale
ctx.n_groups = n_groups
ctx.d_head = d_head
ctx.causal = causal
# Return the output in shape `[batch_size, n_heads, q_seq_len, d_head]`
return o.view(batch_size, n_heads, q_seq_len, d_head)
@staticmethod
def backward(ctx, do):
def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
"""
The backward pass computes the gradients of the input tensors.
:param ctx: is the context for torch gradient descent
:param do: is the gradient tensor of the attention output with shape `[batch_size, n_heads, q_seq_len, d_head]`
"""
# Get saved tensors and attributes
n_groups = ctx.n_groups
sm_scale = ctx.sm_scale
causal = ctx.causal
q, k, v, o, lse = ctx.saved_tensors
# Get shapes
batch_size, n_heads, q_seq_len, d_head = do.shape
_, kv_seq_len, _ = k.shape
k_heads = n_heads // n_groups
# Combine the heads with the batch dimension of the output gradients tensor
do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
# Make sure it's contiguous and the strides are the same
assert do.is_contiguous()
assert k.stride() == v.stride()
assert q.stride() == o.stride() == do.stride()
# Create tensors for input gradients
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k * (sm_scale * RCP_LN2)
BLOCK_M = 16
assert q_seq_len % BLOCK_M == 0
pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
# $\frac{1}{\log_e 2}$
RCP_LN2 = 1.4426950408889634
# Multiply $k$ by softmax scale
k_scaled = k * (sm_scale * RCP_LN2)
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
pdp = torch.empty_like(lse)
# We use fixed `BLOCK_M` for backward pass on $D$
BLOCK_M = 16
assert q_seq_len % BLOCK_M == 0
# Compute $D_i$
#
# This is parallelized along the batch and query in blocks of size `BLOCK_M`
pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
_attn_bwd_d[pre_grid](
o, do,
pdp,
@ -92,34 +127,42 @@ class AttentionFunc(torch.autograd.Function):
n_groups=n_groups,
num_stages=1,
)
# Compute $dK$ and $dV$
#
# This is parallelized along the batch and keys in blocks of size `BLOCK_N`
grid = lambda args: (triton.cdiv(kv_seq_len, args['BLOCK_N']), batch_size * k_heads)
_attn_bwd_dkdv[grid](
q, arg_k, v, sm_scale, do, dk, dv,
q, k_scaled, v, sm_scale, do, dk, dv,
lse, pdp,
q_seq_len, kv_seq_len, n_groups, d_head,
is_causal=causal,
)
# Compute $dQ$
#
# This is parallelized along the batch and queries in blocks of size `BLOCK_M`
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups)
_attn_bwd_dq[grid](
q, arg_k, v, do,
q, k_scaled, v, do,
dq,
lse, pdp,
q_seq_len, kv_seq_len, n_groups, d_head,
is_causal=causal,
)
# Split the combined batch and heads
dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
#
return dq, dk, dv, None, None
attention = AttentionFunc.apply
def _get_autotune_configs(inner_loop: str):
def _get_autotune_configs(inner_loop: str) -> list:
"""
#### Configs for auto-tuning
"""
@ -176,15 +219,15 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
Stride `n` denote the stride on `seq_len` of key.
"""
start_m = tl.program_id(0)
i = tl.program_id(0)
z = tl.program_id(1) // n_groups
g = tl.program_id(1) % n_groups
# block pointers
# Create block pointers
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
(q_seq_len, d_head),
(d_head, 1),
(start_m * BLOCK_M, 0),
(i * BLOCK_M, 0),
(BLOCK_M, d_head),
(1, 0))
p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
@ -202,19 +245,19 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
(q_seq_len, d_head),
(d_head, 1),
(start_m * BLOCK_M, 0),
(i * BLOCK_M, 0),
(BLOCK_M, d_head),
(1, 0))
p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
(q_seq_len,),
(1,),
(start_m * BLOCK_M,),
(i * BLOCK_M,),
(BLOCK_M,),
(0,))
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# Initialize offsets
offs_i = i * BLOCK_M + tl.arange(0, BLOCK_M)
offs_j = tl.arange(0, BLOCK_N)
# Initialize $m_i$ and $l_i$
b_m = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) - float("inf")
@ -228,21 +271,22 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
b_q = tl.load(p_q)
if is_causal:
# Run for ranges
# Upto the diagonal block
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
p_kT, p_v,
sm_scale,
BLOCK_M, d_head, BLOCK_N,
offs_m, offs_n,
offs_i, offs_j,
start_n=tl.full([], 0, tl.int32), # type: ignore
steps=(start_m * BLOCK_M) // BLOCK_N,
steps=(i * BLOCK_M) // BLOCK_N,
MASK=False,
)
# Diagonal block with masking within it
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
sm_scale,
BLOCK_M, d_head, BLOCK_N,
offs_m, offs_n,
start_n=start_m * BLOCK_M,
offs_i, offs_j,
start_n=i * BLOCK_M,
steps=BLOCK_M // BLOCK_N,
MASK=True,
)
@ -250,7 +294,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
sm_scale,
BLOCK_M, d_head, BLOCK_N,
offs_m, offs_n,
offs_i, offs_j,
start_n=tl.full([], 0, tl.int32), # type: ignore
steps=kv_seq_len // BLOCK_N,
MASK=False,
@ -308,9 +352,10 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
b_p = b_p.to(b_q.dtype)
b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
# update m_i and l_i
# update $m_i$
b_m = b_m_new
# Move pointers
start_n += BLOCK_N
p_v = tl.advance(p_v, (BLOCK_N, 0))
p_kT = tl.advance(p_kT, (0, BLOCK_N))
@ -327,24 +372,26 @@ def _attn_bwd_d(t_o, t_do,
q_seq_len: tl.constexpr,
n_groups: tl.constexpr,
):
m = tl.program_id(0) * BLOCK_M
i = tl.program_id(0) * BLOCK_M
z = tl.program_id(1)
# Create block pointers
p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
(n_groups, q_seq_len, d_head),
(q_seq_len * d_head, d_head, 1),
(0, m, 0),
(0, i, 0),
(n_groups, BLOCK_M, d_head),
(2, 1, 0))
p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
(n_groups, q_seq_len, d_head),
(q_seq_len * d_head, d_head, 1),
(0, m, 0),
(0, i, 0),
(n_groups, BLOCK_M, d_head),
(2, 1, 0))
p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
(n_groups, q_seq_len),
(q_seq_len, 1),
(0, m),
(0, i),
(n_groups, BLOCK_M),
(1, 0))
@ -406,7 +453,9 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
b_k = tl.load(p_k)
b_v = tl.load(p_v)
# Iterate through queries that attend to save keys
for g in range(n_groups):
# Create block pointers
p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
(d_head, q_seq_len),
(1, d_head),
@ -575,6 +624,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
z = tl.program_id(1) // n_groups
g = tl.program_id(1) % n_groups
# Create block pointers
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
(q_seq_len, d_head),
(d_head, 1),
@ -709,4 +759,4 @@ def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
p_vT = tl.advance(p_vT, (0, BLOCK_N))
# Return accumulated $dq$
return b_dq
return b_dq