Flash Attention

Forward pass

You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys:

Finally you can compute,

To make it numerically stable flash attention subtracts the current max of before exponentiating.

So it maintains the following while iterating over keys:

  • , the max
  • , the sum of exponents , and
  • , the unnormalized output

For each block of keys it updates them:

Then finally,

Backward pass

where is when and otherwise.

Flash attention paper introduces to simplify computation.

Then,

Note: , , , etc are row vectors.

101from typing import Any, Tuple
102
103import torch
104import triton
105import triton.language as tl
106
107HI_PRES_TL: tl.constexpr = tl.float32
108HI_PRES_TORCH: torch.dtype = torch.float32
111class AttentionFunc(torch.autograd.Function):

Group query attention forward pass. Returns the output in shape [batch_size, n_heads, q_seq_len, d_head] .

  • ctx is the context for torch gradient descent
  • q has shape [batch_size, n_heads, q_seq_len, d_head]
  • q has shape [batch_size, n_heads, q_seq_len, d_head]
  • k has shape [batch_size, k_heads, kv_seq_len, d_head]
  • v has shape [batch_size, k_heads, kv_seq_len, d_head]
  • causal whether to apply causal attention mask
  • sm_scale softmax scale factor
112    @staticmethod
113    def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
114                causal: bool, sm_scale: float) -> torch.Tensor:
126        batch_size, n_heads, q_seq_len, d_head = q.shape
127        _, k_heads, kv_seq_len, _ = k.shape
128        assert n_heads % k_heads == 0
129        n_groups = n_heads // k_heads

Shape constraints

132        assert d_head == k.shape[-1] == v.shape[-1]
133        assert d_head in {16, 32, 64, 128, 256}

Change the tensors combining the heads with the batch dimension

136        q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
137        k = k.view(batch_size * k_heads, kv_seq_len, d_head)
138        v = v.view(batch_size * k_heads, kv_seq_len, d_head)

Make sure the tensors are contiguous and the strides are same

141        assert q.is_contiguous()
142        assert k.is_contiguous()
143        assert v.is_contiguous()
144        assert k.stride() == v.stride()

Tensor for the output

147        o = torch.empty_like(q)

Tensor for log of sum of exponentials

149        lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)

The forward computation will be parallelized along the batch dimension and the queries in blocks of size BLOCK_Q

152        grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
153        _attn_fwd[grid](
154            q, k, v, sm_scale * 1.4426950408889634, lse, o,
155            n_groups=n_groups,
156            q_seq_len=q_seq_len,
157            kv_seq_len=kv_seq_len,
158            d_head=d_head,
159            is_causal=causal,
160        )

Save the reshaped inputs and outputs for the backward pass

163        ctx.save_for_backward(q, k, v, o, lse)
164        ctx.sm_scale = sm_scale
165        ctx.n_groups = n_groups
166        ctx.causal = causal

Return the output in shape [batch_size, n_heads, q_seq_len, d_head]

169        return o.view(batch_size, n_heads, q_seq_len, d_head)

The backward pass computes the gradients of the input tensors.

  • ctx is the context for torch gradient descent
  • do is the gradient tensor of the attention output with shape [batch_size, n_heads, q_seq_len, d_head]
171    @staticmethod
172    def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:

Get saved tensors and attributes

181        n_groups = ctx.n_groups
182        sm_scale = ctx.sm_scale
183        causal = ctx.causal
184        q, k, v, o, lse = ctx.saved_tensors

Get shapes

187        batch_size, n_heads, q_seq_len, d_head = do.shape
188        _, kv_seq_len, _ = k.shape
189        k_heads = n_heads // n_groups

Combine the heads with the batch dimension of the output gradients tensor

192        do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)

Make sure it's contiguous and the strides are the same

195        assert do.is_contiguous()
196        assert k.stride() == v.stride()
197        assert q.stride() == o.stride() == do.stride()

Create tensors for input gradients

200        dq = torch.empty_like(q)
201        dk = torch.empty_like(k)
202        dv = torch.empty_like(v)

Precompute

205        k_scaled = k * (sm_scale * 1.4426950408889634)

207        pdp = torch.empty_like(lse)

We use fixed BLOCK_Q for backward pass on

Compute

This is parallelized along the batch and query in blocks of size BLOCK_Q

213        BLOCK_Q = 16
214        pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
215        _attn_bwd_d[pre_grid](
216            o, do,
217            pdp,
218            BLOCK_Q=16,
219            d_head=d_head,
220            q_seq_len=q_seq_len,
221            n_groups=n_groups,
222            num_stages=1,
223        )

Compute and

This is parallelized along the batch and keys in blocks of size BLOCK_K

228        grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
229        _attn_bwd_dkdv[grid](
230            q, k_scaled, v, sm_scale, do, dk, dv,
231            lse, pdp,
232            q_seq_len, kv_seq_len, n_groups, d_head,
233            is_causal=causal,
234
235        )

Compute

This is parallelized along the batch and queries in blocks of size BLOCK_Q

240        grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
241        _attn_bwd_dq[grid](
242            q, k_scaled, v, do,
243            dq,
244            lse, pdp,
245            q_seq_len, kv_seq_len, n_groups, d_head,
246            is_causal=causal,
247        )

Split the combined batch and heads

250        dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
251        dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
252        dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)

255        return dq, dk, dv, None, None
256
257
258attention = AttentionFunc.apply

Configs for auto-tuning

261def _get_autotune_configs(inner_loop: str) -> list:
266    configs = []

List possible BLOCK_Q and BLOCK_K that satisfy BLOCK_Q divisible by BLOCK_K and also try to cover a wide range

269    for bm in [64, 128, 256]:

We'll try bn in 16, 32, 64, 128 that are divisors and <= bm

271        for bn in [64, 128, 256]:
272            if inner_loop == 'key' and bm % bn != 0:
273                continue
274            if inner_loop == 'query' and bn % bm != 0:
275                continue
276            for s in [2, 3, 4]:
277                for w in [4, 8]:
278                    if bm * bn < 128 * 128 and w == 8:
279                        continue
280
281                    configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
282
283    return configs[:1]
  • t_q query
  • t_k keys
  • t_v values
  • sm_scale softmax scale
  • t_lse (out)
  • t_o output (out)
  • n_groups number of groups
  • q_seq_len query sequence length
  • kv_seq_len key/value sequence length
  • d_head size of a head
  • BLOCK_Q block size for query sequence length
  • BLOCK_K block size for key sequence length
  • is_causal whether causal attention

Strides z , h , m and d denote the stride of the corresponding dimensions (batch_size , n_heads , seq_len , d_head ) in the query. Stride n denote the stride on seq_len of key.

286@triton.autotune(_get_autotune_configs(inner_loop='key'),
287                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
288@triton.jit
289def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
290              n_groups: tl.constexpr,
291              q_seq_len: tl.constexpr,
292              kv_seq_len: tl.constexpr,
293              d_head: tl.constexpr,
294              is_causal: tl.constexpr,
295              BLOCK_Q: tl.constexpr,  # q seq len block
296              BLOCK_K: tl.constexpr,  # k seq len block
297              ):
318    i = tl.program_id(0)
319    z = tl.program_id(1) // n_groups
320    g = tl.program_id(1) % n_groups  # TODO

Create block pointers

323    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
324                            (q_seq_len, d_head),
325                            (d_head, 1),
326                            (i * BLOCK_Q, 0),
327                            (BLOCK_Q, d_head),
328                            (1, 0))
329    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
330                            (kv_seq_len, d_head),
331                            (d_head, 1),
332                            (0, 0),
333                            (BLOCK_K, d_head),
334                            (1, 0))
335    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
336                             (d_head, kv_seq_len),
337                             (1, d_head),
338                             (0, 0),
339                             (d_head, BLOCK_K),
340                             (0, 1))
341    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
342                            (q_seq_len, d_head),
343                            (d_head, 1),
344                            (i * BLOCK_Q, 0),
345                            (BLOCK_Q, d_head),
346                            (1, 0))
347    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
348                              (q_seq_len,),
349                              (1,),
350                              (i * BLOCK_Q,),
351                              (BLOCK_Q,),
352                              (0,))

Initialize offsets

355    offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
356    offs_j = tl.arange(0, BLOCK_K)

Mask for for the last block

358    i_mask = offs_i < q_seq_len

Initialize and . is initialized to and to . So in the first update, the effect of initial is .

b_m will be storing

364    b_m = tl.where(i_mask, -float("inf"), 0.0)
365    b_l = tl.where(i_mask, 1.0, 0.0)

368    b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)

Load outside the loop since it will be reused through out the loop over .

371    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
372
373    if is_causal:

Inner loop upto the diagonal block

375        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
376                                        p_kT, p_v,
377                                        sm_scale_log2e,
378                                        BLOCK_Q, d_head, BLOCK_K,
379                                        offs_i, offs_j,
380                                        j=tl.full([], 0, tl.int32),  # type: ignore
381                                        steps=(i * BLOCK_Q) // BLOCK_K,
382                                        MASK=False,
383                                        q_seq_len=q_seq_len,
384                                        kv_seq_len=kv_seq_len
385                                        )

Diagonal block with masking within it

387        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
388                                        sm_scale_log2e,
389                                        BLOCK_Q, d_head, BLOCK_K,
390                                        offs_i, offs_j,
391                                        j=i * BLOCK_Q,
392                                        steps=BLOCK_Q // BLOCK_K,
393                                        MASK=True,
394                                        q_seq_len=q_seq_len,
395                                        kv_seq_len=kv_seq_len
396                                        )
397    else:

Iterate through all

399        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
400                                        sm_scale_log2e,
401                                        BLOCK_Q, d_head, BLOCK_K,
402                                        offs_i, offs_j,
403                                        j=tl.full([], 0, tl.int32),  # type: ignore
404                                        steps=tl.cdiv(kv_seq_len, BLOCK_K),
405                                        MASK=False,
406                                        q_seq_len=q_seq_len,
407                                        kv_seq_len=kv_seq_len
408                                        )

Store LSE

411    tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))

Store

413    tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
416@triton.jit
417def _attn_fwd_inner(b_o, b_l, b_m, b_q,
418                    p_kT, p_v,
419                    sm_scale_log2e,
420                    BLOCK_Q: tl.constexpr,
421                    d_head: tl.constexpr,
422                    BLOCK_K: tl.constexpr,
423                    offs_i, offs_j,
424                    j,
425                    steps,
426                    MASK: tl.constexpr,
427                    q_seq_len: tl.constexpr,
428                    kv_seq_len: tl.constexpr
429                    ):
430    tl.static_assert(BLOCK_Q % BLOCK_K == 0)

Move and pointers

433    p_kT = tl.advance(p_kT, (0, j))
434    p_v = tl.advance(p_v, (j, 0))

Iterate over , and update and

437    for _ in range(steps):

Load

439        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")

Compute

441        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
442        b_s = b_s * sm_scale_log2e

Apply causal mask

445        if MASK:
446            causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
447            b_s = tl.where(causal_mask, b_s, -float("inf"))

Mask out if the block is beyond the end of

450        j_mask = (j + offs_j) < kv_seq_len
451        b_s = tl.where(j_mask[None, :], b_s, -float("inf"))

454        b_m_new = tl.maximum(b_m, tl.max(b_s, -1))

460        b_p = tl.math.exp2(b_s - b_m_new[:, None])

463        b_l_new = tl.sum(b_p, -1)

465        b_m_m_new = tl.math.exp2(b_m - b_m_new)

467        b_l = b_l * b_m_m_new + b_l_new

470        b_o = b_o * b_m_m_new[:, None]
471        b_p = b_p.to(b_q.dtype)  # TODO
472        b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
473        b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)

476        b_m = b_m_new

Move pointers

479        j += BLOCK_K
480        p_v = tl.advance(p_v, (BLOCK_K, 0))
481        p_kT = tl.advance(p_kT, (0, BLOCK_K))
482
483    tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
484
485    return b_o, b_l, b_m
488@triton.jit
489def _attn_bwd_d(t_o, t_do,
490                t_pdp,
491                BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
492                q_seq_len: tl.constexpr,
493                n_groups: tl.constexpr,
494                ):
495    i = tl.program_id(0) * BLOCK_Q
496    z = tl.program_id(1)

Create block pointers

499    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
500                            (n_groups, q_seq_len, d_head),
501                            (q_seq_len * d_head, d_head, 1),
502                            (0, i, 0),
503                            (n_groups, BLOCK_Q, d_head),
504                            (2, 1, 0))
505    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
506                             (n_groups, q_seq_len, d_head),
507                             (q_seq_len * d_head, d_head, 1),
508                             (0, i, 0),
509                             (n_groups, BLOCK_Q, d_head),
510                             (2, 1, 0))
511    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
512                              (n_groups, q_seq_len),
513                              (q_seq_len, 1),
514                              (0, i),
515                              (n_groups, BLOCK_Q),
516                              (1, 0))

Load

519    o = tl.load(p_o, boundary_check=(1,), padding_option="zero")

Load

521    do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)

Calculate

523    d = tl.sum(o * do, axis=-1)

Save

525    tl.store(p_pdp, d, boundary_check=(1,))

Compute and for by iterating over

528@triton.autotune(_get_autotune_configs(inner_loop='query'),
529                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
530@triton.jit
531def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
532                   t_do,
533                   t_dk, t_dv,
534                   t_lse, t_pdp,
535                   q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
536                   n_groups: tl.constexpr, d_head: tl.constexpr,
537                   is_causal: tl.constexpr,
538                   BLOCK_Q: tl.constexpr,
539                   BLOCK_K: tl.constexpr,
540                   ):
545    j = tl.program_id(0) * BLOCK_K
546    z = tl.program_id(1)

Create block pointers

549    p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
550                            (kv_seq_len, d_head),
551                            (d_head, 1),
552                            (j, 0),
553                            (BLOCK_K, d_head),
554                            (1, 0))
555    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
556                            (kv_seq_len, d_head),
557                            (d_head, 1),
558                            (j, 0),
559                            (BLOCK_K, d_head),
560                            (1, 0))
561    p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
562                             (kv_seq_len, d_head),
563                             (d_head, 1),
564                             (j, 0),
565                             (BLOCK_K, d_head),
566                             (1, 0))
567    p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
568                             (kv_seq_len, d_head),
569                             (d_head, 1),
570                             (j, 0),
571                             (BLOCK_K, d_head),
572                             (1, 0))

Initialize and

575    b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
576    b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)

Load and outside the loop.

579    b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
580    b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")

Iterate through queries in GQA

583    for g in range(n_groups):

Create block pointers

585        p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
586                                 (d_head, q_seq_len),
587                                 (1, d_head),
588                                 (0, 0),
589                                 (d_head, BLOCK_Q),
590                                 (0, 1))
591
592        p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
593                                 (q_seq_len, d_head),
594                                 (d_head, 1),
595                                 (0, 0),
596                                 (BLOCK_Q, d_head),
597                                 (1, 0))
598        p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
599                                  (q_seq_len,),
600                                  (1,),
601                                  (0,),
602                                  (BLOCK_Q,),
603                                  (0,))
604        p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
605                                  (q_seq_len,),
606                                  (1,),
607                                  (0,),
608                                  (BLOCK_Q,),
609                                  (0,))
610
611        if is_causal:

Inner loop at the diagonal block

613            b_dk, b_dv = _attn_bwd_dkdv_inner(
614                b_dk, b_dv,
615                p_qT, b_k, b_v, p_do,
616                p_lse, p_pdp,
617                BLOCK_Q, BLOCK_K,
618                d_head,
619                j=j, i=j,
620                steps=BLOCK_K // BLOCK_Q,
621                MASK=True,
622                q_seq_len=q_seq_len,
623                kv_seq_len=kv_seq_len,
624            )

Innerloop on queries after the diagonal

627            b_dk, b_dv = _attn_bwd_dkdv_inner(
628                b_dk, b_dv,
629                p_qT, b_k, b_v, p_do,
630                p_lse, p_pdp,
631                BLOCK_Q, BLOCK_K,
632                d_head,
633                j=j, i=j + BLOCK_K,
634                steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
635                MASK=False,
636                q_seq_len=q_seq_len,
637                kv_seq_len=kv_seq_len
638            )
639        else:

Iterate through all queries

641            b_dk, b_dv = _attn_bwd_dkdv_inner(
642                b_dk, b_dv,
643                p_qT, b_k, b_v, p_do,
644                p_lse, p_pdp,
645                BLOCK_Q, BLOCK_K,
646                d_head,
647                j=j, i=tl.full([], 0, tl.int32),
648                steps=tl.cdiv(q_seq_len, BLOCK_Q),
649                MASK=False,
650                q_seq_len=q_seq_len,
651                kv_seq_len=kv_seq_len
652            )

Save

655    tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))

b_dk had

658    b_dk *= sm_scale

Save

661    tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))

Inner loop along query

664@triton.jit
665def _attn_bwd_dkdv_inner(b_dk, b_dv,
666                         p_qT, b_k, b_v, p_do,
667                         p_lse, p_pdp,
668                         BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
669                         d_head: tl.constexpr,
670                         j, i, steps,
671                         MASK: tl.constexpr,
672                         q_seq_len: tl.constexpr,
673                         kv_seq_len: tl.constexpr):

To apply the mask

677    tl.static_assert(BLOCK_K % BLOCK_Q == 0)

Offsets and mask

680    offs_i = i + tl.arange(0, BLOCK_Q)
681    offs_j = j + tl.arange(0, BLOCK_K)

Move the pointers

684    p_qT = tl.advance(p_qT, (0, i))
685    p_do = tl.advance(p_do, (i, 0))
686    p_lse = tl.advance(p_lse, (i,))
687    p_pdp = tl.advance(p_pdp, (i,))

Iterate over

690    for _ in range(steps):

Load

692        b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")

695        b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")

698        b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)

707        b_pT = tl.math.exp2(b_sT - b_l[None, :])

Autoregressive masking

710        if MASK:
711            mask = (offs_i[None, :] >= offs_j[:, None])
712            b_pT = tl.where(mask, b_pT, 0.0)

Mask out if the block is beyond the end of

Note: No need to mask out based on because the effects on positions outside boundary will not get stored in or Masking by may also not be necessary size the tensors have 0 on loading

719        i_mask = offs_i < q_seq_len
720        b_pT = tl.where(i_mask[None, :], b_pT, 0.0)

723        b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
724        b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)

727        b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")

729        b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)

731        b_dsT = b_pT * (b_dpT - b_pdp[None, :])

733        b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)

Increment pointers.

736        offs_i += BLOCK_Q
737        p_lse = tl.advance(p_lse, (BLOCK_Q,))
738        p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
739        p_qT = tl.advance(p_qT, (0, BLOCK_Q))
740        p_do = tl.advance(p_do, (BLOCK_Q, 0))

Return accumulated and

743    return b_dk, b_dv
746@triton.autotune(_get_autotune_configs(inner_loop='key'),
747                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
748@triton.jit
749def _attn_bwd_dq(t_q, t_k, t_v, t_do,
750                 t_dq,
751                 t_lse, t_pdp,
752                 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
753                 n_groups: tl.constexpr, d_head: tl.constexpr,
754                 is_causal: tl.constexpr,
755                 BLOCK_Q: tl.constexpr,
756                 BLOCK_K: tl.constexpr,
757                 ):
758    i = tl.program_id(0) * BLOCK_Q
759    z = tl.program_id(1) // n_groups
760    g = tl.program_id(1) % n_groups  # TODO

Create block pointers

763    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
764                            (q_seq_len, d_head),
765                            (d_head, 1),
766                            (i, 0),
767                            (BLOCK_Q, d_head),
768                            (1, 0))
769    p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
770                             (q_seq_len, d_head),
771                             (d_head, 1),
772                             (i, 0),
773                             (BLOCK_Q, d_head),
774                             (1, 0))
775    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
776                             (q_seq_len, d_head),
777                             (d_head, 1),
778                             (i, 0),
779                             (BLOCK_Q, d_head),
780                             (1, 0))
781    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
782                             (d_head, kv_seq_len),
783                             (1, d_head),
784                             (0, 0),
785                             (d_head, BLOCK_K),
786                             (0, 1))
787    p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
788                             (d_head, kv_seq_len),
789                             (1, d_head),
790                             (0, 0),
791                             (d_head, BLOCK_K),
792                             (0, 1))
793    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
794                              (q_seq_len,),
795                              (1,),
796                              (i,),
797                              (BLOCK_Q,),
798                              (0,))
799    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
800                              (q_seq_len,),
801                              (1,),
802                              (i,),
803                              (BLOCK_Q,),
804                              (0,))

Load , , , and outside the loop

807    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
808    b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
809    b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
810    b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")

Initialize

813    b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)

817    if is_causal:

Compute for masked (diagonal) blocks.

819        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
820                                  b_do, b_lse, b_pdp,
821                                  BLOCK_Q, BLOCK_K,
822                                  i=i, j=i,
823                                  steps=BLOCK_Q // BLOCK_K,
824                                  MASK=True,
825                                  q_seq_len=q_seq_len,
826                                  kv_seq_len=kv_seq_len
827                                  )

Compute for other blocks

830        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
831                                  b_do, b_lse, b_pdp,
832                                  BLOCK_Q, BLOCK_K,
833                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
834                                  steps=i // BLOCK_K,
835                                  MASK=False,
836                                  q_seq_len=q_seq_len,
837                                  kv_seq_len=kv_seq_len
838                                  )
839    else:

Iterate through all

841        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
842                                  b_do, b_lse, b_pdp,
843                                  BLOCK_Q, BLOCK_K,
844                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
845                                  steps=tl.cdiv(kv_seq_len, BLOCK_K),
846                                  MASK=False,
847                                  q_seq_len=q_seq_len,
848                                  kv_seq_len=kv_seq_len
849                                  )

b_dq stores so multiply by to get

852    b_dq *= 0.6931471824645996

Save

855    tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))

Inner loop over key

858@triton.jit
859def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
860                       b_do, b_lse, b_pdp,
861                       BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
862                       i, j, steps,
863                       MASK: tl.constexpr,
864                       q_seq_len: tl.constexpr,
865                       kv_seq_len: tl.constexpr):

Offsets

869    offs_i = i + tl.arange(0, BLOCK_Q)
870    offs_j = j + tl.arange(0, BLOCK_K)

Move the pointers

873    p_kT = tl.advance(p_kT, (0, j))
874    p_vT = tl.advance(p_vT, (0, j))
875
876    tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')

Iterate over

879    for _ in range(steps):

Load

881        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")

Load

883        b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")

886        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)

895        b_p = tl.math.exp2(b_s - b_lse[:, None])

Autoregressive masking

898        if MASK:
899            causal_mask = (offs_i[:, None] >= offs_j[None, :])
900            b_p = tl.where(causal_mask, b_p, 0.0)

Mask out if the block is beyond the end of

903        j_mask = offs_j < kv_seq_len
904        b_p = tl.where(j_mask[None, :], b_p, 0.0)

909        b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)

911        b_ds = b_p * (b_dp - b_pdp[:, None])

913        b_dq += tl.dot(b_ds.to(b_kT.dtype), tl.trans(b_kT), out_dtype=HI_PRES_TL)

Increment pointers.

916        offs_j += BLOCK_K
917        p_kT = tl.advance(p_kT, (0, BLOCK_K))
918        p_vT = tl.advance(p_vT, (0, BLOCK_K))

Return accumulated

921    return b_dq