diff --git a/docs/sitemap.xml b/docs/sitemap.xml index 59b8be13..a1eb71f3 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -1086,7 +1086,7 @@ https://nn.labml.ai/transformers/flash/test.html - 2025-07-30T16:30:00+00:00 + 2025-07-31T16:30:00+00:00 1.00 diff --git a/docs/transformers/flash/index.html b/docs/transformers/flash/index.html index c36039cb..62cbda38 100644 --- a/docs/transformers/flash/index.html +++ b/docs/transformers/flash/index.html @@ -73,33 +73,33 @@

Flash Attention

Forward pass

-

You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys:

-

Finally you can compute,

-

-

To make it numerically stable flash attention subtracts the current max of before exponentiating.

+

You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys:

+

Finally you can compute,

+

+

To make it numerically stable flash attention subtracts the current max of before exponentiating.

So it maintains the following while iterating over keys:

- -

For each block of keys it updates them:

-

Then finally,

-

+ +

For each block of keys it updates them:

+

Then finally,

+

Backward pass

-

where is when and otherwise.

-

Flash attention paper introduces to simplify computation.

-

Then,

-

Note: , , , etc are row vectors.

+

where is when and otherwise.

+

Flash attention paper introduces to simplify computation.

+

Then,

+

Note: , , , etc are row vectors.

-
100from typing import Any, Tuple
-101
-102import torch
-103import triton
-104import triton.language as tl
-105
-106HI_PRES_TL: tl.constexpr = tl.float32
-107HI_PRES_TORCH: torch.dtype = torch.float32
+
101from typing import Any, Tuple
+102
+103import torch
+104import triton
+105import triton.language as tl
+106
+107HI_PRES_TL: tl.constexpr = tl.float32
+108HI_PRES_TORCH: torch.dtype = torch.float32
@@ -110,7 +110,7 @@
-
110class AttentionFunc(torch.autograd.Function):
+
111class AttentionFunc(torch.autograd.Function):
@@ -141,9 +141,9 @@
-
111    @staticmethod
-112    def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
-113                causal: bool, sm_scale: float) -> torch.Tensor:
+
112    @staticmethod
+113    def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+114                causal: bool, sm_scale: float) -> torch.Tensor:
@@ -154,10 +154,10 @@
-
125        batch_size, n_heads, q_seq_len, d_head = q.shape
-126        _, k_heads, kv_seq_len, _ = k.shape
-127        assert n_heads % k_heads == 0
-128        n_groups = n_heads // k_heads
+
126        batch_size, n_heads, q_seq_len, d_head = q.shape
+127        _, k_heads, kv_seq_len, _ = k.shape
+128        assert n_heads % k_heads == 0
+129        n_groups = n_heads // k_heads
@@ -169,8 +169,8 @@
-
131        assert d_head == k.shape[-1] == v.shape[-1]
-132        assert d_head in {16, 32, 64, 128, 256}
+
132        assert d_head == k.shape[-1] == v.shape[-1]
+133        assert d_head in {16, 32, 64, 128, 256}
@@ -182,9 +182,9 @@
-
135        q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
-136        k = k.view(batch_size * k_heads, kv_seq_len, d_head)
-137        v = v.view(batch_size * k_heads, kv_seq_len, d_head)
+
136        q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
+137        k = k.view(batch_size * k_heads, kv_seq_len, d_head)
+138        v = v.view(batch_size * k_heads, kv_seq_len, d_head)
@@ -196,10 +196,10 @@
-
140        assert q.is_contiguous()
-141        assert k.is_contiguous()
-142        assert v.is_contiguous()
-143        assert k.stride() == v.stride()
+
141        assert q.is_contiguous()
+142        assert k.is_contiguous()
+143        assert v.is_contiguous()
+144        assert k.stride() == v.stride()
@@ -211,7 +211,7 @@
-
146        o = torch.empty_like(q)
+
147        o = torch.empty_like(q)
@@ -219,11 +219,11 @@ -

Tensor for

+

Tensor for log of sum of exponentials

-
148        lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
+
149        lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
@@ -236,15 +236,15 @@
-
151        grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
-152        _attn_fwd[grid](
-153            q, k, v, sm_scale, lse, o,
-154            n_groups=n_groups,
-155            q_seq_len=q_seq_len,
-156            kv_seq_len=kv_seq_len,
-157            d_head=d_head,
-158            is_causal=causal,
-159        )
+
152        grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
+153        _attn_fwd[grid](
+154            q, k, v, sm_scale, lse, o,
+155            n_groups=n_groups,
+156            q_seq_len=q_seq_len,
+157            kv_seq_len=kv_seq_len,
+158            d_head=d_head,
+159            is_causal=causal,
+160        )
@@ -256,10 +256,10 @@
-
162        ctx.save_for_backward(q, k, v, o, lse)
-163        ctx.sm_scale = sm_scale
-164        ctx.n_groups = n_groups
-165        ctx.causal = causal
+
163        ctx.save_for_backward(q, k, v, o, lse)
+164        ctx.sm_scale = sm_scale
+165        ctx.n_groups = n_groups
+166        ctx.causal = causal
@@ -272,7 +272,7 @@
-
168        return o.view(batch_size, n_heads, q_seq_len, d_head)
+
169        return o.view(batch_size, n_heads, q_seq_len, d_head)
@@ -289,8 +289,8 @@
-
170    @staticmethod
-171    def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
+
171    @staticmethod
+172    def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
@@ -302,10 +302,10 @@
-
180        n_groups = ctx.n_groups
-181        sm_scale = ctx.sm_scale
-182        causal = ctx.causal
-183        q, k, v, o, lse = ctx.saved_tensors
+
181        n_groups = ctx.n_groups
+182        sm_scale = ctx.sm_scale
+183        causal = ctx.causal
+184        q, k, v, o, lse = ctx.saved_tensors
@@ -317,9 +317,9 @@
-
186        batch_size, n_heads, q_seq_len, d_head = do.shape
-187        _, kv_seq_len, _ = k.shape
-188        k_heads = n_heads // n_groups
+
187        batch_size, n_heads, q_seq_len, d_head = do.shape
+188        _, kv_seq_len, _ = k.shape
+189        k_heads = n_heads // n_groups
@@ -331,7 +331,7 @@
-
191        do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
+
192        do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
@@ -343,9 +343,9 @@
-
194        assert do.is_contiguous()
-195        assert k.stride() == v.stride()
-196        assert q.stride() == o.stride() == do.stride()
+
195        assert do.is_contiguous()
+196        assert k.stride() == v.stride()
+197        assert q.stride() == o.stride() == do.stride()
@@ -357,9 +357,9 @@
-
199        dq = torch.empty_like(q)
-200        dk = torch.empty_like(k)
-201        dv = torch.empty_like(v)
+
200        dq = torch.empty_like(q)
+201        dk = torch.empty_like(k)
+202        dv = torch.empty_like(v)
@@ -367,11 +367,11 @@ -

+

-
204        RCP_LN2 = 1.4426950408889634
+
205        RCP_LN2 = 1.4426950408889634
@@ -379,11 +379,11 @@ -

Multiply by softmax scale

+

Precompute

-
206        k_scaled = k * (sm_scale * RCP_LN2)
+
207        k_scaled = k * (sm_scale * RCP_LN2)
@@ -391,11 +391,11 @@ -

+

-
208        pdp = torch.empty_like(lse)
+
209        pdp = torch.empty_like(lse)
@@ -404,11 +404,11 @@ #

We use fixed BLOCK_Q - for backward pass on

+ for backward pass on

-
210        BLOCK_Q = 16
+
@@ -416,22 +416,23 @@ -

Compute

+

Compute

This is parallelized along the batch and query in blocks of size BLOCK_Q

-
214        pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
-215        _attn_bwd_d[pre_grid](
-216            o, do,
-217            pdp,
-218            BLOCK_Q=16,
-219            d_head=d_head,
-220            q_seq_len=q_seq_len,
-221            n_groups=n_groups,
-222            num_stages=1,
-223        )
+
215        BLOCK_Q = 16
+216        pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
+217        _attn_bwd_d[pre_grid](
+218            o, do,
+219            pdp,
+220            BLOCK_Q=16,
+221            d_head=d_head,
+222            q_seq_len=q_seq_len,
+223            n_groups=n_groups,
+224            num_stages=1,
+225        )
@@ -439,20 +440,20 @@ -

Compute and

+

Compute and

This is parallelized along the batch and keys in blocks of size BLOCK_K

-
227        grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
-228        _attn_bwd_dkdv[grid](
-229            q, k_scaled, v, sm_scale, do, dk, dv,
-230            lse, pdp,
-231            q_seq_len, kv_seq_len, n_groups, d_head,
-232            is_causal=causal,
-233
-234        )
+
230        grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
+231        _attn_bwd_dkdv[grid](
+232            q, k_scaled, v, sm_scale, do, dk, dv,
+233            lse, pdp,
+234            q_seq_len, kv_seq_len, n_groups, d_head,
+235            is_causal=causal,
+236
+237        )
@@ -460,20 +461,20 @@ -

Compute

+

Compute

This is parallelized along the batch and queries in blocks of size BLOCK_Q

-
238        grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
-239        _attn_bwd_dq[grid](
-240            q, k_scaled, v, do,
-241            dq,
-242            lse, pdp,
-243            q_seq_len, kv_seq_len, n_groups, d_head,
-244            is_causal=causal,
-245        )
+
242        grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
+243        _attn_bwd_dq[grid](
+244            q, k_scaled, v, do,
+245            dq,
+246            lse, pdp,
+247            q_seq_len, kv_seq_len, n_groups, d_head,
+248            is_causal=causal,
+249        )
@@ -485,9 +486,9 @@
-
248        dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
-249        dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
-250        dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
+
252        dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
+253        dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
+254        dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
@@ -499,10 +500,10 @@
-
253        return dq, dk, dv, None, None
-254
-255
-256attention = AttentionFunc.apply
+
257        return dq, dk, dv, None, None
+258
+259
+260attention = AttentionFunc.apply
@@ -514,7 +515,7 @@
-
259def _get_autotune_configs(inner_loop: str) -> list:
+
263def _get_autotune_configs(inner_loop: str) -> list:
@@ -525,7 +526,7 @@
-
264    configs = []
+
268    configs = []
@@ -537,7 +538,7 @@
-
267    for bm in [64, 128, 256]:
+
271    for bm in [64, 128, 256]:
@@ -549,19 +550,19 @@
-
269        for bn in [64, 128, 256]:
-270            if inner_loop == 'key' and bm % bn != 0:
-271                continue
-272            if inner_loop == 'query' and bn % bm != 0:
-273                continue
-274            for s in [2, 3, 4]:
-275                for w in [4, 8]:
-276                    if bm * bn < 128 * 128 and w == 8:
-277                        continue
-278
-279                    configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
-280
-281    return configs[:1]
+
273        for bn in [64, 128, 256]:
+274            if inner_loop == 'key' and bm % bn != 0:
+275                continue
+276            if inner_loop == 'query' and bn % bm != 0:
+277                continue
+278            for s in [2, 3, 4]:
+279                for w in [4, 8]:
+280                    if bm * bn < 128 * 128 and w == 8:
+281                        continue
+282
+283                    configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
+284
+285    return configs[:1]
@@ -578,7 +579,7 @@
  • sm_scale softmax scale
  • t_lse - (out)
  • + (out)
  • t_o output (out)
  • n_groups @@ -609,18 +610,18 @@
  • -
    284@triton.autotune(_get_autotune_configs(inner_loop='key'),
    -285                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
    -286@triton.jit
    -287def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
    -288              n_groups: tl.constexpr,
    -289              q_seq_len: tl.constexpr,
    -290              kv_seq_len: tl.constexpr,
    -291              d_head: tl.constexpr,
    -292              is_causal: tl.constexpr,
    -293              BLOCK_Q: tl.constexpr,  # q seq len block
    -294              BLOCK_K: tl.constexpr,  # k seq len block
    -295              ):
    +
    288@triton.autotune(_get_autotune_configs(inner_loop='key'),
    +289                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
    +290@triton.jit
    +291def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
    +292              n_groups: tl.constexpr,
    +293              q_seq_len: tl.constexpr,
    +294              kv_seq_len: tl.constexpr,
    +295              d_head: tl.constexpr,
    +296              is_causal: tl.constexpr,
    +297              BLOCK_Q: tl.constexpr,  # q seq len block
    +298              BLOCK_K: tl.constexpr,  # k seq len block
    +299              ):
    @@ -631,9 +632,9 @@
    -
    316    i = tl.program_id(0)
    -317    z = tl.program_id(1) // n_groups
    -318    g = tl.program_id(1) % n_groups
    +
    320    i = tl.program_id(0)
    +321    z = tl.program_id(1) // n_groups
    +322    g = tl.program_id(1) % n_groups
    @@ -645,36 +646,36 @@
    -
    321    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    -322                            (q_seq_len, d_head),
    -323                            (d_head, 1),
    -324                            (i * BLOCK_Q, 0),
    -325                            (BLOCK_Q, d_head),
    -326                            (1, 0))
    -327    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
    -328                            (kv_seq_len, d_head),
    -329                            (d_head, 1),
    -330                            (0, 0),
    -331                            (BLOCK_K, d_head),
    -332                            (1, 0))
    -333    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
    -334                             (d_head, kv_seq_len),
    -335                             (1, d_head),
    -336                             (0, 0),
    -337                             (d_head, BLOCK_K),
    -338                             (0, 1))
    -339    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    -340                            (q_seq_len, d_head),
    -341                            (d_head, 1),
    -342                            (i * BLOCK_Q, 0),
    -343                            (BLOCK_Q, d_head),
    -344                            (1, 0))
    -345    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
    -346                              (q_seq_len,),
    -347                              (1,),
    -348                              (i * BLOCK_Q,),
    -349                              (BLOCK_Q,),
    -350                              (0,))
    +
    325    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    +326                            (q_seq_len, d_head),
    +327                            (d_head, 1),
    +328                            (i * BLOCK_Q, 0),
    +329                            (BLOCK_Q, d_head),
    +330                            (1, 0))
    +331    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
    +332                            (kv_seq_len, d_head),
    +333                            (d_head, 1),
    +334                            (0, 0),
    +335                            (BLOCK_K, d_head),
    +336                            (1, 0))
    +337    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
    +338                             (d_head, kv_seq_len),
    +339                             (1, d_head),
    +340                             (0, 0),
    +341                             (d_head, BLOCK_K),
    +342                             (0, 1))
    +343    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    +344                            (q_seq_len, d_head),
    +345                            (d_head, 1),
    +346                            (i * BLOCK_Q, 0),
    +347                            (BLOCK_Q, d_head),
    +348                            (1, 0))
    +349    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
    +350                              (q_seq_len,),
    +351                              (1,),
    +352                              (i * BLOCK_Q,),
    +353                              (BLOCK_Q,),
    +354                              (0,))
    @@ -686,9 +687,8 @@
    -
    353    offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
    -354    i_mask = offs_i < q_seq_len
    -355    offs_j = tl.arange(0, BLOCK_K)
    +
    357    offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
    +358    offs_j = tl.arange(0, BLOCK_K)
    @@ -696,12 +696,11 @@ -

    Initialize and

    +

    Mask for for the last block

    -
    358    b_m = tl.where(i_mask, -float("inf"), 0.0)
    -359    b_l = tl.where(i_mask, 1.0, 0.0)
    +
    360    i_mask = offs_i < q_seq_len
    @@ -709,11 +708,13 @@ -

    Accumulate

    +

    Precalculate .

    +

    We will be use this when calculating so S + will store instead.

    -
    361    b_acc = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
    +
    365    sm_scale = sm_scale * 1.44269504
    @@ -721,11 +722,14 @@ -

    softmax scale / log(2)

    +

    Initialize and . is initialized to and to . So in the first update, the effect of initial is .

    +

    b_m + will be storing

    -
    364    sm_scale = sm_scale * 1.44269504
    +
    371    b_m = tl.where(i_mask, -float("inf"), 0.0)
    +372    b_l = tl.where(i_mask, 1.0, 0.0)
    @@ -733,13 +737,11 @@ -

    Load

    +

    -
    366    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
    -367
    -368    if is_causal:
    +
    375    b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
    @@ -747,21 +749,13 @@ -

    Upto the diagonal block

    +

    Load outside the loop since it will be reused through out the loop over .

    -
    370        b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
    -371                                          p_kT, p_v,
    -372                                          sm_scale,
    -373                                          BLOCK_Q, d_head, BLOCK_K,
    -374                                          offs_i, offs_j,
    -375                                          j=tl.full([], 0, tl.int32),  # type: ignore
    -376                                          steps=(i * BLOCK_Q) // BLOCK_K,
    -377                                          MASK=False,
    -378                                          q_seq_len=q_seq_len,
    -379                                          kv_seq_len=kv_seq_len
    -380                                          )
    +
    378    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
    +379
    +380    if is_causal:
    @@ -769,31 +763,21 @@ -

    Diagonal block with masking within it

    +

    Inner loop upto the diagonal block

    -
    382        b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
    -383                                          sm_scale,
    -384                                          BLOCK_Q, d_head, BLOCK_K,
    -385                                          offs_i, offs_j,
    -386                                          j=i * BLOCK_Q,
    -387                                          steps=BLOCK_Q // BLOCK_K,
    -388                                          MASK=True,
    -389                                          q_seq_len=q_seq_len,
    -390                                          kv_seq_len=kv_seq_len
    -391                                          )
    -392    else:
    -393        b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
    -394                                          sm_scale,
    -395                                          BLOCK_Q, d_head, BLOCK_K,
    -396                                          offs_i, offs_j,
    -397                                          j=tl.full([], 0, tl.int32),  # type: ignore
    -398                                          steps=tl.cdiv(kv_seq_len, BLOCK_K),
    -399                                          MASK=False,
    -400                                          q_seq_len=q_seq_len,
    -401                                          kv_seq_len=kv_seq_len
    -402                                          )
    +
    382        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
    +383                                          p_kT, p_v,
    +384                                          sm_scale,
    +385                                          BLOCK_Q, d_head, BLOCK_K,
    +386                                          offs_i, offs_j,
    +387                                          j=tl.full([], 0, tl.int32),  # type: ignore
    +388                                          steps=(i * BLOCK_Q) // BLOCK_K,
    +389                                          MASK=False,
    +390                                          q_seq_len=q_seq_len,
    +391                                          kv_seq_len=kv_seq_len
    +392                                          )
    @@ -801,12 +785,21 @@ -

    Update LSE

    +

    Diagonal block with masking within it

    -
    405    tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
    -406    tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
    +
    394        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
    +395                                          sm_scale,
    +396                                          BLOCK_Q, d_head, BLOCK_K,
    +397                                          offs_i, offs_j,
    +398                                          j=i * BLOCK_Q,
    +399                                          steps=BLOCK_Q // BLOCK_K,
    +400                                          MASK=True,
    +401                                          q_seq_len=q_seq_len,
    +402                                          kv_seq_len=kv_seq_len
    +403                                          )
    +404    else:
    @@ -814,27 +807,20 @@ - +

    Iterate through all

    +
    -
    409@triton.jit
    -410def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
    -411                    p_kT, p_v,
    -412                    scale,
    -413                    BLOCK_Q: tl.constexpr,
    -414                    d_head: tl.constexpr,
    -415                    BLOCK_K: tl.constexpr,
    -416                    offs_i, offs_j,
    -417                    j,
    -418                    steps,
    -419                    MASK: tl.constexpr,
    -420                    q_seq_len: tl.constexpr,
    -421                    kv_seq_len: tl.constexpr
    -422                    ):
    -423    tl.static_assert(BLOCK_Q % BLOCK_K == 0)
    -424
    -425    p_kT = tl.advance(p_kT, (0, j))
    -426    p_v = tl.advance(p_v, (j, 0))
    +
    406        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
    +407                                          sm_scale,
    +408                                          BLOCK_Q, d_head, BLOCK_K,
    +409                                          offs_i, offs_j,
    +410                                          j=tl.full([], 0, tl.int32),  # type: ignore
    +411                                          steps=tl.cdiv(kv_seq_len, BLOCK_K),
    +412                                          MASK=False,
    +413                                          q_seq_len=q_seq_len,
    +414                                          kv_seq_len=kv_seq_len
    +415                                          )
    @@ -842,22 +828,11 @@ -

    loop over k, v and update accumulator

    +

    Store LSE

    -
    429    for _ in range(steps):
    -430        current_j = j + offs_j
    -431        j_mask = current_j < kv_seq_len
    -432
    -433        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
    -434        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
    -435
    -436        tl.static_assert(b_s.dtype == HI_PRES_TL)
    -437        b_s = b_s * scale
    -438        if MASK:
    -439            causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
    -440            b_s = tl.where(causal_mask, b_s, -float("inf"))
    +
    418    tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
    @@ -865,11 +840,11 @@ -

    always apply seq mask

    +

    Store

    -
    442        b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
    +
    420    tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
    @@ -877,12 +852,24 @@ -

    - +
    -
    445        tl.static_assert(len(b_s.shape) == 2)
    -446        b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
    +
    423@triton.jit
    +424def _attn_fwd_inner(b_o, b_l, b_m, b_q,
    +425                    p_kT, p_v,
    +426                    scale,
    +427                    BLOCK_Q: tl.constexpr,
    +428                    d_head: tl.constexpr,
    +429                    BLOCK_K: tl.constexpr,
    +430                    offs_i, offs_j,
    +431                    j,
    +432                    steps,
    +433                    MASK: tl.constexpr,
    +434                    q_seq_len: tl.constexpr,
    +435                    kv_seq_len: tl.constexpr
    +436                    ):
    +437    tl.static_assert(BLOCK_Q % BLOCK_K == 0)
    @@ -890,11 +877,12 @@ -

    +

    Move and pointers

    -
    448        b_p = tl.math.exp2(b_s - b_m_new[:, None])
    +
    440    p_kT = tl.advance(p_kT, (0, j))
    +441    p_v = tl.advance(p_v, (j, 0))
    @@ -902,11 +890,11 @@ -

    +

    Iterate over , and update and

    -
    450        b_l_new = tl.sum(b_p, -1)
    +
    444    for _ in range(steps):
    @@ -914,11 +902,11 @@ -

    +

    Load

    -
    453        b_m_m_new = tl.math.exp2(b_m - b_m_new)
    +
    446        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
    @@ -926,11 +914,12 @@ -

    +

    Compute

    -
    455        b_l = b_l * b_m_m_new + b_l_new
    +
    448        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
    +449        b_s = b_s * scale
    @@ -938,14 +927,13 @@ -

    +

    Apply causal mask

    -
    458        b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
    -459        b_acc = b_acc * b_m_m_new[:, None]
    -460        b_p = b_p.to(b_q.dtype)
    -461        b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
    +
    452        if MASK:
    +453            causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
    +454            b_s = tl.where(causal_mask, b_s, -float("inf"))
    @@ -953,11 +941,12 @@ -

    update

    +

    Mask out if the block is beyond the end of

    -
    464        b_m = b_m_new
    +
    457        j_mask = (j + offs_j) < kv_seq_len
    +458        b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
    @@ -965,17 +954,11 @@ -

    Move pointers

    +

    -
    467        j += BLOCK_K
    -468        p_v = tl.advance(p_v, (BLOCK_K, 0))
    -469        p_kT = tl.advance(p_kT, (0, BLOCK_K))
    -470
    -471    tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
    -472
    -473    return b_acc, b_l, b_m
    +
    461        b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
    @@ -983,18 +966,11 @@ - +

    +
    -
    476@triton.jit
    -477def _attn_bwd_d(t_o, t_do,
    -478                t_pdp,
    -479                BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
    -480                q_seq_len: tl.constexpr,
    -481                n_groups: tl.constexpr,
    -482                ):
    -483    i = tl.program_id(0) * BLOCK_Q
    -484    z = tl.program_id(1)
    +
    463        b_p = tl.math.exp2(b_s - b_m_new[:, None])
    @@ -1002,57 +978,23 @@ -

    Create block pointers

    +

    -
    487    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
    -488                            (n_groups, q_seq_len, d_head),
    -489                            (q_seq_len * d_head, d_head, 1),
    -490                            (0, i, 0),
    -491                            (n_groups, BLOCK_Q, d_head),
    -492                            (2, 1, 0))
    -493    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
    -494                             (n_groups, q_seq_len, d_head),
    -495                             (q_seq_len * d_head, d_head, 1),
    -496                             (0, i, 0),
    -497                             (n_groups, BLOCK_Q, d_head),
    -498                             (2, 1, 0))
    -499    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
    -500                              (n_groups, q_seq_len),
    -501                              (q_seq_len, 1),
    -502                              (0, i),
    -503                              (n_groups, BLOCK_Q),
    -504                              (1, 0))
    -505
    -506    o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
    -507    do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
    -508    d = tl.sum(o * do, axis=-1)
    -509    tl.store(p_pdp, d, boundary_check=(1,))
    +
    466        b_l_new = tl.sum(b_p, -1)
    -
    +
    -

    Loop along m query; n % m == 0

    +

    -
    512@triton.autotune(_get_autotune_configs(inner_loop='query'),
    -513                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
    -514@triton.jit
    -515def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
    -516                   t_do,
    -517                   t_dk, t_dv,
    -518                   t_lse, t_pdp,
    -519                   q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
    -520                   n_groups: tl.constexpr, d_head: tl.constexpr,
    -521                   is_causal: tl.constexpr,
    -522                   BLOCK_Q: tl.constexpr,
    -523                   BLOCK_K: tl.constexpr,
    -524                   ):
    +
    468        b_m_m_new = tl.math.exp2(b_m - b_m_new)
    @@ -1060,40 +1002,11 @@ -

    K is already multiplied by scale

    +

    -
    529    j = tl.program_id(0) * BLOCK_K
    -530    z = tl.program_id(1)
    -531
    -532    p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
    -533                            (kv_seq_len, d_head),
    -534                            (d_head, 1),
    -535                            (j, 0),
    -536                            (BLOCK_K, d_head),
    -537                            (1, 0))
    -538    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
    -539                            (kv_seq_len, d_head),
    -540                            (d_head, 1),
    -541                            (j, 0),
    -542                            (BLOCK_K, d_head),
    -543                            (1, 0))
    -544    p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
    -545                             (kv_seq_len, d_head),
    -546                             (d_head, 1),
    -547                             (j, 0),
    -548                             (BLOCK_K, d_head),
    -549                             (1, 0))
    -550    p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
    -551                             (kv_seq_len, d_head),
    -552                             (d_head, 1),
    -553                             (j, 0),
    -554                             (BLOCK_K, d_head),
    -555                             (1, 0))
    -556
    -557    b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
    -558    b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
    +
    470        b_l = b_l * b_m_m_new + b_l_new
    @@ -1101,12 +1014,14 @@ -

    load K and V: they stay in SRAM throughout the inner loop.

    +

    -
    561    b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
    -562    b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
    +
    473        b_o = b_o * b_m_m_new[:, None]
    +474        b_p = b_p.to(b_q.dtype) # TODO
    +475        b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
    +476        b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
    @@ -1114,11 +1029,11 @@ -

    Iterate through queries that attend to save keys

    +

    -
    565    for g in range(n_groups):
    +
    479        b_m = b_m_new
    @@ -1126,35 +1041,17 @@ -

    Create block pointers

    +

    Move pointers

    -
    567        p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    -568                                 (d_head, q_seq_len),
    -569                                 (1, d_head),
    -570                                 (0, 0),
    -571                                 (d_head, BLOCK_Q),
    -572                                 (0, 1))
    -573
    -574        p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    -575                                 (q_seq_len, d_head),
    -576                                 (d_head, 1),
    -577                                 (0, 0),
    -578                                 (BLOCK_Q, d_head),
    -579                                 (1, 0))
    -580        p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
    -581                                  (q_seq_len,),
    -582                                  (1,),
    -583                                  (0,),
    -584                                  (BLOCK_Q,),
    -585                                  (0,))
    -586        p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
    -587                                  (q_seq_len,),
    -588                                  (1,),
    -589                                  (0,),
    -590                                  (BLOCK_Q,),
    -591                                  (0,))
    +
    482        j += BLOCK_K
    +483        p_v = tl.advance(p_v, (BLOCK_K, 0))
    +484        p_kT = tl.advance(p_kT, (0, BLOCK_K))
    +485
    +486    tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
    +487
    +488    return b_o, b_l, b_m
    @@ -1162,11 +1059,18 @@ -

    - +
    -
    +
    491@triton.jit
    +492def _attn_bwd_d(t_o, t_do,
    +493                t_pdp,
    +494                BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
    +495                q_seq_len: tl.constexpr,
    +496                n_groups: tl.constexpr,
    +497                ):
    +498    i = tl.program_id(0) * BLOCK_Q
    +499    z = tl.program_id(1)
    @@ -1174,11 +1078,28 @@ -

    Compute and along the masked blocks near diagonal. Use smaller block size of MASK_BLOCK_Q because there is a little extra computation?

    +

    Create block pointers

    -
    599        if is_causal:
    +
    502    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
    +503                            (n_groups, q_seq_len, d_head),
    +504                            (q_seq_len * d_head, d_head, 1),
    +505                            (0, i, 0),
    +506                            (n_groups, BLOCK_Q, d_head),
    +507                            (2, 1, 0))
    +508    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
    +509                             (n_groups, q_seq_len, d_head),
    +510                             (q_seq_len * d_head, d_head, 1),
    +511                             (0, i, 0),
    +512                             (n_groups, BLOCK_Q, d_head),
    +513                             (2, 1, 0))
    +514    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
    +515                              (n_groups, q_seq_len),
    +516                              (q_seq_len, 1),
    +517                              (0, i),
    +518                              (n_groups, BLOCK_Q),
    +519                              (1, 0))
    @@ -1186,14 +1107,11 @@ -

    loop along m

    +

    Load

    -
    601            b_dk, b_dv = _attn_bwd_dkdv_inner(
    -602                b_dk, b_dv,
    -603                p_qT, b_k, b_v, p_do,
    -604                p_lse, p_pdp,
    +
    522    o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
    @@ -1201,18 +1119,11 @@ -

    You can use a smaller BLOCK_Q if BLOCK_K is not divisible by BLOCK_Q

    +

    Load

    -
    606                BLOCK_Q, BLOCK_K,
    -607                d_head,
    -608                j=j, i=j,
    -609                steps=BLOCK_K // BLOCK_Q,
    -610                MASK=True,
    -611                q_seq_len=q_seq_len,
    -612                kv_seq_len=kv_seq_len,
    -613            )
    +
    524    do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
    @@ -1220,7 +1131,178 @@ -

    Compute and for non-masked blocks.

    +

    Calculate

    + +
    +
    +
    526    d = tl.sum(o * do, axis=-1)
    +
    + +
    +
    + +

    Save

    + +
    +
    +
    528    tl.store(p_pdp, d, boundary_check=(1,))
    +
    +
    +
    +
    + +

    Compute and for by iterating over

    + +
    +
    +
    531@triton.autotune(_get_autotune_configs(inner_loop='query'),
    +532                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
    +533@triton.jit
    +534def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
    +535                   t_do,
    +536                   t_dk, t_dv,
    +537                   t_lse, t_pdp,
    +538                   q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
    +539                   n_groups: tl.constexpr, d_head: tl.constexpr,
    +540                   is_causal: tl.constexpr,
    +541                   BLOCK_Q: tl.constexpr,
    +542                   BLOCK_K: tl.constexpr,
    +543                   ):
    +
    +
    +
    +
    + + +
    +
    +
    548    j = tl.program_id(0) * BLOCK_K
    +549    z = tl.program_id(1)
    +
    +
    +
    +
    + +

    Create block pointers

    + +
    +
    +
    552    p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
    +553                            (kv_seq_len, d_head),
    +554                            (d_head, 1),
    +555                            (j, 0),
    +556                            (BLOCK_K, d_head),
    +557                            (1, 0))
    +558    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
    +559                            (kv_seq_len, d_head),
    +560                            (d_head, 1),
    +561                            (j, 0),
    +562                            (BLOCK_K, d_head),
    +563                            (1, 0))
    +564    p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
    +565                             (kv_seq_len, d_head),
    +566                             (d_head, 1),
    +567                             (j, 0),
    +568                             (BLOCK_K, d_head),
    +569                             (1, 0))
    +570    p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
    +571                             (kv_seq_len, d_head),
    +572                             (d_head, 1),
    +573                             (j, 0),
    +574                             (BLOCK_K, d_head),
    +575                             (1, 0))
    +
    +
    +
    +
    + +

    Initialize and

    + +
    +
    +
    578    b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
    +579    b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
    +
    +
    +
    +
    + +

    Load and outside the loop.

    + +
    +
    +
    582    b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
    +583    b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
    +
    +
    +
    +
    + +

    Iterate through queries in GQA

    + +
    +
    +
    586    for g in range(n_groups):
    +
    +
    +
    +
    + +

    Create block pointers

    + +
    +
    +
    588        p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    +589                                 (d_head, q_seq_len),
    +590                                 (1, d_head),
    +591                                 (0, 0),
    +592                                 (d_head, BLOCK_Q),
    +593                                 (0, 1))
    +594
    +595        p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    +596                                 (q_seq_len, d_head),
    +597                                 (d_head, 1),
    +598                                 (0, 0),
    +599                                 (BLOCK_Q, d_head),
    +600                                 (1, 0))
    +601        p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
    +602                                  (q_seq_len,),
    +603                                  (1,),
    +604                                  (0,),
    +605                                  (BLOCK_Q,),
    +606                                  (0,))
    +607        p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
    +608                                  (q_seq_len,),
    +609                                  (1,),
    +610                                  (0,),
    +611                                  (BLOCK_Q,),
    +612                                  (0,))
    +613
    +614        if is_causal:
    +
    +
    +
    +
    + +

    Inner loop at the diagonal block

    @@ -1230,147 +1312,12 @@ 619 p_lse, p_pdp, 620 BLOCK_Q, BLOCK_K, 621 d_head, -622 j=j, i=j + BLOCK_K, -623 steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q), -624 MASK=False, +622 j=j, i=j, +623 steps=BLOCK_K // BLOCK_Q, +624 MASK=True, 625 q_seq_len=q_seq_len, -626 kv_seq_len=kv_seq_len -627 ) -628 else: -629 b_dk, b_dv = _attn_bwd_dkdv_inner( -630 b_dk, b_dv, -631 p_qT, b_k, b_v, p_do, -632 p_lse, p_pdp, -633 BLOCK_Q, BLOCK_K, -634 d_head, -635 j=j, i=tl.full([], 0, tl.int32), -636 steps=tl.cdiv(q_seq_len, BLOCK_Q), -637 MASK=False, -638 q_seq_len=q_seq_len, -639 kv_seq_len=kv_seq_len -640 )
    -
    - -
    -
    - -

    Save

    - -
    -
    -
    643    tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
    -
    -
    -
    -
    - -

    Since we used where $hat{k} are the original keys we multiple by scale again to get gradient on original keys.

    - -
    -
    -
    647    b_dk *= sm_scale
    -
    -
    -
    -
    - -

    Save

    - -
    -
    -
    650    tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
    -
    -
    -
    -
    - -

    Inner loop along m query

    - -
    -
    -
    653@triton.jit
    -654def _attn_bwd_dkdv_inner(b_dk, b_dv,
    -655                         p_qT, b_k, b_v, p_do,
    -656                         p_lse, p_pdp,
    -657                         BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
    -658                         d_head: tl.constexpr,
    -659                         j, i, steps,
    -660                         MASK: tl.constexpr,
    -661                         q_seq_len: tl.constexpr,
    -662                         kv_seq_len: tl.constexpr):
    -
    -
    -
    -
    - -

    To apply the mask

    - -
    -
    -
    666    tl.static_assert(BLOCK_K % BLOCK_Q == 0)
    -
    -
    -
    -
    - -

    Offsets for mask computation

    - -
    -
    -
    669    offs_i = i + tl.arange(0, BLOCK_Q)
    -670    i_mask = offs_i < q_seq_len
    -671    offs_j = j + tl.arange(0, BLOCK_K)
    -
    -
    -
    -
    - -

    Pointers

    - -
    -
    -
    674    p_qT = tl.advance(p_qT, (0, i))
    -675    p_do = tl.advance(p_do, (i, 0))
    -676    p_lse = tl.advance(p_lse, (i,))
    -677    p_pdp = tl.advance(p_pdp, (i,))
    -
    -
    -
    -
    - -

    Loop

    - -
    -
    -
    680    for _ in range(steps):
    -
    -
    -
    -
    - -

    Load

    - -
    -
    -
    682        b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
    +626 kv_seq_len=kv_seq_len, +627 )
    @@ -1378,11 +1325,23 @@ -

    +

    Innerloop on queries after the diagonal

    -
    685        b_m = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
    +
    630            b_dk, b_dv = _attn_bwd_dkdv_inner(
    +631                b_dk, b_dv,
    +632                p_qT, b_k, b_v, p_do,
    +633                p_lse, p_pdp,
    +634                BLOCK_Q, BLOCK_K,
    +635                d_head,
    +636                j=j, i=j + BLOCK_K,
    +637                steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
    +638                MASK=False,
    +639                q_seq_len=q_seq_len,
    +640                kv_seq_len=kv_seq_len
    +641            )
    +642        else:
    @@ -1390,12 +1349,22 @@ -

    Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of

    +

    Iterate through all queries

    -
    690        b_qkT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
    -691        b_pT = tl.math.exp2(b_qkT - b_m[None, :])
    +
    644            b_dk, b_dv = _attn_bwd_dkdv_inner(
    +645                b_dk, b_dv,
    +646                p_qT, b_k, b_v, p_do,
    +647                p_lse, p_pdp,
    +648                BLOCK_Q, BLOCK_K,
    +649                d_head,
    +650                j=j, i=tl.full([], 0, tl.int32),
    +651                steps=tl.cdiv(q_seq_len, BLOCK_Q),
    +652                MASK=False,
    +653                q_seq_len=q_seq_len,
    +654                kv_seq_len=kv_seq_len
    +655            )
    @@ -1403,15 +1372,11 @@ -

    Autoregressive masking.

    +

    Save

    -
    694        if MASK:
    -695            mask = (offs_i[None, :] >= offs_j[:, None])
    -696            b_pT = tl.where(mask, b_pT, 0.0)
    -697
    -698        b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
    +
    658    tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
    @@ -1419,12 +1384,12 @@ -

    +

    b_dk + had

    -
    701        b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
    -702        b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
    +
    661    b_dk *= sm_scale
    @@ -1432,23 +1397,32 @@ -

    +

    Save

    -
    705        b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
    +
    664    tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
    -
    +
    -

    +

    Inner loop along m query

    -
    707        b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
    +
    667@triton.jit
    +668def _attn_bwd_dkdv_inner(b_dk, b_dv,
    +669                         p_qT, b_k, b_v, p_do,
    +670                         p_lse, p_pdp,
    +671                         BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
    +672                         d_head: tl.constexpr,
    +673                         j, i, steps,
    +674                         MASK: tl.constexpr,
    +675                         q_seq_len: tl.constexpr,
    +676                         kv_seq_len: tl.constexpr):
    @@ -1456,11 +1430,11 @@ -

    +

    To apply the mask

    -
    709        b_dsT = b_pT * (b_dpT - b_pdp[None, :])
    +
    680    tl.static_assert(BLOCK_K % BLOCK_Q == 0)
    @@ -1468,11 +1442,13 @@ -

    +

    Offsets for mask computation

    -
    711        b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
    +
    683    offs_i = i + tl.arange(0, BLOCK_Q)
    +684    i_mask = offs_i < q_seq_len
    +685    offs_j = j + tl.arange(0, BLOCK_K)
    @@ -1480,15 +1456,14 @@ -

    Increment pointers.

    +

    Move the pointers

    -
    714        offs_i += BLOCK_Q
    -715        p_lse = tl.advance(p_lse, (BLOCK_Q,))
    -716        p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
    -717        p_qT = tl.advance(p_qT, (0, BLOCK_Q))
    -718        p_do = tl.advance(p_do, (BLOCK_Q, 0))
    +
    688    p_qT = tl.advance(p_qT, (0, i))
    +689    p_do = tl.advance(p_do, (i, 0))
    +690    p_lse = tl.advance(p_lse, (i,))
    +691    p_pdp = tl.advance(p_pdp, (i,))
    @@ -1496,11 +1471,11 @@ -

    Return accumulated and

    +

    Iterate over

    -
    721    return b_dk, b_dv
    +
    694    for _ in range(steps):
    @@ -1508,21 +1483,11 @@ - +

    Load

    +
    -
    724@triton.autotune(_get_autotune_configs(inner_loop='key'),
    -725                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
    -726@triton.jit
    -727def _attn_bwd_dq(t_q, t_k, t_v, t_do,
    -728                 t_dq,
    -729                 t_lse, t_pdp,
    -730                 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
    -731                 n_groups: tl.constexpr, d_head: tl.constexpr,
    -732                 is_causal: tl.constexpr,
    -733                 BLOCK_Q: tl.constexpr,
    -734                 BLOCK_K: tl.constexpr,
    -735                 ):
    +
    696        b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
    @@ -1530,15 +1495,11 @@ -

    +

    -
    737    LN2: tl.constexpr = 0.6931471824645996  # type: ignore
    -738
    -739    i = tl.program_id(0) * BLOCK_Q
    -740    z = tl.program_id(1) // n_groups
    -741    g = tl.program_id(1) % n_groups
    +
    699        b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
    @@ -1546,60 +1507,11 @@ -

    Create block pointers

    +

    -
    744    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    -745                            (q_seq_len, d_head),
    -746                            (d_head, 1),
    -747                            (i, 0),
    -748                            (BLOCK_Q, d_head),
    -749                            (1, 0))
    -750    p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    -751                             (q_seq_len, d_head),
    -752                             (d_head, 1),
    -753                             (i, 0),
    -754                             (BLOCK_Q, d_head),
    -755                             (1, 0))
    -756    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    -757                             (q_seq_len, d_head),
    -758                             (d_head, 1),
    -759                             (i, 0),
    -760                             (BLOCK_Q, d_head),
    -761                             (1, 0))
    -762    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
    -763                             (d_head, kv_seq_len),
    -764                             (1, d_head),
    -765                             (0, 0),
    -766                             (d_head, BLOCK_K),
    -767                             (0, 1))
    -768    p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
    -769                             (d_head, kv_seq_len),
    -770                             (1, d_head),
    -771                             (0, 0),
    -772                             (d_head, BLOCK_K),
    -773                             (0, 1))
    -774    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
    -775                              (q_seq_len,),
    -776                              (1,),
    -777                              (i,),
    -778                              (BLOCK_Q,),
    -779                              (0,))
    -780    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
    -781                              (q_seq_len,),
    -782                              (1,),
    -783                              (i,),
    -784                              (BLOCK_Q,),
    -785                              (0,))
    -786
    -787    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
    -788    b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
    -789    b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
    -790
    -791    b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
    -792
    -793    b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
    +
    702        b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
    @@ -1607,11 +1519,11 @@ -

    +

    -
    797    if is_causal:
    +
    711        b_pT = tl.math.exp2(b_sT - b_l[None, :])
    @@ -1619,19 +1531,13 @@ -

    Compute for masked (diagonal) blocks.

    +

    Autoregressive masking.

    -
    799        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
    -800                                  b_do, b_lse, b_pdp,
    -801                                  BLOCK_Q, BLOCK_K,
    -802                                  i=i, j=i,
    -803                                  steps=BLOCK_Q // BLOCK_K,
    -804                                  MASK=True,
    -805                                  q_seq_len=q_seq_len,
    -806                                  kv_seq_len=kv_seq_len
    -807                                  )
    +
    714        if MASK:
    +715            mask = (offs_i[None, :] >= offs_j[:, None])
    +716            b_pT = tl.where(mask, b_pT, 0.0)
    @@ -1639,29 +1545,11 @@ -

    Other blocks

    +

    Mask out if the block is beyond the end of

    -
    810        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
    -811                                  b_do, b_lse, b_pdp,
    -812                                  BLOCK_Q, BLOCK_K,
    -813                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
    -814                                  steps=i // BLOCK_K,
    -815                                  MASK=False,
    -816                                  q_seq_len=q_seq_len,
    -817                                  kv_seq_len=kv_seq_len
    -818                                  )
    -819    else:
    -820        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
    -821                                  b_do, b_lse, b_pdp,
    -822                                  BLOCK_Q, BLOCK_K,
    -823                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
    -824                                  steps=tl.cdiv(kv_seq_len, BLOCK_K),
    -825                                  MASK=False,
    -826                                  q_seq_len=q_seq_len,
    -827                                  kv_seq_len=kv_seq_len
    -828                                  )
    +
    719        b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
    @@ -1669,11 +1557,12 @@ -

    Since was scaled by , and got this factor in to computed we need to reverse it.

    +

    -
    832    b_dq *= LN2
    +
    722        b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
    +723        b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
    @@ -1681,30 +1570,23 @@ -

    Save

    +

    -
    835    tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
    +
    726        b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
    -
    +
    -

    Inner loop over n key

    +

    -
    838@triton.jit
    -839def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
    -840                       b_do, b_lse, b_pdp,
    -841                       BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
    -842                       i, j, steps,
    -843                       MASK: tl.constexpr,
    -844                       q_seq_len: tl.constexpr,
    -845                       kv_seq_len: tl.constexpr):
    +
    728        b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
    @@ -1712,20 +1594,11 @@ - +

    +
    -
    847    offs_i = i + tl.arange(0, BLOCK_Q)
    -848    offs_j = tl.arange(0, BLOCK_K)
    -849
    -850    p_kT = tl.advance(p_kT, (0, j))
    -851    p_vT = tl.advance(p_vT, (0, j))
    -852
    -853    tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')
    -854
    -855    for _ in range(steps):
    -856        current_j = j + offs_j
    -857        j_mask = current_j < kv_seq_len
    +
    730        b_dsT = b_pT * (b_dpT - b_pdp[None, :])
    @@ -1733,14 +1606,11 @@ -

    Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of

    +

    -
    862        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
    -863        b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")
    -864        b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
    -865        b_p = tl.math.exp2(b_qk - b_lse[:, None])
    +
    732        b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
    @@ -1748,15 +1618,15 @@ -

    Autoregressive masking.

    +

    Increment pointers.

    -
    868        if MASK:
    -869            causal_mask = (offs_i[:, None] >= current_j[None, :])
    -870            b_p = tl.where(causal_mask, b_p, 0.0)
    -871
    -872        b_p = tl.where(j_mask[None, :], b_p, 0.0)
    +
    735        offs_i += BLOCK_Q
    +736        p_lse = tl.advance(p_lse, (BLOCK_Q,))
    +737        p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
    +738        p_qT = tl.advance(p_qT, (0, BLOCK_Q))
    +739        p_do = tl.advance(p_do, (BLOCK_Q, 0))
    @@ -1764,11 +1634,11 @@ -

    +

    Return accumulated and

    -
    +
    742    return b_dk, b_dv
    @@ -1776,11 +1646,21 @@ -

    - +
    -
    877        b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
    +
    745@triton.autotune(_get_autotune_configs(inner_loop='key'),
    +746                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
    +747@triton.jit
    +748def _attn_bwd_dq(t_q, t_k, t_v, t_do,
    +749                 t_dq,
    +750                 t_lse, t_pdp,
    +751                 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
    +752                 n_groups: tl.constexpr, d_head: tl.constexpr,
    +753                 is_causal: tl.constexpr,
    +754                 BLOCK_Q: tl.constexpr,
    +755                 BLOCK_K: tl.constexpr,
    +756                 ):
    @@ -1788,11 +1668,15 @@ -

    +

    -
    879        b_ds = b_p * (b_dp - b_pdp[:, None])
    +
    758    LN2: tl.constexpr = 0.6931471824645996  # type: ignore
    +759
    +760    i = tl.program_id(0) * BLOCK_Q
    +761    z = tl.program_id(1) // n_groups
    +762    g = tl.program_id(1) % n_groups
    @@ -1800,13 +1684,52 @@ -

    +

    Create block pointers

    -
    881        b_dq += tl.dot(b_ds.to(b_kT.dtype),
    -882                       tl.trans(b_kT),
    -883                       out_dtype=HI_PRES_TL)
    +
    765    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    +766                            (q_seq_len, d_head),
    +767                            (d_head, 1),
    +768                            (i, 0),
    +769                            (BLOCK_Q, d_head),
    +770                            (1, 0))
    +771    p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    +772                             (q_seq_len, d_head),
    +773                             (d_head, 1),
    +774                             (i, 0),
    +775                             (BLOCK_Q, d_head),
    +776                             (1, 0))
    +777    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
    +778                             (q_seq_len, d_head),
    +779                             (d_head, 1),
    +780                             (i, 0),
    +781                             (BLOCK_Q, d_head),
    +782                             (1, 0))
    +783    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
    +784                             (d_head, kv_seq_len),
    +785                             (1, d_head),
    +786                             (0, 0),
    +787                             (d_head, BLOCK_K),
    +788                             (0, 1))
    +789    p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
    +790                             (d_head, kv_seq_len),
    +791                             (1, d_head),
    +792                             (0, 0),
    +793                             (d_head, BLOCK_K),
    +794                             (0, 1))
    +795    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
    +796                              (q_seq_len,),
    +797                              (1,),
    +798                              (i,),
    +799                              (BLOCK_Q,),
    +800                              (0,))
    +801    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
    +802                              (q_seq_len,),
    +803                              (1,),
    +804                              (i,),
    +805                              (BLOCK_Q,),
    +806                              (0,))
    @@ -1814,13 +1737,14 @@ -

    Increment pointers.

    +

    Load , , , and outside the loop

    -
    886        j += BLOCK_K
    -887        p_kT = tl.advance(p_kT, (0, BLOCK_K))
    -888        p_vT = tl.advance(p_vT, (0, BLOCK_K))
    +
    809    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
    +810    b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
    +811    b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
    +812    b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
    @@ -1828,11 +1752,256 @@ -

    Return accumulated

    +

    Initialize

    -
    891    return b_dq
    +
    815    b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
    +
    + +
    +
    + +

    + +
    +
    +
    819    if is_causal:
    +
    +
    +
    +
    + +

    Compute for masked (diagonal) blocks.

    + +
    +
    +
    821        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
    +822                                  b_do, b_lse, b_pdp,
    +823                                  BLOCK_Q, BLOCK_K,
    +824                                  i=i, j=i,
    +825                                  steps=BLOCK_Q // BLOCK_K,
    +826                                  MASK=True,
    +827                                  q_seq_len=q_seq_len,
    +828                                  kv_seq_len=kv_seq_len
    +829                                  )
    +
    +
    +
    +
    + +

    Compute for other blocks

    + +
    +
    +
    832        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
    +833                                  b_do, b_lse, b_pdp,
    +834                                  BLOCK_Q, BLOCK_K,
    +835                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
    +836                                  steps=i // BLOCK_K,
    +837                                  MASK=False,
    +838                                  q_seq_len=q_seq_len,
    +839                                  kv_seq_len=kv_seq_len
    +840                                  )
    +841    else:
    +
    +
    +
    +
    + +

    Iterate through all

    + +
    +
    +
    843        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
    +844                                  b_do, b_lse, b_pdp,
    +845                                  BLOCK_Q, BLOCK_K,
    +846                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
    +847                                  steps=tl.cdiv(kv_seq_len, BLOCK_K),
    +848                                  MASK=False,
    +849                                  q_seq_len=q_seq_len,
    +850                                  kv_seq_len=kv_seq_len
    +851                                  )
    +
    +
    +
    +
    + +

    b_dq + stores so multiply by to get

    + +
    +
    +
    854    b_dq *= LN2
    +
    +
    +
    +
    + +

    Save

    + +
    +
    +
    857    tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
    +
    +
    +
    +
    + +

    Inner loop over n key

    + +
    +
    +
    860@triton.jit
    +861def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
    +862                       b_do, b_lse, b_pdp,
    +863                       BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
    +864                       i, j, steps,
    +865                       MASK: tl.constexpr,
    +866                       q_seq_len: tl.constexpr,
    +867                       kv_seq_len: tl.constexpr):
    +
    +
    +
    +
    + + +
    +
    +
    869    offs_i = i + tl.arange(0, BLOCK_Q)
    +870    offs_j = tl.arange(0, BLOCK_K)
    +871
    +872    p_kT = tl.advance(p_kT, (0, j))
    +873    p_vT = tl.advance(p_vT, (0, j))
    +874
    +875    tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')
    +876
    +877    for _ in range(steps):
    +878        current_j = j + offs_j
    +879        j_mask = current_j < kv_seq_len
    +
    +
    +
    +
    + +

    Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of

    + +
    +
    +
    884        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
    +885        b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")
    +886        b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
    +887        b_p = tl.math.exp2(b_qk - b_lse[:, None])
    +
    +
    +
    +
    + +

    Autoregressive masking.

    + +
    +
    +
    890        if MASK:
    +891            causal_mask = (offs_i[:, None] >= current_j[None, :])
    +892            b_p = tl.where(causal_mask, b_p, 0.0)
    +893
    +894        b_p = tl.where(j_mask[None, :], b_p, 0.0)
    +
    +
    +
    +
    + +

    + +
    +
    +
    +
    +
    +
    +
    + +

    + +
    +
    +
    899        b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
    +
    +
    +
    +
    + +

    + +
    +
    +
    901        b_ds = b_p * (b_dp - b_pdp[:, None])
    +
    +
    +
    +
    + +

    + +
    +
    +
    903        b_dq += tl.dot(b_ds.to(b_kT.dtype),
    +904                       tl.trans(b_kT),
    +905                       out_dtype=HI_PRES_TL)
    +
    +
    +
    +
    + +

    Increment pointers.

    + +
    +
    +
    908        j += BLOCK_K
    +909        p_kT = tl.advance(p_kT, (0, BLOCK_K))
    +910        p_vT = tl.advance(p_vT, (0, BLOCK_K))
    +
    +
    +
    +
    + +

    Return accumulated

    + +
    +
    +
    913    return b_dq