From 9262c57f181a52130a64f65bc204fb5b3470f0fd Mon Sep 17 00:00:00 2001
From: Varuna Jayasiri Flash attention speeds up transformer attention mechanism by reducing the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. It's introduced in paper FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness and further optimized in paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. Official CUDA implementation can be found at Dao-AILab/flash-attention. Our implementation is based on the Triton's example implementation. Note: You can click on the mathematical symbols or identifiers to highlight them. You can run test.py to see correctness and measure performance of this implementation. You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys: Finally you can compute,
- To make it numerically stable flash attention subtracts the current max of before exponentiating. Here's the attention forward pass. The formulas represent a single attention head. is query vector (row vector) at position and and are the key and value row vectors at position . is the output vector at position . is the attention score matrix before softmax, is the softmax denominator, and is the attention matrix after softmax. You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys: Finally you can compute,
+ To make it numerically stable flash attention subtracts the current max of before exponentiating. So it maintains the following while iterating over keys: For each block of keys it updates them: Then finally,
+ For each block of keys it updates them: Then finally,
+ This reduces the memory usage since we don't have to compute full matrix or matrix. It also speeds up since we don't have to load these large matrices. Instead it only loads blocks of and as it iterates over them. where is when and otherwise. Flash attention paper introduces to simplify computation. Then, Note: , , , etc are row vectors. Here's the standard backward pass. is the gradient vector on the output where is when and otherwise. Flash attention paper introduces to simplify computation. Then, Flash attention saves from the forward pass since it doesn't take much memory. So during the backward pass it doesn't have to keep computing or . It first computes . Then it iterates over the queries and compute (accumulate) and . Finally it iterates over the keys and compute (accumulate) . In both forward and backward pass we calculate logarithms and exponentials of instead of for performance. Group query attention forward pass. Returns the output in shape Group query attention forward pass. Returns the output in shape Tensor for log of sum of exponentials Tensor for log of sum of exponentials The backward pass computes the gradients of the input tensors. The backward pass computes the gradients of the input tensors. We use fixed Paper Implementations
✨ Transformers
Flash Attention
+Forward pass
-Flash Attention Optimization
+
-
+Backward pass
-101from typing import Any, Tuple
-102
-103import torch
-104import triton
-105import triton.language as tl
-106
-107HI_PRES_TL: tl.constexpr = tl.float32
-108HI_PRES_TORCH: torch.dtype = torch.float32
148from typing import Any, Tuple
+149
+150import torch
+151import triton
+152import triton.language as tl
+153
+154HI_PRES_TL: tl.constexpr = tl.float32
+155HI_PRES_TORCH: torch.dtype = torch.float32
111class AttentionFunc(torch.autograd.Function):
158class AttentionFunc(torch.autograd.Function):
[batch_size, n_heads, q_seq_len, d_head]
+ Forward pass
+[batch_size, n_heads, q_seq_len, d_head]
.
+ softmax scale factor
ctx
is the context for torch gradient descent causal
whether to apply causal attention mask sm_scale
- softmax scale factor112 @staticmethod
-113 def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
-114 causal: bool, sm_scale: float) -> torch.Tensor:
159 @staticmethod
+160 def forward(ctx: Any,
+161 q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+162 causal: bool, sm_scale: float) -> torch.Tensor:
126 batch_size, n_heads, q_seq_len, d_head = q.shape
-127 _, k_heads, kv_seq_len, _ = k.shape
-128 assert n_heads % k_heads == 0
-129 n_groups = n_heads // k_heads
176 batch_size, n_heads, q_seq_len, d_head = q.shape
+177 _, k_heads, kv_seq_len, _ = k.shape
+178 assert n_heads % k_heads == 0
+179 n_groups = n_heads // k_heads
132 assert d_head == k.shape[-1] == v.shape[-1]
-133 assert d_head in {16, 32, 64, 128, 256}
182 assert d_head == k.shape[-1] == v.shape[-1]
+183 assert d_head in {16, 32, 64, 128, 256}
136 q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
-137 k = k.view(batch_size * k_heads, kv_seq_len, d_head)
-138 v = v.view(batch_size * k_heads, kv_seq_len, d_head)
186 q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
+187 k = k.view(batch_size * k_heads, kv_seq_len, d_head)
+188 v = v.view(batch_size * k_heads, kv_seq_len, d_head)
141 assert q.is_contiguous()
-142 assert k.is_contiguous()
-143 assert v.is_contiguous()
-144 assert k.stride() == v.stride()
191 assert q.is_contiguous()
+192 assert k.is_contiguous()
+193 assert v.is_contiguous()
+194 assert k.stride() == v.stride()
147 o = torch.empty_like(q)
197 o = torch.empty_like(q)
149 lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
199 lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
152 grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
-153 _attn_fwd[grid](
-154 q, k, v, sm_scale * 1.4426950408889634, lse, o,
-155 n_groups=n_groups,
-156 q_seq_len=q_seq_len,
-157 kv_seq_len=kv_seq_len,
-158 d_head=d_head,
-159 is_causal=causal,
-160 )
202 grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
+203 _attn_fwd[grid](
+204 q, k, v, sm_scale * 1.4426950408889634, lse, o,
+205 n_groups=n_groups,
+206 q_seq_len=q_seq_len,
+207 kv_seq_len=kv_seq_len,
+208 d_head=d_head,
+209 is_causal=causal,
+210 )
163 ctx.save_for_backward(q, k, v, o, lse)
-164 ctx.sm_scale = sm_scale
-165 ctx.n_groups = n_groups
-166 ctx.causal = causal
213 ctx.save_for_backward(q, k, v, o, lse)
+214 ctx.sm_scale = sm_scale
+215 ctx.n_groups = n_groups
+216 ctx.causal = causal
169 return o.view(batch_size, n_heads, q_seq_len, d_head)
219 return o.view(batch_size, n_heads, q_seq_len, d_head)
Backward pass
+ctx
is the context for torch gradient descent do
@@ -289,8 +304,8 @@
171 @staticmethod
-172 def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
221 @staticmethod
+222 def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
181 n_groups = ctx.n_groups
-182 sm_scale = ctx.sm_scale
-183 causal = ctx.causal
-184 q, k, v, o, lse = ctx.saved_tensors
233 n_groups = ctx.n_groups
+234 sm_scale = ctx.sm_scale
+235 causal = ctx.causal
+236 q, k, v, o, lse = ctx.saved_tensors
187 batch_size, n_heads, q_seq_len, d_head = do.shape
-188 _, kv_seq_len, _ = k.shape
-189 k_heads = n_heads // n_groups
239 batch_size, n_heads, q_seq_len, d_head = do.shape
+240 _, kv_seq_len, _ = k.shape
+241 k_heads = n_heads // n_groups
192 do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
244 do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
195 assert do.is_contiguous()
-196 assert k.stride() == v.stride()
-197 assert q.stride() == o.stride() == do.stride()
247 assert do.is_contiguous()
+248 assert k.stride() == v.stride()
+249 assert q.stride() == o.stride() == do.stride()
200 dq = torch.empty_like(q)
-201 dk = torch.empty_like(k)
-202 dv = torch.empty_like(v)
252 dq = torch.empty_like(q)
+253 dk = torch.empty_like(k)
+254 dv = torch.empty_like(v)
205 k_scaled = k * (sm_scale * 1.4426950408889634)
257 k_scaled = k * (sm_scale * 1.4426950408889634)
207 pdp = torch.empty_like(lse)
259 pdp = torch.empty_like(lse)
BLOCK_Q
- for backward pass on
Compute
+Compute
This is parallelized along the batch and query in blocks of size BLOCK_Q
213 BLOCK_Q = 16
-214 pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
-215 _attn_bwd_d[pre_grid](
-216 o, do,
-217 pdp,
-218 BLOCK_Q=16,
-219 d_head=d_head,
-220 q_seq_len=q_seq_len,
-221 n_groups=n_groups,
-222 num_stages=1,
-223 )265 BLOCK_Q = 16
+266 pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
+267 _attn_bwd_d[pre_grid](
+268 o, do,
+269 pdp,
+270 BLOCK_Q=16,
+271 d_head=d_head,
+272 q_seq_len=q_seq_len,
+273 n_groups=n_groups,
+274 num_stages=1,
+275 )Compute and
+Compute and
This is parallelized along the batch and keys in blocks of size BLOCK_K
228 grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
-229 _attn_bwd_dkdv[grid](
-230 q, k_scaled, v, sm_scale, do, dk, dv,
-231 lse, pdp,
-232 q_seq_len, kv_seq_len, n_groups, d_head,
-233 is_causal=causal,
-234
-235 )280 grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
+281 _attn_bwd_dkdv[grid](
+282 q, k_scaled, v, sm_scale, do, dk, dv,
+283 lse, pdp,
+284 q_seq_len, kv_seq_len, n_groups, d_head,
+285 is_causal=causal,
+286
+287 )Compute
+Compute
This is parallelized along the batch and queries in blocks of size BLOCK_Q
240 grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
-241 _attn_bwd_dq[grid](
-242 q, k_scaled, v, do,
-243 dq,
-244 lse, pdp,
-245 q_seq_len, kv_seq_len, n_groups, d_head,
-246 is_causal=causal,
-247 )292 grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
+293 _attn_bwd_dq[grid](
+294 q, k_scaled, v, do,
+295 dq,
+296 lse, pdp,
+297 q_seq_len, kv_seq_len, n_groups, d_head,
+298 is_causal=causal,
+299 )250 dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
-251 dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
-252 dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)302 dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
+303 dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
+304 dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)255 return dq, dk, dv, None, None
-256
-257
-258attention = AttentionFunc.apply307 return dq, dk, dv, None, None
+308
+309
+310attention = AttentionFunc.apply261def _get_autotune_configs(inner_loop: str) -> list:313def _get_autotune_configs(inner_loop: str) -> list:266 configs = []318 configs = []List possible BLOCK_Q and BLOCK_K that satisfy BLOCK_Q divisible by BLOCK_K and also try to cover a wide range
+Possible options for BLOCK_Q
+
269 for bm in [64, 128, 256]:321 for bq in [64, 128, 256]:We'll try bn in 16, 32, 64, 128 that are divisors and <= bm
+Possible options for BLOCK_K
+
271 for bn in [64, 128, 256]:
-272 if inner_loop == 'key' and bm % bn != 0:
-273 continue
-274 if inner_loop == 'query' and bn % bm != 0:
-275 continue
-276 for s in [2, 3, 4]:
-277 for w in [4, 8]:
-278 if bm * bn < 128 * 128 and w == 8:
-279 continue
-280
-281 configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
-282
-283 return configs[:1]323 for bk in [64, 128, 256]:t_q
- query If the inner loop is along keys the BLOCK_Q
+ must be a multiple of BLOCK_K
+ for causal masking
325 if inner_loop == 'key' and bq % bk != 0:
+326 continueSimilarly when the inner loop is along queries
+ +328 if inner_loop == 'query' and bk % bq != 0:
+329 continueNumber of stages and warps
+ +332 for s in [2, 3, 4]:
+333 for w in [4, 8]:
+334 if bq * bk < 128 * 128 and w == 8:
+335 continue
+336
+337 configs.append(triton.Config({'BLOCK_Q': bq, 'BLOCK_K': bk}, num_stages=s, num_warps=w))Use return configs
+ to autotune. Trying all combinations is slow for testing.
340 return configs[:1]t_q
+ queries t_k
- keys t_v
- values sm_scale
- softmax scale sm_scale_log2e
+ softmax scale multiplied by t_lse
- (out) t_o
- output (out) n_groups
- number of groups q_seq_len
query sequence length kv_seq_len
key/value sequence length d_head
- size of a head BLOCK_Q
block size for query sequence length BLOCK_K
@@ -590,105 +654,26 @@
and d
denote the stride of the corresponding dimensions (batch_size
, n_heads
-, seq_len
+, q_seq_len
, d_head
) in the query. Stride n
- denote the stride on seq_len
+ denote the stride on kv_seq_len
of key.
286@triton.autotune(_get_autotune_configs(inner_loop='key'),
-287 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-288@triton.jit
-289def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
-290 n_groups: tl.constexpr,
-291 q_seq_len: tl.constexpr,
-292 kv_seq_len: tl.constexpr,
-293 d_head: tl.constexpr,
-294 is_causal: tl.constexpr,
-295 BLOCK_Q: tl.constexpr, # q seq len block
-296 BLOCK_K: tl.constexpr, # k seq len block
-297 ):318 i = tl.program_id(0)
-319 z = tl.program_id(1) // n_groups
-320 g = tl.program_id(1) % n_groups # TODOCreate block pointers
- -323 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-324 (q_seq_len, d_head),
-325 (d_head, 1),
-326 (i * BLOCK_Q, 0),
-327 (BLOCK_Q, d_head),
-328 (1, 0))
-329 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-330 (kv_seq_len, d_head),
-331 (d_head, 1),
-332 (0, 0),
-333 (BLOCK_K, d_head),
-334 (1, 0))
-335 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-336 (d_head, kv_seq_len),
-337 (1, d_head),
-338 (0, 0),
-339 (d_head, BLOCK_K),
-340 (0, 1))
-341 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-342 (q_seq_len, d_head),
-343 (d_head, 1),
-344 (i * BLOCK_Q, 0),
-345 (BLOCK_Q, d_head),
-346 (1, 0))
-347 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-348 (q_seq_len,),
-349 (1,),
-350 (i * BLOCK_Q,),
-351 (BLOCK_Q,),
-352 (0,))Initialize offsets
- -355 offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
-356 offs_j = tl.arange(0, BLOCK_K)Mask for for the last block
- -358 i_mask = offs_i < q_seq_len343@triton.autotune(_get_autotune_configs(inner_loop='key'),
+344 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+345@triton.jit
+346def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
+347 n_groups: tl.constexpr,
+348 q_seq_len: tl.constexpr,
+349 kv_seq_len: tl.constexpr,
+350 d_head: tl.constexpr,
+351 is_causal: tl.constexpr,
+352 BLOCK_Q: tl.constexpr,
+353 BLOCK_K: tl.constexpr,
+354 ):Initialize and . is initialized to and to . So in the first update, the effect of initial is .
-b_m
- will be storing
We are computing the attention for for i
+ ... `i + BLOCK_Q' in batch/head combination .
364 b_m = tl.where(i_mask, -float("inf"), 0.0)
-365 b_l = tl.where(i_mask, 1.0, 0.0)378 i = tl.program_id(0)
+379 z = tl.program_id(1) // n_groups
+380 g = tl.program_id(1) % n_groups368 b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)383 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+384 (q_seq_len, d_head),
+385 (d_head, 1),
+386 (i * BLOCK_Q, 0),
+387 (BLOCK_Q, d_head),
+388 (1, 0))
+389 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+390 (kv_seq_len, d_head),
+391 (d_head, 1),
+392 (0, 0),
+393 (BLOCK_K, d_head),
+394 (1, 0))
+395 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+396 (d_head, kv_seq_len),
+397 (1, d_head),
+398 (0, 0),
+399 (d_head, BLOCK_K),
+400 (0, 1))
+401 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+402 (q_seq_len, d_head),
+403 (d_head, 1),
+404 (i * BLOCK_Q, 0),
+405 (BLOCK_Q, d_head),
+406 (1, 0))
+407 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+408 (q_seq_len,),
+409 (1,),
+410 (i * BLOCK_Q,),
+411 (BLOCK_Q,),
+412 (0,))Load outside the loop since it will be reused through out the loop over .
+Initialize offsets
371 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
-372
-373 if is_causal:415 offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
+416 offs_j = tl.arange(0, BLOCK_K)375 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
-376 p_kT, p_v,
-377 sm_scale_log2e,
-378 BLOCK_Q, d_head, BLOCK_K,
-379 offs_i, offs_j,
-380 j=tl.full([], 0, tl.int32), # type: ignore
-381 steps=(i * BLOCK_Q) // BLOCK_K,
-382 MASK=False,
-383 q_seq_len=q_seq_len,
-384 kv_seq_len=kv_seq_len
-385 )419 i_mask = offs_i < q_seq_lenDiagonal block with masking within it
+Initialize and . is initialized to and to . So in the first update, the effect of initial is .
+b_m
+ will be storing
387 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
-388 sm_scale_log2e,
-389 BLOCK_Q, d_head, BLOCK_K,
-390 offs_i, offs_j,
-391 j=i * BLOCK_Q,
-392 steps=BLOCK_Q // BLOCK_K,
-393 MASK=True,
-394 q_seq_len=q_seq_len,
-395 kv_seq_len=kv_seq_len
-396 )
-397 else:425 b_m = tl.where(i_mask, -float("inf"), 0.0)
+426 b_l = tl.where(i_mask, 1.0, 0.0)399 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
-400 sm_scale_log2e,
-401 BLOCK_Q, d_head, BLOCK_K,
-402 offs_i, offs_j,
-403 j=tl.full([], 0, tl.int32), # type: ignore
-404 steps=tl.cdiv(kv_seq_len, BLOCK_K),
-405 MASK=False,
-406 q_seq_len=q_seq_len,
-407 kv_seq_len=kv_seq_len
-408 )429 b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)Store LSE
+Load outside the loop since it will be reused through out the loop over .
411 tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))432 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
+433
+434 if is_causal:413 tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))436 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
+437 p_kT, p_v,
+438 sm_scale_log2e,
+439 BLOCK_Q, d_head, BLOCK_K,
+440 offs_i, offs_j,
+441 j=tl.full([], 0, tl.int32), # type: ignore
+442 steps=(i * BLOCK_Q) // BLOCK_K,
+443 MASK=False,
+444 q_seq_len=q_seq_len,
+445 kv_seq_len=kv_seq_len
+446 )Diagonal block with masking within it
+416@triton.jit
-417def _attn_fwd_inner(b_o, b_l, b_m, b_q,
-418 p_kT, p_v,
-419 sm_scale_log2e,
-420 BLOCK_Q: tl.constexpr,
-421 d_head: tl.constexpr,
-422 BLOCK_K: tl.constexpr,
-423 offs_i, offs_j,
-424 j,
-425 steps,
-426 MASK: tl.constexpr,
-427 q_seq_len: tl.constexpr,
-428 kv_seq_len: tl.constexpr
-429 ):
-430 tl.static_assert(BLOCK_Q % BLOCK_K == 0)448 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
+449 sm_scale_log2e,
+450 BLOCK_Q, d_head, BLOCK_K,
+451 offs_i, offs_j,
+452 j=i * BLOCK_Q,
+453 steps=BLOCK_Q // BLOCK_K,
+454 MASK=True,
+455 q_seq_len=q_seq_len,
+456 kv_seq_len=kv_seq_len
+457 )
+458 else:433 p_kT = tl.advance(p_kT, (0, j))
-434 p_v = tl.advance(p_v, (j, 0))460 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
+461 sm_scale_log2e,
+462 BLOCK_Q, d_head, BLOCK_K,
+463 offs_i, offs_j,
+464 j=tl.full([], 0, tl.int32), # type: ignore
+465 steps=tl.cdiv(kv_seq_len, BLOCK_K),
+466 MASK=False,
+467 q_seq_len=q_seq_len,
+468 kv_seq_len=kv_seq_len
+469 )437 for _ in range(steps):472 tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))439 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")474 tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))Compute
+This iterates through keys and values starting from j
+ for steps
+ number of steps. In each step it processes BLOCK_K
+ entries of keys/values.
441 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
-442 b_s = b_s * sm_scale_log2e477@triton.jit
+478def _attn_fwd_inner(b_o, b_l, b_m, b_q,
+479 p_kT, p_v,
+480 sm_scale_log2e,
+481 BLOCK_Q: tl.constexpr,
+482 d_head: tl.constexpr,
+483 BLOCK_K: tl.constexpr,
+484 offs_i, offs_j,
+485 j,
+486 steps,
+487 MASK: tl.constexpr,
+488 q_seq_len: tl.constexpr,
+489 kv_seq_len: tl.constexpr
+490 ):Apply causal mask
- +445 if MASK:
-446 causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
-447 b_s = tl.where(causal_mask, b_s, -float("inf"))497 tl.static_assert(BLOCK_Q % BLOCK_K == 0)450 j_mask = (j + offs_j) < kv_seq_len
-451 b_s = tl.where(j_mask[None, :], b_s, -float("inf"))500 p_kT = tl.advance(p_kT, (0, j))
+501 p_v = tl.advance(p_v, (j, 0))454 b_m_new = tl.maximum(b_m, tl.max(b_s, -1))504 for _ in range(steps):460 b_p = tl.math.exp2(b_s - b_m_new[:, None])506 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")463 b_l_new = tl.sum(b_p, -1)508 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
+509 b_s = b_s * sm_scale_log2e465 b_m_m_new = tl.math.exp2(b_m - b_m_new)512 if MASK:
+513 causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
+514 b_s = tl.where(causal_mask, b_s, -float("inf"))467 b_l = b_l * b_m_m_new + b_l_new517 j_mask = (j + offs_j) < kv_seq_len
+518 b_s = tl.where(j_mask[None, :], b_s, -float("inf"))470 b_o = b_o * b_m_m_new[:, None]
-471 b_p = b_p.to(b_q.dtype) # TODO
-472 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
-473 b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)521 b_m_new = tl.maximum(b_m, tl.max(b_s, -1))476 b_m = b_m_new527 b_p = tl.math.exp2(b_s - b_m_new[:, None])479 j += BLOCK_K
-480 p_v = tl.advance(p_v, (BLOCK_K, 0))
-481 p_kT = tl.advance(p_kT, (0, BLOCK_K))
-482
-483 tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
-484
-485 return b_o, b_l, b_m530 b_l_new = tl.sum(b_p, -1)+
488@triton.jit
-489def _attn_bwd_d(t_o, t_do,
-490 t_pdp,
-491 BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
-492 q_seq_len: tl.constexpr,
-493 n_groups: tl.constexpr,
-494 ):
-495 i = tl.program_id(0) * BLOCK_Q
-496 z = tl.program_id(1)532 b_m_m_new = tl.math.exp2(b_m - b_m_new)499 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
-500 (n_groups, q_seq_len, d_head),
-501 (q_seq_len * d_head, d_head, 1),
-502 (0, i, 0),
-503 (n_groups, BLOCK_Q, d_head),
-504 (2, 1, 0))
-505 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
-506 (n_groups, q_seq_len, d_head),
-507 (q_seq_len * d_head, d_head, 1),
-508 (0, i, 0),
-509 (n_groups, BLOCK_Q, d_head),
-510 (2, 1, 0))
-511 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
-512 (n_groups, q_seq_len),
-513 (q_seq_len, 1),
-514 (0, i),
-515 (n_groups, BLOCK_Q),
-516 (1, 0))534 b_l = b_l * b_m_m_new + b_l_new519 o = tl.load(p_o, boundary_check=(1,), padding_option="zero")537 b_o = b_o * b_m_m_new[:, None]
+538 b_p = b_p.to(b_q.dtype) # TODO
+539 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
+540 b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)521 do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)543 b_m = b_m_new523 d = tl.sum(o * do, axis=-1)546 j += BLOCK_K
+547 p_v = tl.advance(p_v, (BLOCK_K, 0))
+548 p_kT = tl.advance(p_kT, (0, BLOCK_K))
+549
+550 tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
+551
+552 return b_o, b_l, b_m525 tl.store(p_pdp, d, boundary_check=(1,))555@triton.jit
+556def _attn_bwd_d(t_o, t_do,
+557 t_pdp,
+558 BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
+559 q_seq_len: tl.constexpr,
+560 n_groups: tl.constexpr,
+561 ):Compute and for by iterating over
- +528@triton.autotune(_get_autotune_configs(inner_loop='query'),
-529 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-530@triton.jit
-531def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
-532 t_do,
-533 t_dk, t_dv,
-534 t_lse, t_pdp,
-535 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-536 n_groups: tl.constexpr, d_head: tl.constexpr,
-537 is_causal: tl.constexpr,
-538 BLOCK_Q: tl.constexpr,
-539 BLOCK_K: tl.constexpr,
-540 ):565 i = tl.program_id(0) * BLOCK_Q
+566 z = tl.program_id(1)Create block pointers
+545 j = tl.program_id(0) * BLOCK_K
-546 z = tl.program_id(1)569 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
+570 (n_groups, q_seq_len, d_head),
+571 (q_seq_len * d_head, d_head, 1),
+572 (0, i, 0),
+573 (n_groups, BLOCK_Q, d_head),
+574 (2, 1, 0))
+575 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
+576 (n_groups, q_seq_len, d_head),
+577 (q_seq_len * d_head, d_head, 1),
+578 (0, i, 0),
+579 (n_groups, BLOCK_Q, d_head),
+580 (2, 1, 0))
+581 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
+582 (n_groups, q_seq_len),
+583 (q_seq_len, 1),
+584 (0, i),
+585 (n_groups, BLOCK_Q),
+586 (1, 0))549 p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-550 (kv_seq_len, d_head),
-551 (d_head, 1),
-552 (j, 0),
-553 (BLOCK_K, d_head),
-554 (1, 0))
-555 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-556 (kv_seq_len, d_head),
-557 (d_head, 1),
-558 (j, 0),
-559 (BLOCK_K, d_head),
-560 (1, 0))
-561 p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
-562 (kv_seq_len, d_head),
-563 (d_head, 1),
-564 (j, 0),
-565 (BLOCK_K, d_head),
-566 (1, 0))
-567 p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
-568 (kv_seq_len, d_head),
-569 (d_head, 1),
-570 (j, 0),
-571 (BLOCK_K, d_head),
-572 (1, 0))589 o = tl.load(p_o, boundary_check=(1,), padding_option="zero")575 b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
-576 b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)591 do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)579 b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
-580 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")593 d = tl.sum(o * do, axis=-1)583 for g in range(n_groups):595 tl.store(p_pdp, d, boundary_check=(1,))585 p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-586 (d_head, q_seq_len),
-587 (1, d_head),
-588 (0, 0),
-589 (d_head, BLOCK_Q),
-590 (0, 1))
-591
-592 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-593 (q_seq_len, d_head),
-594 (d_head, 1),
-595 (0, 0),
-596 (BLOCK_Q, d_head),
-597 (1, 0))
-598 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-599 (q_seq_len,),
-600 (1,),
-601 (0,),
-602 (BLOCK_Q,),
-603 (0,))
-604 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-605 (q_seq_len,),
-606 (1,),
-607 (0,),
-608 (BLOCK_Q,),
-609 (0,))
-610
-611 if is_causal:598@triton.autotune(_get_autotune_configs(inner_loop='query'),
+599 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+600@triton.jit
+601def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
+602 t_do,
+603 t_dk, t_dv,
+604 t_lse, t_pdp,
+605 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+606 n_groups: tl.constexpr, d_head: tl.constexpr,
+607 is_causal: tl.constexpr,
+608 BLOCK_Q: tl.constexpr,
+609 BLOCK_K: tl.constexpr,
+610 ):Inner loop at the diagonal block
+Compute and for j
+ ... j + BLOCK_K
+ by iterating over
613 b_dk, b_dv = _attn_bwd_dkdv_inner(
-614 b_dk, b_dv,
-615 p_qT, b_k, b_v, p_do,
-616 p_lse, p_pdp,
-617 BLOCK_Q, BLOCK_K,
-618 d_head,
-619 j=j, i=j,
-620 steps=BLOCK_K // BLOCK_Q,
-621 MASK=True,
-622 q_seq_len=q_seq_len,
-623 kv_seq_len=kv_seq_len,
-624 )616 j = tl.program_id(0) * BLOCK_K
+617 z = tl.program_id(1)627 b_dk, b_dv = _attn_bwd_dkdv_inner(
-628 b_dk, b_dv,
-629 p_qT, b_k, b_v, p_do,
-630 p_lse, p_pdp,
-631 BLOCK_Q, BLOCK_K,
-632 d_head,
-633 j=j, i=j + BLOCK_K,
-634 steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
-635 MASK=False,
-636 q_seq_len=q_seq_len,
-637 kv_seq_len=kv_seq_len
-638 )
-639 else:620 p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+621 (kv_seq_len, d_head),
+622 (d_head, 1),
+623 (j, 0),
+624 (BLOCK_K, d_head),
+625 (1, 0))
+626 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+627 (kv_seq_len, d_head),
+628 (d_head, 1),
+629 (j, 0),
+630 (BLOCK_K, d_head),
+631 (1, 0))
+632 p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
+633 (kv_seq_len, d_head),
+634 (d_head, 1),
+635 (j, 0),
+636 (BLOCK_K, d_head),
+637 (1, 0))
+638 p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
+639 (kv_seq_len, d_head),
+640 (d_head, 1),
+641 (j, 0),
+642 (BLOCK_K, d_head),
+643 (1, 0))641 b_dk, b_dv = _attn_bwd_dkdv_inner(
-642 b_dk, b_dv,
-643 p_qT, b_k, b_v, p_do,
-644 p_lse, p_pdp,
-645 BLOCK_Q, BLOCK_K,
-646 d_head,
-647 j=j, i=tl.full([], 0, tl.int32),
-648 steps=tl.cdiv(q_seq_len, BLOCK_Q),
-649 MASK=False,
-650 q_seq_len=q_seq_len,
-651 kv_seq_len=kv_seq_len
-652 )646 b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
+647 b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)655 tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))650 b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
+651 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")658 b_dk *= sm_scale654 for g in range(n_groups):661 tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))656 p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+657 (d_head, q_seq_len),
+658 (1, d_head),
+659 (0, 0),
+660 (d_head, BLOCK_Q),
+661 (0, 1))
+662
+663 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+664 (q_seq_len, d_head),
+665 (d_head, 1),
+666 (0, 0),
+667 (BLOCK_Q, d_head),
+668 (1, 0))
+669 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+670 (q_seq_len,),
+671 (1,),
+672 (0,),
+673 (BLOCK_Q,),
+674 (0,))
+675 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+676 (q_seq_len,),
+677 (1,),
+678 (0,),
+679 (BLOCK_Q,),
+680 (0,))
+681
+682 if is_causal:664@triton.jit
-665def _attn_bwd_dkdv_inner(b_dk, b_dv,
-666 p_qT, b_k, b_v, p_do,
-667 p_lse, p_pdp,
-668 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
-669 d_head: tl.constexpr,
-670 j, i, steps,
-671 MASK: tl.constexpr,
-672 q_seq_len: tl.constexpr,
-673 kv_seq_len: tl.constexpr):684 b_dk, b_dv = _attn_bwd_dkdv_inner(
+685 b_dk, b_dv,
+686 p_qT, b_k, b_v, p_do,
+687 p_lse, p_pdp,
+688 BLOCK_Q, BLOCK_K,
+689 d_head,
+690 j=j, i=j,
+691 steps=BLOCK_K // BLOCK_Q,
+692 MASK=True,
+693 q_seq_len=q_seq_len,
+694 kv_seq_len=kv_seq_len,
+695 )677 tl.static_assert(BLOCK_K % BLOCK_Q == 0)698 b_dk, b_dv = _attn_bwd_dkdv_inner(
+699 b_dk, b_dv,
+700 p_qT, b_k, b_v, p_do,
+701 p_lse, p_pdp,
+702 BLOCK_Q, BLOCK_K,
+703 d_head,
+704 j=j, i=j + BLOCK_K,
+705 steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
+706 MASK=False,
+707 q_seq_len=q_seq_len,
+708 kv_seq_len=kv_seq_len
+709 )
+710 else:680 offs_i = i + tl.arange(0, BLOCK_Q)
-681 offs_j = j + tl.arange(0, BLOCK_K)712 b_dk, b_dv = _attn_bwd_dkdv_inner(
+713 b_dk, b_dv,
+714 p_qT, b_k, b_v, p_do,
+715 p_lse, p_pdp,
+716 BLOCK_Q, BLOCK_K,
+717 d_head,
+718 j=j, i=tl.full([], 0, tl.int32),
+719 steps=tl.cdiv(q_seq_len, BLOCK_Q),
+720 MASK=False,
+721 q_seq_len=q_seq_len,
+722 kv_seq_len=kv_seq_len
+723 )684 p_qT = tl.advance(p_qT, (0, i))
-685 p_do = tl.advance(p_do, (i, 0))
-686 p_lse = tl.advance(p_lse, (i,))
-687 p_pdp = tl.advance(p_pdp, (i,))726 tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))690 for _ in range(steps):729 b_dk *= sm_scale692 b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")732 tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))695 b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")735@triton.jit
+736def _attn_bwd_dkdv_inner(b_dk, b_dv,
+737 p_qT, b_k, b_v, p_do,
+738 p_lse, p_pdp,
+739 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
+740 d_head: tl.constexpr,
+741 j, i, steps,
+742 MASK: tl.constexpr,
+743 q_seq_len: tl.constexpr,
+744 kv_seq_len: tl.constexpr):698 b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)750 tl.static_assert(BLOCK_K % BLOCK_Q == 0)707 b_pT = tl.math.exp2(b_sT - b_l[None, :])753 offs_i = i + tl.arange(0, BLOCK_Q)
+754 offs_j = j + tl.arange(0, BLOCK_K)710 if MASK:
-711 mask = (offs_i[None, :] >= offs_j[:, None])
-712 b_pT = tl.where(mask, b_pT, 0.0)757 p_qT = tl.advance(p_qT, (0, i))
+758 p_do = tl.advance(p_do, (i, 0))
+759 p_lse = tl.advance(p_lse, (i,))
+760 p_pdp = tl.advance(p_pdp, (i,))Mask out if the block is beyond the end of
-Note: No need to mask out based on because the effects on positions outside boundary will not get stored in or Masking by may also not be necessary size the tensors have 0 on loading
+Iterate over
719 i_mask = offs_i < q_seq_len
-720 b_pT = tl.where(i_mask[None, :], b_pT, 0.0)763 for _ in range(steps):723 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
-724 b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)765 b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")727 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")768 b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")729 b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)771 b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)731 b_dsT = b_pT * (b_dpT - b_pdp[None, :])780 b_pT = tl.math.exp2(b_sT - b_l[None, :])733 b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)783 if MASK:
+784 mask = (offs_i[None, :] >= offs_j[:, None])
+785 b_pT = tl.where(mask, b_pT, 0.0)Increment pointers.
+Mask out if the block is beyond the end of
+Note: No need to mask out based on because the effects on positions outside boundary will not get stored in or Masking by may also not be necessary size the tensors have 0 on loading
736 offs_i += BLOCK_Q
-737 p_lse = tl.advance(p_lse, (BLOCK_Q,))
-738 p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
-739 p_qT = tl.advance(p_qT, (0, BLOCK_Q))
-740 p_do = tl.advance(p_do, (BLOCK_Q, 0))792 i_mask = offs_i < q_seq_len
+793 b_pT = tl.where(i_mask[None, :], b_pT, 0.0)743 return b_dk, b_dv796 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
+797 b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)+
746@triton.autotune(_get_autotune_configs(inner_loop='key'),
-747 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-748@triton.jit
-749def _attn_bwd_dq(t_q, t_k, t_v, t_do,
-750 t_dq,
-751 t_lse, t_pdp,
-752 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-753 n_groups: tl.constexpr, d_head: tl.constexpr,
-754 is_causal: tl.constexpr,
-755 BLOCK_Q: tl.constexpr,
-756 BLOCK_K: tl.constexpr,
-757 ):
-758 i = tl.program_id(0) * BLOCK_Q
-759 z = tl.program_id(1) // n_groups
-760 g = tl.program_id(1) % n_groups # TODO800 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")763 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-764 (q_seq_len, d_head),
-765 (d_head, 1),
-766 (i, 0),
-767 (BLOCK_Q, d_head),
-768 (1, 0))
-769 p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-770 (q_seq_len, d_head),
-771 (d_head, 1),
-772 (i, 0),
-773 (BLOCK_Q, d_head),
-774 (1, 0))
-775 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-776 (q_seq_len, d_head),
-777 (d_head, 1),
-778 (i, 0),
-779 (BLOCK_Q, d_head),
-780 (1, 0))
-781 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-782 (d_head, kv_seq_len),
-783 (1, d_head),
-784 (0, 0),
-785 (d_head, BLOCK_K),
-786 (0, 1))
-787 p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-788 (d_head, kv_seq_len),
-789 (1, d_head),
-790 (0, 0),
-791 (d_head, BLOCK_K),
-792 (0, 1))
-793 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-794 (q_seq_len,),
-795 (1,),
-796 (i,),
-797 (BLOCK_Q,),
-798 (0,))
-799 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-800 (q_seq_len,),
-801 (1,),
-802 (i,),
-803 (BLOCK_Q,),
-804 (0,))802 b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)807 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
-808 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
-809 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
-810 b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")804 b_dsT = b_pT * (b_dpT - b_pdp[None, :])813 b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)806 b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)817 if is_causal:809 offs_i += BLOCK_Q
+810 p_lse = tl.advance(p_lse, (BLOCK_Q,))
+811 p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
+812 p_qT = tl.advance(p_qT, (0, BLOCK_Q))
+813 p_do = tl.advance(p_do, (BLOCK_Q, 0))819 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-820 b_do, b_lse, b_pdp,
-821 BLOCK_Q, BLOCK_K,
-822 i=i, j=i,
-823 steps=BLOCK_Q // BLOCK_K,
-824 MASK=True,
-825 q_seq_len=q_seq_len,
-826 kv_seq_len=kv_seq_len
-827 )816 return b_dk, b_dv830 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-831 b_do, b_lse, b_pdp,
-832 BLOCK_Q, BLOCK_K,
-833 i=i, j=tl.full([], 0, tl.int32), # type: ignore
-834 steps=i // BLOCK_K,
-835 MASK=False,
-836 q_seq_len=q_seq_len,
-837 kv_seq_len=kv_seq_len
-838 )
-839 else:819@triton.autotune(_get_autotune_configs(inner_loop='key'),
+820 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+821@triton.jit
+822def _attn_bwd_dq(t_q, t_k, t_v, t_do,
+823 t_dq,
+824 t_lse, t_pdp,
+825 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+826 n_groups: tl.constexpr, d_head: tl.constexpr,
+827 is_causal: tl.constexpr,
+828 BLOCK_Q: tl.constexpr,
+829 BLOCK_K: tl.constexpr,
+830 ):Iterate through all
- +841 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-842 b_do, b_lse, b_pdp,
-843 BLOCK_Q, BLOCK_K,
-844 i=i, j=tl.full([], 0, tl.int32), # type: ignore
-845 steps=tl.cdiv(kv_seq_len, BLOCK_K),
-846 MASK=False,
-847 q_seq_len=q_seq_len,
-848 kv_seq_len=kv_seq_len
-849 )835 i = tl.program_id(0) * BLOCK_Q
+836 z = tl.program_id(1) // n_groups
+837 g = tl.program_id(1) % n_groups # TODO852 b_dq *= 0.6931471824645996840 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+841 (q_seq_len, d_head),
+842 (d_head, 1),
+843 (i, 0),
+844 (BLOCK_Q, d_head),
+845 (1, 0))
+846 p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+847 (q_seq_len, d_head),
+848 (d_head, 1),
+849 (i, 0),
+850 (BLOCK_Q, d_head),
+851 (1, 0))
+852 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+853 (q_seq_len, d_head),
+854 (d_head, 1),
+855 (i, 0),
+856 (BLOCK_Q, d_head),
+857 (1, 0))
+858 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+859 (d_head, kv_seq_len),
+860 (1, d_head),
+861 (0, 0),
+862 (d_head, BLOCK_K),
+863 (0, 1))
+864 p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+865 (d_head, kv_seq_len),
+866 (1, d_head),
+867 (0, 0),
+868 (d_head, BLOCK_K),
+869 (0, 1))
+870 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+871 (q_seq_len,),
+872 (1,),
+873 (i,),
+874 (BLOCK_Q,),
+875 (0,))
+876 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+877 (q_seq_len,),
+878 (1,),
+879 (i,),
+880 (BLOCK_Q,),
+881 (0,))855 tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))884 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
+885 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
+886 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
+887 b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")858@triton.jit
-859def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-860 b_do, b_lse, b_pdp,
-861 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
-862 i, j, steps,
-863 MASK: tl.constexpr,
-864 q_seq_len: tl.constexpr,
-865 kv_seq_len: tl.constexpr):890 b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)869 offs_i = i + tl.arange(0, BLOCK_Q)
-870 offs_j = j + tl.arange(0, BLOCK_K)894 if is_causal:873 p_kT = tl.advance(p_kT, (0, j))
-874 p_vT = tl.advance(p_vT, (0, j))
-875
-876 tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')896 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+897 b_do, b_lse, b_pdp,
+898 BLOCK_Q, BLOCK_K,
+899 i=i, j=i,
+900 steps=BLOCK_Q // BLOCK_K,
+901 MASK=True,
+902 q_seq_len=q_seq_len,
+903 kv_seq_len=kv_seq_len
+904 )879 for _ in range(steps):907 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+908 b_do, b_lse, b_pdp,
+909 BLOCK_Q, BLOCK_K,
+910 i=i, j=tl.full([], 0, tl.int32), # type: ignore
+911 steps=i // BLOCK_K,
+912 MASK=False,
+913 q_seq_len=q_seq_len,
+914 kv_seq_len=kv_seq_len
+915 )
+916 else:881 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")918 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+919 b_do, b_lse, b_pdp,
+920 BLOCK_Q, BLOCK_K,
+921 i=i, j=tl.full([], 0, tl.int32), # type: ignore
+922 steps=tl.cdiv(kv_seq_len, BLOCK_K),
+923 MASK=False,
+924 q_seq_len=q_seq_len,
+925 kv_seq_len=kv_seq_len
+926 )883 b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")929 b_dq *= 0.6931471824645996886 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)932 tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))895 b_p = tl.math.exp2(b_s - b_lse[:, None])935@triton.jit
+936def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+937 b_do, b_lse, b_pdp,
+938 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
+939 i, j, steps,
+940 MASK: tl.constexpr,
+941 q_seq_len: tl.constexpr,
+942 kv_seq_len: tl.constexpr):898 if MASK:
-899 causal_mask = (offs_i[:, None] >= offs_j[None, :])
-900 b_p = tl.where(causal_mask, b_p, 0.0)948 offs_i = i + tl.arange(0, BLOCK_Q)
+949 offs_j = j + tl.arange(0, BLOCK_K)903 j_mask = offs_j < kv_seq_len
-904 b_p = tl.where(j_mask[None, :], b_p, 0.0)952 p_kT = tl.advance(p_kT, (0, j))
+953 p_vT = tl.advance(p_vT, (0, j))
+954
+955 tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')958 for _ in range(steps):909 b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)960 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")911 b_ds = b_p * (b_dp - b_pdp[:, None])962 b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")913 b_dq += tl.dot(b_ds.to(b_kT.dtype), tl.trans(b_kT), out_dtype=HI_PRES_TL)965 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)916 offs_j += BLOCK_K
-917 p_kT = tl.advance(p_kT, (0, BLOCK_K))
-918 p_vT = tl.advance(p_vT, (0, BLOCK_K))974 b_p = tl.math.exp2(b_s - b_lse[:, None])921 return b_dq977 if MASK:
+978 causal_mask = (offs_i[:, None] >= offs_j[None, :])
+979 b_p = tl.where(causal_mask, b_p, 0.0)Mask out if the block is beyond the end of
+ +982 j_mask = offs_j < kv_seq_len
+983 b_p = tl.where(j_mask[None, :], b_p, 0.0)+ +
+ +
988 b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)+ +
990 b_ds = b_p * (b_dp - b_pdp[:, None])+ +
992 b_dq += tl.dot(b_ds.to(b_kT.dtype), tl.trans(b_kT), out_dtype=HI_PRES_TL)Increment pointers.
+ +995 offs_j += BLOCK_K
+996 p_kT = tl.advance(p_kT, (0, BLOCK_K))
+997 p_vT = tl.advance(p_vT, (0, BLOCK_K))Return accumulated
+ +1000 return b_dqThis is the code to test and measure performance of our flash attention implementation
+1import triton
-2
-3import torch
-4from labml import logger, monit
-5from labml_nn.transformers.flash import attention
-6
-7HI_PRES_TORCH = torch.float327import torch
+8import triton
+9
+10from labml import logger, monit
+11from labml_nn.transformers.flash import attention
+12
+13HI_PRES_TORCH = torch.float3210@torch.no_grad()
-11def _calc_abs_rel_error(a: torch.Tensor, b: torch.Tensor, atol=1e-2):
-12 d = (a - b).abs()
-13 max_abs = d.max()
-14 d = (d - atol).clamp(min=0)
-15 d = d / b.abs()
-16 max_rel = d.max()
-17
-18 return max_abs.cpu().item(), max_rel.cpu().item()
-19
-20
-21def _test_op(batch_size, n_heads, k_heads, q_seq_len, kv_seq_len, d_head, causal, dtype, device):
-22 with monit.section(f'Init {q_seq_len} {kv_seq_len} {d_head}'):
-23 torch.manual_seed(20)
-24 q = (torch.empty((batch_size, n_heads, q_seq_len, d_head),
-25 dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
-26 k = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
-27 dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
-28 v = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
-29 dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
-30 sm_scale = d_head ** -0.5
-31 d_out = torch.randn_like(q)16@torch.no_grad()
+17def _calc_abs_rel_error(a: torch.Tensor, b: torch.Tensor, atol=1e-2):reference implementation
- +33 mask = torch.tril(torch.ones((q_seq_len, kv_seq_len), device=device, dtype=torch.bool))
-34 torch.cuda.synchronize()
-35
-36 with monit.section('Pytorch'):
-37 p = torch.matmul(q.view(batch_size, k_heads, -1, q_seq_len, d_head),
-38 k.transpose(2, 3)[:, :, None, :, :]) * sm_scale
-39 if causal:
-40 p[:, :, :, ~mask] = float("-inf")
-41 p = torch.softmax(p.to(HI_PRES_TORCH), dim=-1).to(dtype)
-42 ref_out = torch.matmul(p, v[:, :, None, :, :])
-43 ref_out = ref_out.view(q.shape)
-44 ref_out.backward(d_out)
-45 ref_dv, v.grad = v.grad.clone(), None
-46 ref_dk, k.grad = k.grad.clone(), None
-47 ref_dq, q.grad = q.grad.clone(), None
-48 torch.cuda.synchronize()
-49
-50 with monit.section('Triton'):
-51 assert q.dtype == dtype
-52 tri_out = attention(q, k, v, causal, sm_scale).to(dtype)
-53 monit.progress(0.5)
-54
-55 tri_out.backward(d_out)
-56 monit.progress(0.9)
-57 tri_dv, v.grad = v.grad.clone(), None # type: ignore
-58 tri_dk, k.grad = k.grad.clone(), None # type: ignore
-59 tri_dq, q.grad = q.grad.clone(), None # type: ignore
-60 torch.cuda.synchronize()
-61
-62 with monit.section('Test') as s:21 d = (a - b).abs()
+22 max_abs = d.max()
+23 d = (d - atol).clamp(min=0)
+24 d = d / b.abs()
+25 max_rel = d.max()
+26
+27 return max_abs.cpu().item(), max_rel.cpu().item()64 passed = True
-65 if not torch.allclose(tri_out, ref_out, atol=1e-2, rtol=0.):
-66 abs_err, rel_err = _calc_abs_rel_error(ref_out, tri_out)
-67 logger.log(('[FAILED]', logger.Text.danger), f' Out mismatch {abs_err} {rel_err}')
-68 passed = False
-69 rtol = 1e-1
-70 if not torch.allclose(tri_dq, ref_dq, atol=1e-2, rtol=rtol):
-71 abs_err, rel_err = _calc_abs_rel_error(ref_dq, tri_dq)
-72 logger.log(('[FAILED]', logger.Text.danger), f' dQ mismatch {abs_err} {rel_err}')
-73 passed = False
-74 if not torch.allclose(tri_dv, ref_dv, atol=1e-2, rtol=rtol):
-75 abs_err, rel_err = _calc_abs_rel_error(ref_dv, tri_dv)
-76 logger.log(('[FAILED]', logger.Text.danger), f' dV mismatch {abs_err} {rel_err}')
-77 passed = False
-78 if not torch.allclose(tri_dk, ref_dk, atol=1e-2, rtol=rtol):
-79 abs_err, rel_err = _calc_abs_rel_error(ref_dk, tri_dk)
-80 logger.log(('[FAILED]', logger.Text.danger), f' dK mismatch {abs_err} {rel_err}')
-81 passed = False
-82
-83 if passed:
-84 logger.log('[PASSED]', logger.Text.success)
-85 s.success = True
-86 else:
-87 s.success = False
-88 torch.cuda.synchronize()30def test_fwd_bwd(batch_size, n_heads, k_heads, q_seq_len, kv_seq_len, d_head, causal, dtype, device):91def _perf_triton_fn(*, device, dtype, batch_size, k_heads, n_groups, seq_len, d_head, causal):
-92 q = torch.randn((batch_size, k_heads * n_groups, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
-93 k = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
-94 v = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
-95 sm_scale = d_head ** -0.5
-96 return lambda: attention(q, k, v, causal, sm_scale)35 with monit.section(f'Init {q_seq_len} {kv_seq_len} {d_head}'):
+36 torch.manual_seed(20)
+37 q = (torch.empty((batch_size, n_heads, q_seq_len, d_head),
+38 dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
+39 k = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
+40 dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
+41 v = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
+42 dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
+43 sm_scale = d_head ** -0.5
+44 d_out = torch.randn_like(q)reference implementation
+99def _perf_flash(*, batch_size, k_heads, n_groups, seq_len, d_head, causal, device, dtype):
-100 q = torch.randn((batch_size, seq_len, k_heads * n_groups, d_head), dtype=dtype, device=device, requires_grad=True)
-101 k = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
-102 v = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
-103 from flash_attn import flash_attn_func
-104 return lambda: flash_attn_func(q, k, v, causal=causal)46 mask = torch.tril(torch.ones((q_seq_len, kv_seq_len), device=device, dtype=torch.bool))
+47 torch.cuda.synchronize()
+48
+49 with monit.section('Pytorch'):
+50 p = torch.matmul(q.view(batch_size, k_heads, -1, q_seq_len, d_head),
+51 k.transpose(2, 3)[:, :, None, :, :]) * sm_scale
+52 if causal:
+53 p[:, :, :, ~mask] = float("-inf")
+54 p = torch.softmax(p.to(HI_PRES_TORCH), dim=-1).to(dtype)
+55 ref_out = torch.matmul(p, v[:, :, None, :, :])
+56 ref_out = ref_out.view(q.shape)
+57 ref_out.backward(d_out)
+58 ref_dv, v.grad = v.grad.clone(), None
+59 ref_dk, k.grad = k.grad.clone(), None
+60 ref_dq, q.grad = q.grad.clone(), None
+61 torch.cuda.synchronize()
+62
+63 with monit.section('Triton'):
+64 assert q.dtype == dtype
+65 tri_out = attention(q, k, v, causal, sm_scale).to(dtype)
+66 monit.progress(0.5)
+67
+68 tri_out.backward(d_out)
+69 monit.progress(0.9)
+70 tri_dv, v.grad = v.grad.clone(), None # type: ignore
+71 tri_dk, k.grad = k.grad.clone(), None # type: ignore
+72 tri_dq, q.grad = q.grad.clone(), None # type: ignore
+73 torch.cuda.synchronize()
+74
+75 with monit.section('Test') as s:compare
+107def _perf_fn(name, fn, *, batch_size, k_heads, n_groups, seq_len, d_head, causal, is_bwd: bool):
-108 if is_bwd:
-109 o = fn()
-110 do = torch.randn_like(o)
-111 fn = lambda: o.backward(do, retain_graph=True)
-112 ms = triton.testing.do_bench(fn)
-113
-114 flops_per_matmul = 2.0 * batch_size * k_heads * n_groups * seq_len * seq_len * d_head
-115 total_flops = 2 * flops_per_matmul
-116 if causal:
-117 total_flops *= 0.5
-118 if is_bwd:
-119 total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
-120
-121 tf_ps = total_flops * 1e-12 / (ms * 1e-3)
-122 logger.log((f'{name}', logger.Text.key), ': ', f'{ms :,.1f}ms', ' ', f'{tf_ps :,.2f}TFps')77 passed = True
+78 if not torch.allclose(tri_out, ref_out, atol=1e-2, rtol=0.):
+79 abs_err, rel_err = _calc_abs_rel_error(ref_out, tri_out)
+80 logger.log(('[FAILED]', logger.Text.danger), f' Out mismatch {abs_err} {rel_err}')
+81 passed = False
+82 rtol = 1e-1
+83 if not torch.allclose(tri_dq, ref_dq, atol=1e-2, rtol=rtol):
+84 abs_err, rel_err = _calc_abs_rel_error(ref_dq, tri_dq)
+85 logger.log(('[FAILED]', logger.Text.danger), f' dQ mismatch {abs_err} {rel_err}')
+86 passed = False
+87 if not torch.allclose(tri_dv, ref_dv, atol=1e-2, rtol=rtol):
+88 abs_err, rel_err = _calc_abs_rel_error(ref_dv, tri_dv)
+89 logger.log(('[FAILED]', logger.Text.danger), f' dV mismatch {abs_err} {rel_err}')
+90 passed = False
+91 if not torch.allclose(tri_dk, ref_dk, atol=1e-2, rtol=rtol):
+92 abs_err, rel_err = _calc_abs_rel_error(ref_dk, tri_dk)
+93 logger.log(('[FAILED]', logger.Text.danger), f' dK mismatch {abs_err} {rel_err}')
+94 passed = False
+95
+96 if passed:
+97 logger.log('[PASSED]', logger.Text.success)
+98 s.success = True
+99 else:
+100 s.success = False
+101 torch.cuda.synchronize()Get a partial function to test performance of our implementation
+125def _test():
-126 device = torch.device('cuda:0')
-127 torch.cuda.set_device(device)
-128
-129 dtype = torch.float16104def _perf_triton_fn(*, device, dtype, batch_size, k_heads, n_groups, seq_len, d_head, causal):108 q = torch.randn((batch_size, k_heads * n_groups, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
+109 k = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
+110 v = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
+111 sm_scale = d_head ** -0.5
+112 return lambda: attention(q, k, v, causal, sm_scale)Get a partial function to test performance of original flash implementation
+ +115def _perf_flash(*, batch_size, k_heads, n_groups, seq_len, d_head, causal, device, dtype):119 q = torch.randn((batch_size, seq_len, k_heads * n_groups, d_head), dtype=dtype, device=device, requires_grad=True)
+120 k = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
+121 v = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
+122 from flash_attn import flash_attn_func
+123 return lambda: flash_attn_func(q, k, v, causal=causal)126def measure_performance(name, fn, *, batch_size, k_heads, n_groups, seq_len, d_head, causal, is_bwd: bool):130 if is_bwd:
+131 o = fn()
+132 do = torch.randn_like(o)
+133 fn = lambda: o.backward(do, retain_graph=True)
+134 ms = triton.testing.do_bench(fn)
+135
+136 flops_per_matmul = 2.0 * batch_size * k_heads * n_groups * seq_len * seq_len * d_head
+137 total_flops = 2 * flops_per_matmul
+138 if causal:
+139 total_flops *= 0.5
+140 if is_bwd:
+141 total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
+142
+143 tf_ps = total_flops * 1e-12 / (ms * 1e-3)
+144 logger.log((f'{name}', logger.Text.key), ': ', f'{ms :,.1f}ms', ' ', f'{tf_ps :,.2f}TFps')147def main():
+148 device = torch.device('cuda:0')
+149 torch.cuda.set_device(device)
+150
+151 dtype = torch.float16only works on post-Ampere GPUs right now
132 _test_op(1, 4, 1, 2048, 2048, 128, True, dtype=dtype, device=device)
-133 _test_op(16, 32, 8, 2001, 4001, 128, False, dtype=dtype, device=device)
-134 _test_op(4, 32, 8, 2048, 1024, 128, False, dtype=dtype, device=device)
-135 _test_op(4, 32, 8, 2001, 4001, 128, True, dtype=dtype, device=device)
-136
-137 _conf = {
-138 'batch_size': 16,
-139 'k_heads': 8,
-140 'n_groups': 4,
-141 'seq_len': 2048,
-142 'd_head': 128,
-143 }
-144
-145 for _causal in [False, True]:
-146 for is_bwd in [False, True]:
-147 logger.log(f'{"Causal" if _causal else "Non-causal"} {" Backward" if is_bwd else ""}', logger.Text.title)
-148 _perf_fn(f'flash', _perf_flash(causal=_causal, device=device, dtype=dtype, **_conf),
-149 is_bwd=is_bwd,
-150 causal=_causal, **_conf)
-151 _perf_fn(f'triton', _perf_triton_fn(causal=_causal, device=device, dtype=dtype, **_conf),
-152 is_bwd=is_bwd,
-153 causal=_causal, **_conf)
-154
-155
-156if __name__ == "__main__":
-157 _test()154 test_fwd_bwd(1, 4, 1, 2048, 2048, 128, True, dtype=dtype, device=device)
+155 test_fwd_bwd(16, 32, 8, 2001, 4001, 128, False, dtype=dtype, device=device)
+156 test_fwd_bwd(4, 32, 8, 2048, 1024, 128, False, dtype=dtype, device=device)
+157 test_fwd_bwd(4, 32, 8, 2001, 4001, 128, True, dtype=dtype, device=device)
+158
+159 _conf = {
+160 'batch_size': 16,
+161 'k_heads': 8,
+162 'n_groups': 4,
+163 'seq_len': 2048,
+164 'd_head': 128,
+165 }
+166
+167 for _causal in [False, True]:
+168 for is_bwd in [False, True]:
+169 logger.log(f'{"Causal" if _causal else "Non-causal"} {" Backward" if is_bwd else ""}', logger.Text.title)
+170 measure_performance(f'flash', _perf_flash(causal=_causal, device=device, dtype=dtype, **_conf),
+171 is_bwd=is_bwd,
+172 causal=_causal, **_conf)
+173 measure_performance(f'triton', _perf_triton_fn(causal=_causal, device=device, dtype=dtype, **_conf),
+174 is_bwd=is_bwd,
+175 causal=_causal, **_conf)
+176
+177
+178if __name__ == "__main__":
+179 main()