mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 10:18:50 +08:00
flash comments
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -17,3 +17,5 @@ diagrams/
|
||||
.comet.config
|
||||
settings.md
|
||||
labml_app.log
|
||||
/extensions
|
||||
/.nb_editor
|
||||
File diff suppressed because one or more lines are too long
@ -6,39 +6,54 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
import torch
|
||||
from typing import Any, Tuple, Optional
|
||||
|
||||
HI_PRES_TL: tl.constexpr = tl.float32
|
||||
HI_PRES_TORCH: tl.constexpr = torch.float32
|
||||
HI_PRES_TORCH: torch.dtype = torch.float32
|
||||
|
||||
|
||||
class AttentionFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
# Shape batch size, n_heads, seq, d
|
||||
def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
causal: bool, sm_scale: float) -> torch.Tensor:
|
||||
"""
|
||||
Group query attention forward pass. Returns the output in shape `[batch_size, n_heads, q_seq_len, d_head]`.
|
||||
|
||||
:param ctx: is the context for torch gradient descent
|
||||
:param q: has shape `[batch_size, n_heads, q_seq_len, d_head]`
|
||||
:param q: has shape `[batch_size, n_heads, q_seq_len, d_head]`
|
||||
:param k: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
|
||||
:param v: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
|
||||
:param causal: whether to apply causal attention mask
|
||||
:param sm_scale: softmax scale factor
|
||||
"""
|
||||
batch_size, n_heads, q_seq_len, d_head = q.shape
|
||||
k_heads = k.shape[1]
|
||||
kv_seq_len = k.shape[2]
|
||||
_, k_heads, kv_seq_len, _ = k.shape
|
||||
assert n_heads % k_heads == 0
|
||||
n_groups = n_heads // k_heads
|
||||
|
||||
# shape constraints
|
||||
# Shape constraints
|
||||
assert d_head == k.shape[-1] == v.shape[-1]
|
||||
assert d_head in {16, 32, 64, 128, 256}
|
||||
|
||||
# Change the tensors combining the heads with the batch dimension
|
||||
q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
|
||||
k = k.view(batch_size * k_heads, kv_seq_len, d_head)
|
||||
v = v.view(batch_size * k_heads, kv_seq_len, d_head)
|
||||
|
||||
# Make sure the tensors are contiguous and the strides are same
|
||||
assert q.is_contiguous()
|
||||
assert k.is_contiguous()
|
||||
assert v.is_contiguous()
|
||||
assert k.stride() == v.stride()
|
||||
|
||||
# Tensor for the output
|
||||
o = torch.empty_like(q)
|
||||
|
||||
# Tensor for $\log_2 \sum_j e^{S_{ij}}$
|
||||
lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
|
||||
|
||||
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_M`
|
||||
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups, 1)
|
||||
ctx.grid = grid
|
||||
_attn_fwd[grid](
|
||||
q, k, v, sm_scale, lse, o,
|
||||
n_groups=n_groups,
|
||||
@ -48,41 +63,61 @@ class AttentionFunc(torch.autograd.Function):
|
||||
is_causal=causal,
|
||||
)
|
||||
|
||||
# Save the reshaped inputs and outputs for the backward pass
|
||||
ctx.save_for_backward(q, k, v, o, lse)
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.n_groups = n_groups
|
||||
ctx.d_head = d_head
|
||||
ctx.causal = causal
|
||||
|
||||
# Return the output in shape `[batch_size, n_heads, q_seq_len, d_head]`
|
||||
return o.view(batch_size, n_heads, q_seq_len, d_head)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
|
||||
"""
|
||||
The backward pass computes the gradients of the input tensors.
|
||||
|
||||
:param ctx: is the context for torch gradient descent
|
||||
:param do: is the gradient tensor of the attention output with shape `[batch_size, n_heads, q_seq_len, d_head]`
|
||||
"""
|
||||
|
||||
# Get saved tensors and attributes
|
||||
n_groups = ctx.n_groups
|
||||
sm_scale = ctx.sm_scale
|
||||
causal = ctx.causal
|
||||
q, k, v, o, lse = ctx.saved_tensors
|
||||
|
||||
# Get shapes
|
||||
batch_size, n_heads, q_seq_len, d_head = do.shape
|
||||
_, kv_seq_len, _ = k.shape
|
||||
k_heads = n_heads // n_groups
|
||||
|
||||
# Combine the heads with the batch dimension of the output gradients tensor
|
||||
do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
|
||||
|
||||
# Make sure it's contiguous and the strides are the same
|
||||
assert do.is_contiguous()
|
||||
assert k.stride() == v.stride()
|
||||
assert q.stride() == o.stride() == do.stride()
|
||||
|
||||
# Create tensors for input gradients
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
|
||||
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
||||
arg_k = k * (sm_scale * RCP_LN2)
|
||||
BLOCK_M = 16
|
||||
assert q_seq_len % BLOCK_M == 0
|
||||
pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
|
||||
# $\frac{1}{\log_e 2}$
|
||||
RCP_LN2 = 1.4426950408889634
|
||||
# Multiply $k$ by softmax scale
|
||||
k_scaled = k * (sm_scale * RCP_LN2)
|
||||
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
||||
pdp = torch.empty_like(lse)
|
||||
# We use fixed `BLOCK_M` for backward pass on $D$
|
||||
BLOCK_M = 16
|
||||
assert q_seq_len % BLOCK_M == 0
|
||||
# Compute $D_i$
|
||||
#
|
||||
# This is parallelized along the batch and query in blocks of size `BLOCK_M`
|
||||
pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
|
||||
_attn_bwd_d[pre_grid](
|
||||
o, do,
|
||||
pdp,
|
||||
@ -92,34 +127,42 @@ class AttentionFunc(torch.autograd.Function):
|
||||
n_groups=n_groups,
|
||||
num_stages=1,
|
||||
)
|
||||
# Compute $dK$ and $dV$
|
||||
#
|
||||
# This is parallelized along the batch and keys in blocks of size `BLOCK_N`
|
||||
grid = lambda args: (triton.cdiv(kv_seq_len, args['BLOCK_N']), batch_size * k_heads)
|
||||
_attn_bwd_dkdv[grid](
|
||||
q, arg_k, v, sm_scale, do, dk, dv,
|
||||
q, k_scaled, v, sm_scale, do, dk, dv,
|
||||
lse, pdp,
|
||||
q_seq_len, kv_seq_len, n_groups, d_head,
|
||||
is_causal=causal,
|
||||
|
||||
)
|
||||
# Compute $dQ$
|
||||
#
|
||||
# This is parallelized along the batch and queries in blocks of size `BLOCK_M`
|
||||
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups)
|
||||
_attn_bwd_dq[grid](
|
||||
q, arg_k, v, do,
|
||||
q, k_scaled, v, do,
|
||||
dq,
|
||||
lse, pdp,
|
||||
q_seq_len, kv_seq_len, n_groups, d_head,
|
||||
is_causal=causal,
|
||||
)
|
||||
|
||||
# Split the combined batch and heads
|
||||
dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
|
||||
dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
|
||||
dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
|
||||
|
||||
#
|
||||
return dq, dk, dv, None, None
|
||||
|
||||
|
||||
attention = AttentionFunc.apply
|
||||
|
||||
|
||||
def _get_autotune_configs(inner_loop: str):
|
||||
def _get_autotune_configs(inner_loop: str) -> list:
|
||||
"""
|
||||
#### Configs for auto-tuning
|
||||
"""
|
||||
@ -176,15 +219,15 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
Stride `n` denote the stride on `seq_len` of key.
|
||||
|
||||
"""
|
||||
start_m = tl.program_id(0)
|
||||
i = tl.program_id(0)
|
||||
z = tl.program_id(1) // n_groups
|
||||
g = tl.program_id(1) % n_groups
|
||||
|
||||
# block pointers
|
||||
# Create block pointers
|
||||
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(q_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(start_m * BLOCK_M, 0),
|
||||
(i * BLOCK_M, 0),
|
||||
(BLOCK_M, d_head),
|
||||
(1, 0))
|
||||
p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
|
||||
@ -202,19 +245,19 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(q_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(start_m * BLOCK_M, 0),
|
||||
(i * BLOCK_M, 0),
|
||||
(BLOCK_M, d_head),
|
||||
(1, 0))
|
||||
p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
|
||||
(q_seq_len,),
|
||||
(1,),
|
||||
(start_m * BLOCK_M,),
|
||||
(i * BLOCK_M,),
|
||||
(BLOCK_M,),
|
||||
(0,))
|
||||
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# Initialize offsets
|
||||
offs_i = i * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_j = tl.arange(0, BLOCK_N)
|
||||
|
||||
# Initialize $m_i$ and $l_i$
|
||||
b_m = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) - float("inf")
|
||||
@ -228,21 +271,22 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
b_q = tl.load(p_q)
|
||||
|
||||
if is_causal:
|
||||
# Run for ranges
|
||||
# Upto the diagonal block
|
||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
p_kT, p_v,
|
||||
sm_scale,
|
||||
BLOCK_M, d_head, BLOCK_N,
|
||||
offs_m, offs_n,
|
||||
offs_i, offs_j,
|
||||
start_n=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=(start_m * BLOCK_M) // BLOCK_N,
|
||||
steps=(i * BLOCK_M) // BLOCK_N,
|
||||
MASK=False,
|
||||
)
|
||||
# Diagonal block with masking within it
|
||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale,
|
||||
BLOCK_M, d_head, BLOCK_N,
|
||||
offs_m, offs_n,
|
||||
start_n=start_m * BLOCK_M,
|
||||
offs_i, offs_j,
|
||||
start_n=i * BLOCK_M,
|
||||
steps=BLOCK_M // BLOCK_N,
|
||||
MASK=True,
|
||||
)
|
||||
@ -250,7 +294,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale,
|
||||
BLOCK_M, d_head, BLOCK_N,
|
||||
offs_m, offs_n,
|
||||
offs_i, offs_j,
|
||||
start_n=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=kv_seq_len // BLOCK_N,
|
||||
MASK=False,
|
||||
@ -308,9 +352,10 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
b_p = b_p.to(b_q.dtype)
|
||||
b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
|
||||
|
||||
# update m_i and l_i
|
||||
# update $m_i$
|
||||
b_m = b_m_new
|
||||
|
||||
# Move pointers
|
||||
start_n += BLOCK_N
|
||||
p_v = tl.advance(p_v, (BLOCK_N, 0))
|
||||
p_kT = tl.advance(p_kT, (0, BLOCK_N))
|
||||
@ -327,24 +372,26 @@ def _attn_bwd_d(t_o, t_do,
|
||||
q_seq_len: tl.constexpr,
|
||||
n_groups: tl.constexpr,
|
||||
):
|
||||
m = tl.program_id(0) * BLOCK_M
|
||||
i = tl.program_id(0) * BLOCK_M
|
||||
z = tl.program_id(1)
|
||||
|
||||
# Create block pointers
|
||||
p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
|
||||
(n_groups, q_seq_len, d_head),
|
||||
(q_seq_len * d_head, d_head, 1),
|
||||
(0, m, 0),
|
||||
(0, i, 0),
|
||||
(n_groups, BLOCK_M, d_head),
|
||||
(2, 1, 0))
|
||||
p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
|
||||
(n_groups, q_seq_len, d_head),
|
||||
(q_seq_len * d_head, d_head, 1),
|
||||
(0, m, 0),
|
||||
(0, i, 0),
|
||||
(n_groups, BLOCK_M, d_head),
|
||||
(2, 1, 0))
|
||||
p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
|
||||
(n_groups, q_seq_len),
|
||||
(q_seq_len, 1),
|
||||
(0, m),
|
||||
(0, i),
|
||||
(n_groups, BLOCK_M),
|
||||
(1, 0))
|
||||
|
||||
@ -406,7 +453,9 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
b_k = tl.load(p_k)
|
||||
b_v = tl.load(p_v)
|
||||
|
||||
# Iterate through queries that attend to save keys
|
||||
for g in range(n_groups):
|
||||
# Create block pointers
|
||||
p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(d_head, q_seq_len),
|
||||
(1, d_head),
|
||||
@ -575,6 +624,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
z = tl.program_id(1) // n_groups
|
||||
g = tl.program_id(1) % n_groups
|
||||
|
||||
# Create block pointers
|
||||
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(q_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
|
||||
Reference in New Issue
Block a user