You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys:
Finally you can compute,
To make it numerically stable flash attention subtracts the current max of before exponentiating.
So it maintains the following while iterating over keys:
For each block of keys it updates them:
Then finally,
where is when and otherwise.
Flash attention paper introduces to simplify computation.
Then,
Note: , , , etc are row vectors.
101from typing import Any, Tuple
102
103import torch
104import triton
105import triton.language as tl
106
107HI_PRES_TL: tl.constexpr = tl.float32
108HI_PRES_TORCH: torch.dtype = torch.float32111class AttentionFunc(torch.autograd.Function): Group query attention forward pass. Returns the output in shape [batch_size, n_heads, q_seq_len, d_head]
.
ctx
is the context for torch gradient descent q
has shape [batch_size, n_heads, q_seq_len, d_head]
q
has shape [batch_size, n_heads, q_seq_len, d_head]
k
has shape [batch_size, k_heads, kv_seq_len, d_head]
v
has shape [batch_size, k_heads, kv_seq_len, d_head]
causal
whether to apply causal attention mask sm_scale
softmax scale factor112 @staticmethod
113 def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
114 causal: bool, sm_scale: float) -> torch.Tensor:126 batch_size, n_heads, q_seq_len, d_head = q.shape
127 _, k_heads, kv_seq_len, _ = k.shape
128 assert n_heads % k_heads == 0
129 n_groups = n_heads // k_headsShape constraints
132 assert d_head == k.shape[-1] == v.shape[-1]
133 assert d_head in {16, 32, 64, 128, 256}Change the tensors combining the heads with the batch dimension
136 q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
137 k = k.view(batch_size * k_heads, kv_seq_len, d_head)
138 v = v.view(batch_size * k_heads, kv_seq_len, d_head)Make sure the tensors are contiguous and the strides are same
141 assert q.is_contiguous()
142 assert k.is_contiguous()
143 assert v.is_contiguous()
144 assert k.stride() == v.stride()Tensor for the output
147 o = torch.empty_like(q)Tensor for log of sum of exponentials
149 lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)The forward computation will be parallelized along the batch dimension and the queries in blocks of size BLOCK_Q
152 grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
153 _attn_fwd[grid](
154 q, k, v, sm_scale * 1.4426950408889634, lse, o,
155 n_groups=n_groups,
156 q_seq_len=q_seq_len,
157 kv_seq_len=kv_seq_len,
158 d_head=d_head,
159 is_causal=causal,
160 )Save the reshaped inputs and outputs for the backward pass
163 ctx.save_for_backward(q, k, v, o, lse)
164 ctx.sm_scale = sm_scale
165 ctx.n_groups = n_groups
166 ctx.causal = causalReturn the output in shape [batch_size, n_heads, q_seq_len, d_head]
169 return o.view(batch_size, n_heads, q_seq_len, d_head)The backward pass computes the gradients of the input tensors.
ctx
is the context for torch gradient descent do
is the gradient tensor of the attention output with shape [batch_size, n_heads, q_seq_len, d_head]
171 @staticmethod
172 def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:Get saved tensors and attributes
181 n_groups = ctx.n_groups
182 sm_scale = ctx.sm_scale
183 causal = ctx.causal
184 q, k, v, o, lse = ctx.saved_tensorsGet shapes
187 batch_size, n_heads, q_seq_len, d_head = do.shape
188 _, kv_seq_len, _ = k.shape
189 k_heads = n_heads // n_groupsCombine the heads with the batch dimension of the output gradients tensor
192 do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)Make sure it's contiguous and the strides are the same
195 assert do.is_contiguous()
196 assert k.stride() == v.stride()
197 assert q.stride() == o.stride() == do.stride()Create tensors for input gradients
200 dq = torch.empty_like(q)
201 dk = torch.empty_like(k)
202 dv = torch.empty_like(v)Precompute
205 k_scaled = k * (sm_scale * 1.4426950408889634)207 pdp = torch.empty_like(lse)We use fixed BLOCK_Q
for backward pass on
213 BLOCK_Q = 16
214 pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
215 _attn_bwd_d[pre_grid](
216 o, do,
217 pdp,
218 BLOCK_Q=16,
219 d_head=d_head,
220 q_seq_len=q_seq_len,
221 n_groups=n_groups,
222 num_stages=1,
223 )228 grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
229 _attn_bwd_dkdv[grid](
230 q, k_scaled, v, sm_scale, do, dk, dv,
231 lse, pdp,
232 q_seq_len, kv_seq_len, n_groups, d_head,
233 is_causal=causal,
234
235 )240 grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
241 _attn_bwd_dq[grid](
242 q, k_scaled, v, do,
243 dq,
244 lse, pdp,
245 q_seq_len, kv_seq_len, n_groups, d_head,
246 is_causal=causal,
247 )Split the combined batch and heads
250 dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
251 dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
252 dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)255 return dq, dk, dv, None, None
256
257
258attention = AttentionFunc.apply261def _get_autotune_configs(inner_loop: str) -> list:266 configs = []List possible BLOCK_Q and BLOCK_K that satisfy BLOCK_Q divisible by BLOCK_K and also try to cover a wide range
269 for bm in [64, 128, 256]:We'll try bn in 16, 32, 64, 128 that are divisors and <= bm
271 for bn in [64, 128, 256]:
272 if inner_loop == 'key' and bm % bn != 0:
273 continue
274 if inner_loop == 'query' and bn % bm != 0:
275 continue
276 for s in [2, 3, 4]:
277 for w in [4, 8]:
278 if bm * bn < 128 * 128 and w == 8:
279 continue
280
281 configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
282
283 return configs[:1]t_q
query t_k
keys t_v
values sm_scale
softmax scale t_lse
(out) t_o
output (out) n_groups
number of groups q_seq_len
query sequence length kv_seq_len
key/value sequence length d_head
size of a head BLOCK_Q
block size for query sequence length BLOCK_K
block size for key sequence length is_causal
whether causal attentionStrides z
, h
, m
and d
denote the stride of the corresponding dimensions (batch_size
, n_heads
, seq_len
, d_head
) in the query. Stride n
denote the stride on seq_len
of key.
286@triton.autotune(_get_autotune_configs(inner_loop='key'),
287 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
288@triton.jit
289def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
290 n_groups: tl.constexpr,
291 q_seq_len: tl.constexpr,
292 kv_seq_len: tl.constexpr,
293 d_head: tl.constexpr,
294 is_causal: tl.constexpr,
295 BLOCK_Q: tl.constexpr, # q seq len block
296 BLOCK_K: tl.constexpr, # k seq len block
297 ):318 i = tl.program_id(0)
319 z = tl.program_id(1) // n_groups
320 g = tl.program_id(1) % n_groups # TODOCreate block pointers
323 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
324 (q_seq_len, d_head),
325 (d_head, 1),
326 (i * BLOCK_Q, 0),
327 (BLOCK_Q, d_head),
328 (1, 0))
329 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
330 (kv_seq_len, d_head),
331 (d_head, 1),
332 (0, 0),
333 (BLOCK_K, d_head),
334 (1, 0))
335 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
336 (d_head, kv_seq_len),
337 (1, d_head),
338 (0, 0),
339 (d_head, BLOCK_K),
340 (0, 1))
341 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
342 (q_seq_len, d_head),
343 (d_head, 1),
344 (i * BLOCK_Q, 0),
345 (BLOCK_Q, d_head),
346 (1, 0))
347 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
348 (q_seq_len,),
349 (1,),
350 (i * BLOCK_Q,),
351 (BLOCK_Q,),
352 (0,))Initialize offsets
355 offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
356 offs_j = tl.arange(0, BLOCK_K)Mask for for the last block
358 i_mask = offs_i < q_seq_lenInitialize and . is initialized to and to . So in the first update, the effect of initial is .
b_m
will be storing
364 b_m = tl.where(i_mask, -float("inf"), 0.0)
365 b_l = tl.where(i_mask, 1.0, 0.0)368 b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)Load outside the loop since it will be reused through out the loop over .
371 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
372
373 if is_causal:Inner loop upto the diagonal block
375 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
376 p_kT, p_v,
377 sm_scale_log2e,
378 BLOCK_Q, d_head, BLOCK_K,
379 offs_i, offs_j,
380 j=tl.full([], 0, tl.int32), # type: ignore
381 steps=(i * BLOCK_Q) // BLOCK_K,
382 MASK=False,
383 q_seq_len=q_seq_len,
384 kv_seq_len=kv_seq_len
385 )Diagonal block with masking within it
387 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
388 sm_scale_log2e,
389 BLOCK_Q, d_head, BLOCK_K,
390 offs_i, offs_j,
391 j=i * BLOCK_Q,
392 steps=BLOCK_Q // BLOCK_K,
393 MASK=True,
394 q_seq_len=q_seq_len,
395 kv_seq_len=kv_seq_len
396 )
397 else:Iterate through all
399 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
400 sm_scale_log2e,
401 BLOCK_Q, d_head, BLOCK_K,
402 offs_i, offs_j,
403 j=tl.full([], 0, tl.int32), # type: ignore
404 steps=tl.cdiv(kv_seq_len, BLOCK_K),
405 MASK=False,
406 q_seq_len=q_seq_len,
407 kv_seq_len=kv_seq_len
408 )Store LSE
411 tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))Store
413 tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))416@triton.jit
417def _attn_fwd_inner(b_o, b_l, b_m, b_q,
418 p_kT, p_v,
419 sm_scale_log2e,
420 BLOCK_Q: tl.constexpr,
421 d_head: tl.constexpr,
422 BLOCK_K: tl.constexpr,
423 offs_i, offs_j,
424 j,
425 steps,
426 MASK: tl.constexpr,
427 q_seq_len: tl.constexpr,
428 kv_seq_len: tl.constexpr
429 ):
430 tl.static_assert(BLOCK_Q % BLOCK_K == 0)Move and pointers
433 p_kT = tl.advance(p_kT, (0, j))
434 p_v = tl.advance(p_v, (j, 0))Iterate over , and update and
437 for _ in range(steps):Load
439 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")Compute
441 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
442 b_s = b_s * sm_scale_log2eApply causal mask
445 if MASK:
446 causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
447 b_s = tl.where(causal_mask, b_s, -float("inf"))Mask out if the block is beyond the end of
450 j_mask = (j + offs_j) < kv_seq_len
451 b_s = tl.where(j_mask[None, :], b_s, -float("inf"))454 b_m_new = tl.maximum(b_m, tl.max(b_s, -1))460 b_p = tl.math.exp2(b_s - b_m_new[:, None])463 b_l_new = tl.sum(b_p, -1)465 b_m_m_new = tl.math.exp2(b_m - b_m_new)467 b_l = b_l * b_m_m_new + b_l_new470 b_o = b_o * b_m_m_new[:, None]
471 b_p = b_p.to(b_q.dtype) # TODO
472 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
473 b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)476 b_m = b_m_newMove pointers
479 j += BLOCK_K
480 p_v = tl.advance(p_v, (BLOCK_K, 0))
481 p_kT = tl.advance(p_kT, (0, BLOCK_K))
482
483 tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
484
485 return b_o, b_l, b_m488@triton.jit
489def _attn_bwd_d(t_o, t_do,
490 t_pdp,
491 BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
492 q_seq_len: tl.constexpr,
493 n_groups: tl.constexpr,
494 ):
495 i = tl.program_id(0) * BLOCK_Q
496 z = tl.program_id(1)Create block pointers
499 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
500 (n_groups, q_seq_len, d_head),
501 (q_seq_len * d_head, d_head, 1),
502 (0, i, 0),
503 (n_groups, BLOCK_Q, d_head),
504 (2, 1, 0))
505 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
506 (n_groups, q_seq_len, d_head),
507 (q_seq_len * d_head, d_head, 1),
508 (0, i, 0),
509 (n_groups, BLOCK_Q, d_head),
510 (2, 1, 0))
511 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
512 (n_groups, q_seq_len),
513 (q_seq_len, 1),
514 (0, i),
515 (n_groups, BLOCK_Q),
516 (1, 0))Load
519 o = tl.load(p_o, boundary_check=(1,), padding_option="zero")Load
521 do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)Calculate
523 d = tl.sum(o * do, axis=-1)Save
525 tl.store(p_pdp, d, boundary_check=(1,))Compute and for by iterating over
528@triton.autotune(_get_autotune_configs(inner_loop='query'),
529 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
530@triton.jit
531def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
532 t_do,
533 t_dk, t_dv,
534 t_lse, t_pdp,
535 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
536 n_groups: tl.constexpr, d_head: tl.constexpr,
537 is_causal: tl.constexpr,
538 BLOCK_Q: tl.constexpr,
539 BLOCK_K: tl.constexpr,
540 ):545 j = tl.program_id(0) * BLOCK_K
546 z = tl.program_id(1)Create block pointers
549 p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
550 (kv_seq_len, d_head),
551 (d_head, 1),
552 (j, 0),
553 (BLOCK_K, d_head),
554 (1, 0))
555 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
556 (kv_seq_len, d_head),
557 (d_head, 1),
558 (j, 0),
559 (BLOCK_K, d_head),
560 (1, 0))
561 p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
562 (kv_seq_len, d_head),
563 (d_head, 1),
564 (j, 0),
565 (BLOCK_K, d_head),
566 (1, 0))
567 p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
568 (kv_seq_len, d_head),
569 (d_head, 1),
570 (j, 0),
571 (BLOCK_K, d_head),
572 (1, 0))Initialize and
575 b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
576 b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)Load and outside the loop.
579 b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
580 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")Iterate through queries in GQA
583 for g in range(n_groups):Create block pointers
585 p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
586 (d_head, q_seq_len),
587 (1, d_head),
588 (0, 0),
589 (d_head, BLOCK_Q),
590 (0, 1))
591
592 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
593 (q_seq_len, d_head),
594 (d_head, 1),
595 (0, 0),
596 (BLOCK_Q, d_head),
597 (1, 0))
598 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
599 (q_seq_len,),
600 (1,),
601 (0,),
602 (BLOCK_Q,),
603 (0,))
604 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
605 (q_seq_len,),
606 (1,),
607 (0,),
608 (BLOCK_Q,),
609 (0,))
610
611 if is_causal:Inner loop at the diagonal block
613 b_dk, b_dv = _attn_bwd_dkdv_inner(
614 b_dk, b_dv,
615 p_qT, b_k, b_v, p_do,
616 p_lse, p_pdp,
617 BLOCK_Q, BLOCK_K,
618 d_head,
619 j=j, i=j,
620 steps=BLOCK_K // BLOCK_Q,
621 MASK=True,
622 q_seq_len=q_seq_len,
623 kv_seq_len=kv_seq_len,
624 )Innerloop on queries after the diagonal
627 b_dk, b_dv = _attn_bwd_dkdv_inner(
628 b_dk, b_dv,
629 p_qT, b_k, b_v, p_do,
630 p_lse, p_pdp,
631 BLOCK_Q, BLOCK_K,
632 d_head,
633 j=j, i=j + BLOCK_K,
634 steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
635 MASK=False,
636 q_seq_len=q_seq_len,
637 kv_seq_len=kv_seq_len
638 )
639 else:Iterate through all queries
641 b_dk, b_dv = _attn_bwd_dkdv_inner(
642 b_dk, b_dv,
643 p_qT, b_k, b_v, p_do,
644 p_lse, p_pdp,
645 BLOCK_Q, BLOCK_K,
646 d_head,
647 j=j, i=tl.full([], 0, tl.int32),
648 steps=tl.cdiv(q_seq_len, BLOCK_Q),
649 MASK=False,
650 q_seq_len=q_seq_len,
651 kv_seq_len=kv_seq_len
652 )Save
655 tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))b_dk
had
658 b_dk *= sm_scaleSave
661 tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))Inner loop along query
664@triton.jit
665def _attn_bwd_dkdv_inner(b_dk, b_dv,
666 p_qT, b_k, b_v, p_do,
667 p_lse, p_pdp,
668 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
669 d_head: tl.constexpr,
670 j, i, steps,
671 MASK: tl.constexpr,
672 q_seq_len: tl.constexpr,
673 kv_seq_len: tl.constexpr):To apply the mask
677 tl.static_assert(BLOCK_K % BLOCK_Q == 0)Offsets and mask
680 offs_i = i + tl.arange(0, BLOCK_Q)
681 offs_j = j + tl.arange(0, BLOCK_K)Move the pointers
684 p_qT = tl.advance(p_qT, (0, i))
685 p_do = tl.advance(p_do, (i, 0))
686 p_lse = tl.advance(p_lse, (i,))
687 p_pdp = tl.advance(p_pdp, (i,))Iterate over
690 for _ in range(steps):Load
692 b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")695 b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")698 b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)707 b_pT = tl.math.exp2(b_sT - b_l[None, :])Autoregressive masking
710 if MASK:
711 mask = (offs_i[None, :] >= offs_j[:, None])
712 b_pT = tl.where(mask, b_pT, 0.0)Mask out if the block is beyond the end of
Note: No need to mask out based on because the effects on positions outside boundary will not get stored in or Masking by may also not be necessary size the tensors have 0 on loading
719 i_mask = offs_i < q_seq_len
720 b_pT = tl.where(i_mask[None, :], b_pT, 0.0)723 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
724 b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)727 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")729 b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)731 b_dsT = b_pT * (b_dpT - b_pdp[None, :])733 b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)Increment pointers.
736 offs_i += BLOCK_Q
737 p_lse = tl.advance(p_lse, (BLOCK_Q,))
738 p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
739 p_qT = tl.advance(p_qT, (0, BLOCK_Q))
740 p_do = tl.advance(p_do, (BLOCK_Q, 0))Return accumulated and
743 return b_dk, b_dv746@triton.autotune(_get_autotune_configs(inner_loop='key'),
747 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
748@triton.jit
749def _attn_bwd_dq(t_q, t_k, t_v, t_do,
750 t_dq,
751 t_lse, t_pdp,
752 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
753 n_groups: tl.constexpr, d_head: tl.constexpr,
754 is_causal: tl.constexpr,
755 BLOCK_Q: tl.constexpr,
756 BLOCK_K: tl.constexpr,
757 ):
758 i = tl.program_id(0) * BLOCK_Q
759 z = tl.program_id(1) // n_groups
760 g = tl.program_id(1) % n_groups # TODOCreate block pointers
763 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
764 (q_seq_len, d_head),
765 (d_head, 1),
766 (i, 0),
767 (BLOCK_Q, d_head),
768 (1, 0))
769 p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
770 (q_seq_len, d_head),
771 (d_head, 1),
772 (i, 0),
773 (BLOCK_Q, d_head),
774 (1, 0))
775 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
776 (q_seq_len, d_head),
777 (d_head, 1),
778 (i, 0),
779 (BLOCK_Q, d_head),
780 (1, 0))
781 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
782 (d_head, kv_seq_len),
783 (1, d_head),
784 (0, 0),
785 (d_head, BLOCK_K),
786 (0, 1))
787 p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
788 (d_head, kv_seq_len),
789 (1, d_head),
790 (0, 0),
791 (d_head, BLOCK_K),
792 (0, 1))
793 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
794 (q_seq_len,),
795 (1,),
796 (i,),
797 (BLOCK_Q,),
798 (0,))
799 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
800 (q_seq_len,),
801 (1,),
802 (i,),
803 (BLOCK_Q,),
804 (0,))Load , , , and outside the loop
807 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
808 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
809 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
810 b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")Initialize
813 b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)817 if is_causal:Compute for masked (diagonal) blocks.
819 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
820 b_do, b_lse, b_pdp,
821 BLOCK_Q, BLOCK_K,
822 i=i, j=i,
823 steps=BLOCK_Q // BLOCK_K,
824 MASK=True,
825 q_seq_len=q_seq_len,
826 kv_seq_len=kv_seq_len
827 )Compute for other blocks
830 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
831 b_do, b_lse, b_pdp,
832 BLOCK_Q, BLOCK_K,
833 i=i, j=tl.full([], 0, tl.int32), # type: ignore
834 steps=i // BLOCK_K,
835 MASK=False,
836 q_seq_len=q_seq_len,
837 kv_seq_len=kv_seq_len
838 )
839 else:Iterate through all
841 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
842 b_do, b_lse, b_pdp,
843 BLOCK_Q, BLOCK_K,
844 i=i, j=tl.full([], 0, tl.int32), # type: ignore
845 steps=tl.cdiv(kv_seq_len, BLOCK_K),
846 MASK=False,
847 q_seq_len=q_seq_len,
848 kv_seq_len=kv_seq_len
849 )b_dq
stores so multiply by to get
852 b_dq *= 0.6931471824645996Save
855 tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))Inner loop over key
858@triton.jit
859def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
860 b_do, b_lse, b_pdp,
861 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
862 i, j, steps,
863 MASK: tl.constexpr,
864 q_seq_len: tl.constexpr,
865 kv_seq_len: tl.constexpr):Offsets
869 offs_i = i + tl.arange(0, BLOCK_Q)
870 offs_j = j + tl.arange(0, BLOCK_K)Move the pointers
873 p_kT = tl.advance(p_kT, (0, j))
874 p_vT = tl.advance(p_vT, (0, j))
875
876 tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')Iterate over
879 for _ in range(steps):Load
881 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")Load
883 b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")886 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)895 b_p = tl.math.exp2(b_s - b_lse[:, None])Autoregressive masking
898 if MASK:
899 causal_mask = (offs_i[:, None] >= offs_j[None, :])
900 b_p = tl.where(causal_mask, b_p, 0.0)Mask out if the block is beyond the end of
903 j_mask = offs_j < kv_seq_len
904 b_p = tl.where(j_mask[None, :], b_p, 0.0)909 b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)911 b_ds = b_p * (b_dp - b_pdp[:, None])913 b_dq += tl.dot(b_ds.to(b_kT.dtype), tl.trans(b_kT), out_dtype=HI_PRES_TL)Increment pointers.
916 offs_j += BLOCK_K
917 p_kT = tl.advance(p_kT, (0, BLOCK_K))
918 p_vT = tl.advance(p_vT, (0, BLOCK_K))Return accumulated
921 return b_dq