From 1bc2a6980314af1698bea4ce4dadd1b0d167d0f6 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 31 Jul 2025 09:53:14 +0530 Subject: [PATCH] flash comments --- .gitignore | 4 +- docs/transformers/flash/index.html | 1976 +++++++++++++---------- labml_nn/transformers/flash/__init__.py | 126 +- 3 files changed, 1235 insertions(+), 871 deletions(-) diff --git a/.gitignore b/.gitignore index 4336d66b..98c6271e 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,6 @@ html/ diagrams/ .comet.config settings.md -labml_app.log \ No newline at end of file +labml_app.log +/extensions +/.nb_editor \ No newline at end of file diff --git a/docs/transformers/flash/index.html b/docs/transformers/flash/index.html index 7e726113..b3c28739 100644 --- a/docs/transformers/flash/index.html +++ b/docs/transformers/flash/index.html @@ -79,9 +79,10 @@ 6import triton.language as tl 7 8import torch -9 -10HI_PRES_TL: tl.constexpr = tl.float32 -11HI_PRES_TORCH: tl.constexpr = torch.float32 +9from typing import Any, Tuple, Optional +10 +11HI_PRES_TL: tl.constexpr = tl.float32 +12HI_PRES_TORCH: torch.dtype = torch.float32
@@ -92,19 +93,40 @@
-
14class AttentionFunc(torch.autograd.Function):
+
15class AttentionFunc(torch.autograd.Function):
-
+
- +

Group query attention forward pass. Returns the output in shape [batch_size, n_heads, q_seq_len, d_head] +.

+
  • ctx + is the context for torch gradient descent
  • +
  • q + has shape [batch_size, n_heads, q_seq_len, d_head] +
  • +
  • q + has shape [batch_size, n_heads, q_seq_len, d_head] +
  • +
  • k + has shape [batch_size, k_heads, kv_seq_len, d_head] +
  • +
  • v + has shape [batch_size, k_heads, kv_seq_len, d_head] +
  • +
  • causal + whether to apply causal attention mask
  • +
  • sm_scale + softmax scale factor
+
-
15    @staticmethod
-16    def forward(ctx, q, k, v, causal, sm_scale):
+
16    @staticmethod
+17    def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, 
+18                causal: bool, sm_scale: float) -> torch.Tensor:
@@ -112,15 +134,13 @@ -

Shape batch size, n_heads, seq, d

- +
-
18        batch_size, n_heads, q_seq_len, d_head = q.shape
-19        k_heads = k.shape[1]
-20        kv_seq_len = k.shape[2]
-21        assert n_heads % k_heads == 0
-22        n_groups = n_heads // k_heads
+
30        batch_size, n_heads, q_seq_len, d_head = q.shape
+31        _, k_heads, kv_seq_len, _ = k.shape
+32        assert n_heads % k_heads == 0
+33        n_groups = n_heads // k_heads
@@ -128,43 +148,12 @@ -

shape constraints

+

Shape constraints

-
25        assert d_head == k.shape[-1] == v.shape[-1]
-26        assert d_head in {16, 32, 64, 128, 256}
-27
-28        q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
-29        k = k.view(batch_size * k_heads, kv_seq_len, d_head)
-30        v = v.view(batch_size * k_heads, kv_seq_len, d_head)
-31
-32        assert q.is_contiguous()
-33        assert k.is_contiguous()
-34        assert v.is_contiguous()
-35
-36        o = torch.empty_like(q)
-37
-38        lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
-39
-40        grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups, 1)
-41        ctx.grid = grid
-42        _attn_fwd[grid](
-43            q, k, v, sm_scale, lse, o,
-44            n_groups=n_groups,
-45            q_seq_len=q_seq_len,
-46            kv_seq_len=kv_seq_len,
-47            d_head=d_head,
-48            is_causal=causal,
-49        )
-50
-51        ctx.save_for_backward(q, k, v, o, lse)
-52        ctx.sm_scale = sm_scale
-53        ctx.n_groups = n_groups
-54        ctx.d_head = d_head
-55        ctx.causal = causal
-56
-57        return o.view(batch_size, n_heads, q_seq_len, d_head)
+
36        assert d_head == k.shape[-1] == v.shape[-1]
+37        assert d_head in {16, 32, 64, 128, 256}
@@ -172,34 +161,13 @@ - +

Change the tensors combining the heads with the batch dimension

+
-
59    @staticmethod
-60    def backward(ctx, do):
-61        n_groups = ctx.n_groups
-62        sm_scale = ctx.sm_scale
-63        causal = ctx.causal
-64        q, k, v, o, lse = ctx.saved_tensors
-65        batch_size, n_heads, q_seq_len, d_head = do.shape
-66        _, kv_seq_len, _ = k.shape
-67        k_heads = n_heads // n_groups
-68
-69        do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
-70
-71        assert do.is_contiguous()
-72        assert k.stride() == v.stride()
-73        assert q.stride() == o.stride() == do.stride()
-74
-75        dq = torch.empty_like(q)
-76        dk = torch.empty_like(k)
-77        dv = torch.empty_like(v)
-78
-79        RCP_LN2 = 1.4426950408889634  # = 1.0 / ln(2)
-80        arg_k = k * (sm_scale * RCP_LN2)
-81        BLOCK_M = 16
-82        assert q_seq_len % BLOCK_M == 0
-83        pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
+
40        q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
+41        k = k.view(batch_size * k_heads, kv_seq_len, d_head)
+42        v = v.view(batch_size * k_heads, kv_seq_len, d_head)
@@ -207,57 +175,26 @@ -

+

Make sure the tensors are contiguous and the strides are same

-
85        pdp = torch.empty_like(lse)
-86        _attn_bwd_d[pre_grid](
-87            o, do,
-88            pdp,
-89            BLOCK_M=16,
-90            d_head=d_head,
-91            q_seq_len=q_seq_len,
-92            n_groups=n_groups,
-93            num_stages=1,
-94        )
-95        grid = lambda args: (triton.cdiv(kv_seq_len, args['BLOCK_N']), batch_size * k_heads)
-96        _attn_bwd_dkdv[grid](
-97            q, arg_k, v, sm_scale, do, dk, dv,
-98            lse, pdp,
-99            q_seq_len, kv_seq_len, n_groups, d_head,
-100            is_causal=causal,
-101
-102        )
-103        grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups)
-104        _attn_bwd_dq[grid](
-105            q, arg_k, v, do,
-106            dq,
-107            lse, pdp,
-108            q_seq_len, kv_seq_len, n_groups, d_head,
-109            is_causal=causal,
-110        )
-111
-112        dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
-113        dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
-114        dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
-115
-116        return dq, dk, dv, None, None
-117
-118
-119attention = AttentionFunc.apply
+
45        assert q.is_contiguous()
+46        assert k.is_contiguous()
+47        assert v.is_contiguous()
+48        assert k.stride() == v.stride()
-
+
-

Configs for auto-tuning

+

Tensor for the output

-
122def _get_autotune_configs(inner_loop: str):
+
51        o = torch.empty_like(q)
@@ -265,10 +202,11 @@ - +

Tensor for

+
-
127    configs = []
+
53        lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
@@ -276,11 +214,20 @@ -

List possible BLOCK_M and BLOCK_N that satisfy BLOCK_M divisible by BLOCK_N and also try to cover a wide range

+

The forward computation will be parallelized along the batch dimension and the queries in blocks of size BLOCK_M +

-
130    for bm in [64, 128, 256]:
+
56        grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups, 1)
+57        _attn_fwd[grid](
+58            q, k, v, sm_scale, lse, o,
+59            n_groups=n_groups,
+60            q_seq_len=q_seq_len,
+61            kv_seq_len=kv_seq_len,
+62            d_head=d_head,
+63            is_causal=causal,
+64        )
@@ -288,29 +235,323 @@ +

Save the reshaped inputs and outputs for the backward pass

+ +
+
+
67        ctx.save_for_backward(q, k, v, o, lse)
+68        ctx.sm_scale = sm_scale
+69        ctx.n_groups = n_groups
+70        ctx.causal = causal
+
+ +
+
+ +

Return the output in shape [batch_size, n_heads, q_seq_len, d_head] +

+ +
+
+
73        return o.view(batch_size, n_heads, q_seq_len, d_head)
+
+
+
+
+ +

The backward pass computes the gradients of the input tensors.

+
  • ctx + is the context for torch gradient descent
  • +
  • do + is the gradient tensor of the attention output with shape [batch_size, n_heads, q_seq_len, d_head] +
+ +
+
+
75    @staticmethod
+76    def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
+
+
+
+
+ +

Get saved tensors and attributes

+ +
+
+
85        n_groups = ctx.n_groups
+86        sm_scale = ctx.sm_scale
+87        causal = ctx.causal
+88        q, k, v, o, lse = ctx.saved_tensors
+
+
+
+
+ +

Get shapes

+ +
+
+
91        batch_size, n_heads, q_seq_len, d_head = do.shape
+92        _, kv_seq_len, _ = k.shape
+93        k_heads = n_heads // n_groups
+
+
+
+
+ +

Combine the heads with the batch dimension of the output gradients tensor

+ +
+
+
96        do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
+
+
+
+
+ +

Make sure it's contiguous and the strides are the same

+ +
+
+
99        assert do.is_contiguous()
+100        assert k.stride() == v.stride()
+101        assert q.stride() == o.stride() == do.stride()
+
+
+
+
+ +

Create tensors for input gradients

+ +
+
+
104        dq = torch.empty_like(q)
+105        dk = torch.empty_like(k)
+106        dv = torch.empty_like(v)
+
+
+
+
+ +

+ +
+
+
109        RCP_LN2 = 1.4426950408889634
+
+
+
+
+ +

Multiply by softmax scale

+ +
+
+
111        k_scaled = k * (sm_scale * RCP_LN2)
+
+
+
+
+ +

+ +
+
+
113        pdp = torch.empty_like(lse)
+
+
+
+
+ +

We use fixed BLOCK_M + for backward pass on

+ +
+
+
115        BLOCK_M = 16
+116        assert q_seq_len % BLOCK_M == 0
+
+
+
+
+ +

Compute

+

This is parallelized along the batch and query in blocks of size BLOCK_M +

+ +
+
+
120        pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
+121        _attn_bwd_d[pre_grid](
+122            o, do,
+123            pdp,
+124            BLOCK_M=16,
+125            d_head=d_head,
+126            q_seq_len=q_seq_len,
+127            n_groups=n_groups,
+128            num_stages=1,
+129        )
+
+
+
+
+ +

Compute and

+

This is parallelized along the batch and keys in blocks of size BLOCK_N +

+ +
+
+
133        grid = lambda args: (triton.cdiv(kv_seq_len, args['BLOCK_N']), batch_size * k_heads)
+134        _attn_bwd_dkdv[grid](
+135            q, k_scaled, v, sm_scale, do, dk, dv,
+136            lse, pdp,
+137            q_seq_len, kv_seq_len, n_groups, d_head,
+138            is_causal=causal,
+139
+140        )
+
+
+
+
+ +

Compute

+

This is parallelized along the batch and queries in blocks of size BLOCK_M +

+ +
+
+
144        grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups)
+145        _attn_bwd_dq[grid](
+146            q, k_scaled, v, do,
+147            dq,
+148            lse, pdp,
+149            q_seq_len, kv_seq_len, n_groups, d_head,
+150            is_causal=causal,
+151        )
+
+
+
+
+ +

Split the combined batch and heads

+ +
+
+
154        dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
+155        dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
+156        dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
+
+
+
+
+ +

+ +
+
+
159        return dq, dk, dv, None, None
+160
+161
+162attention = AttentionFunc.apply
+
+
+
+
+ +

Configs for auto-tuning

+ +
+
+
165def _get_autotune_configs(inner_loop: str) -> list:
+
+
+
+
+ + +
+
+
170    configs = []
+
+
+
+
+ +

List possible BLOCK_M and BLOCK_N that satisfy BLOCK_M divisible by BLOCK_N and also try to cover a wide range

+ +
+
+
173    for bm in [64, 128, 256]:
+
+
+
+
+

We'll try bn in 16, 32, 64, 128 that are divisors and <= bm

-
132        for bn in [64, 128, 256]:
-133            if inner_loop == 'key' and bm % bn != 0:
-134                continue
-135            if inner_loop == 'query' and bn % bm != 0:
-136                continue
-137            for s in [2, 3, 4]:
-138                for w in [4, 8]:
-139                    if bm * bn < 128 * 128 and w == 8:
-140                        continue
-141
-142                    configs.append(triton.Config({'BLOCK_M': bm, 'BLOCK_N': bn}, num_stages=s, num_warps=w))
-143
-144    return configs
+
175        for bn in [64, 128, 256]:
+176            if inner_loop == 'key' and bm % bn != 0:
+177                continue
+178            if inner_loop == 'query' and bn % bm != 0:
+179                continue
+180            for s in [2, 3, 4]:
+181                for w in [4, 8]:
+182                    if bm * bn < 128 * 128 and w == 8:
+183                        continue
+184
+185                    configs.append(triton.Config({'BLOCK_M': bm, 'BLOCK_N': bn}, num_stages=s, num_warps=w))
+186
+187    return configs
-
+
  • t_q query
  • @@ -321,7 +562,7 @@
  • sm_scale softmax scale
  • t_lse - (out)
  • + (out)
  • t_o output (out)
  • n_groups @@ -352,427 +593,18 @@
-
147@triton.autotune(_get_autotune_configs(inner_loop='key'),
-148                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-149@triton.jit
-150def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
-151              n_groups: tl.constexpr,
-152              q_seq_len: tl.constexpr,
-153              kv_seq_len: tl.constexpr,
-154              d_head: tl.constexpr,
-155              is_causal: tl.constexpr,
-156              BLOCK_M: tl.constexpr,  # q seq len block
-157              BLOCK_N: tl.constexpr,  # k seq len block
-158              ):
-
-
-
-
- - -
-
-
179    start_m = tl.program_id(0)
-180    z = tl.program_id(1) // n_groups
-181    g = tl.program_id(1) % n_groups
-
-
-
-
- -

block pointers

- -
-
-
184    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-185                            (q_seq_len, d_head),
-186                            (d_head, 1),
-187                            (start_m * BLOCK_M, 0),
-188                            (BLOCK_M, d_head),
-189                            (1, 0))
-190    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-191                            (kv_seq_len, d_head),
-192                            (d_head, 1),
-193                            (0, 0),
-194                            (BLOCK_N, d_head),
-195                            (1, 0))
-196    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-197                             (d_head, kv_seq_len),
-198                             (1, d_head),
-199                             (0, 0),
-200                             (d_head, BLOCK_N),
-201                             (0, 1))
-202    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-203                            (q_seq_len, d_head),
-204                            (d_head, 1),
-205                            (start_m * BLOCK_M, 0),
-206                            (BLOCK_M, d_head),
-207                            (1, 0))
-208    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-209                              (q_seq_len,),
-210                              (1,),
-211                              (start_m * BLOCK_M,),
-212                              (BLOCK_M,),
-213                              (0,))
-
-
-
-
- -

initialize offsets

- -
-
-
216    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
-217    offs_n = tl.arange(0, BLOCK_N)
-
-
-
-
- -

Initialize and

- -
-
-
220    b_m = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) - float("inf")
-221    b_l = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) + 1.0
-
-
-
-
- -

Accumulate

- -
-
-
223    b_acc = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)
-
-
-
-
- -

softmax scale / log(2)

- -
-
-
226    sm_scale = sm_scale * 1.44269504
-
-
-
-
- -

Load

- -
-
-
228    b_q = tl.load(p_q)
-229
-230    if is_causal:
-
-
-
-
- -

Run for ranges

- -
-
-
232        b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
-233                                          p_kT, p_v,
-234                                          sm_scale,
-235                                          BLOCK_M, d_head, BLOCK_N,
-236                                          offs_m, offs_n,
-237                                          start_n=tl.full([], 0, tl.int32),  # type: ignore
-238                                          steps=(start_m * BLOCK_M) // BLOCK_N,
-239                                          MASK=False,
-240                                          )
-241        b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
-242                                          sm_scale,
-243                                          BLOCK_M, d_head, BLOCK_N,
-244                                          offs_m, offs_n,
-245                                          start_n=start_m * BLOCK_M,
-246                                          steps=BLOCK_M // BLOCK_N,
-247                                          MASK=True,
-248                                          )
-249    else:
-250        b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
-251                                          sm_scale,
-252                                          BLOCK_M, d_head, BLOCK_N,
-253                                          offs_m, offs_n,
-254                                          start_n=tl.full([], 0, tl.int32),  # type: ignore
-255                                          steps=kv_seq_len // BLOCK_N,
-256                                          MASK=False,
-257                                          )
-
-
-
-
- -

Update LSE

- -
-
-
260    tl.store(p_lse, b_m + tl.math.log2(b_l))
-261    tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty))
-
-
-
-
- - -
-
-
264@triton.jit
-265def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
-266                    p_kT, p_v,
-267                    scale,
-268                    BLOCK_M: tl.constexpr,
-269                    d_head: tl.constexpr,
-270                    BLOCK_N: tl.constexpr,
-271                    offs_m, offs_n,
-272                    start_n,
-273                    steps,
-274                    MASK: tl.constexpr,
-275                    ):
-276    tl.static_assert(BLOCK_M % BLOCK_N == 0)
-277
-278    p_kT = tl.advance(p_kT, (0, start_n))
-279    p_v = tl.advance(p_v, (start_n, 0))
-
-
-
-
- -

loop over k, v and update accumulator

- -
-
-
282    for _ in range(steps):
-283        b_kT = tl.load(p_kT)
-284        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
-285
-286        tl.static_assert(b_s.dtype == HI_PRES_TL)
-287        b_s = b_s * scale
-288        if MASK:
-289            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
-290            b_s = b_s + tl.where(mask, 0, -1.0e6)
-
-
-
-
- -

- -
-
-
293        tl.static_assert(len(b_s.shape) == 2)
-294        b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
-
-
-
-
- -

- -
-
-
296        b_p = tl.math.exp2(b_s - b_m_new[:, None])
-
-
-
-
- -

- -
-
-
298        b_l_new = tl.sum(b_p, -1)
-
-
-
-
- -

- -
-
-
301        b_m_m_new = tl.math.exp2(b_m - b_m_new)
-
-
-
-
- -

- -
-
-
303        b_l = b_l * b_m_m_new + b_l_new
-
-
-
-
- -

- -
-
-
306        b_v = tl.load(p_v)
-307        b_acc = b_acc * b_m_m_new[:, None]
-308        b_p = b_p.to(b_q.dtype)
-309        b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
-
-
-
-
- -

update m_i and l_i

- -
-
-
312        b_m = b_m_new
-313
-314        start_n += BLOCK_N
-315        p_v = tl.advance(p_v, (BLOCK_N, 0))
-316        p_kT = tl.advance(p_kT, (0, BLOCK_N))
-317
-318    tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
-319
-320    return b_acc, b_l, b_m
-
-
-
-
- -

Loop along m query; n % m == 0

- -
-
-
323@triton.jit
-324def _attn_bwd_d(t_o, t_do,
-325                t_pdp,
-326                BLOCK_M: tl.constexpr, d_head: tl.constexpr,
-327                q_seq_len: tl.constexpr,
-328                n_groups: tl.constexpr,
-329                ):
-330    m = tl.program_id(0) * BLOCK_M
-331    z = tl.program_id(1)
-332    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
-333                            (n_groups, q_seq_len, d_head),
-334                            (q_seq_len * d_head, d_head, 1),
-335                            (0, m, 0),
-336                            (n_groups, BLOCK_M, d_head),
-337                            (2, 1, 0))
-338    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
-339                             (n_groups, q_seq_len, d_head),
-340                             (q_seq_len * d_head, d_head, 1),
-341                             (0, m, 0),
-342                             (n_groups, BLOCK_M, d_head),
-343                             (2, 1, 0))
-344    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
-345                              (n_groups, q_seq_len),
-346                              (q_seq_len, 1),
-347                              (0, m),
-348                              (n_groups, BLOCK_M),
-349                              (1, 0))
-350
-351    o = tl.load(p_o)
-352    do = tl.load(p_do).to(HI_PRES_TL)
-353    d = tl.sum(o * do, axis=-1)
-354    tl.store(p_pdp, d)
-355
-356
-357@triton.autotune(_get_autotune_configs(inner_loop='query'),
-358                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-359@triton.jit
-360def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
-361                   t_do,
-362                   t_dk, t_dv,
-363                   t_lse, t_pdp,
-364                   q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-365                   n_groups: tl.constexpr, d_head: tl.constexpr,
-366                   is_causal: tl.constexpr,
-367                   BLOCK_M: tl.constexpr,
-368                   BLOCK_N: tl.constexpr,
-369                   ):
-
-
-
-
- -

K is already multiplied by scale

- -
-
-
374    n = tl.program_id(0)
-375    z = tl.program_id(1)
-376
-377    p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-378                            (kv_seq_len, d_head),
-379                            (d_head, 1),
-380                            (n * BLOCK_N, 0),
-381                            (BLOCK_N, d_head),
-382                            (1, 0))
-383    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-384                            (kv_seq_len, d_head),
-385                            (d_head, 1),
-386                            (n * BLOCK_N, 0),
-387                            (BLOCK_N, d_head),
-388                            (1, 0))
-389    p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
-390                             (kv_seq_len, d_head),
-391                             (d_head, 1),
-392                             (n * BLOCK_N, 0),
-393                             (BLOCK_N, d_head),
-394                             (1, 0))
-395    p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
-396                             (kv_seq_len, d_head),
-397                             (d_head, 1),
-398                             (n * BLOCK_N, 0),
-399                             (BLOCK_N, d_head),
-400                             (1, 0))
-401
-402    b_dv = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)
-403    b_dk = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)
+
190@triton.autotune(_get_autotune_configs(inner_loop='key'),
+191                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+192@triton.jit
+193def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
+194              n_groups: tl.constexpr,
+195              q_seq_len: tl.constexpr,
+196              kv_seq_len: tl.constexpr,
+197              d_head: tl.constexpr,
+198              is_causal: tl.constexpr,
+199              BLOCK_M: tl.constexpr,  # q seq len block
+200              BLOCK_N: tl.constexpr,  # k seq len block
+201              ):
@@ -780,39 +612,12 @@ -

load K and V: they stay in SRAM throughout the inner loop.

- +
-
406    b_k = tl.load(p_k)
-407    b_v = tl.load(p_v)
-408
-409    for g in range(n_groups):
-410        p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-411                                 (d_head, q_seq_len),
-412                                 (1, d_head),
-413                                 (0, 0),
-414                                 (d_head, BLOCK_M),
-415                                 (0, 1))
-416
-417        p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-418                                 (q_seq_len, d_head),
-419                                 (d_head, 1),
-420                                 (0, 0),
-421                                 (BLOCK_M, d_head),
-422                                 (1, 0))
-423        p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-424                                  (q_seq_len,),
-425                                  (1,),
-426                                  (0,),
-427                                  (BLOCK_M,),
-428                                  (0,))
-429        p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-430                                  (q_seq_len,),
-431                                  (1,),
-432                                  (0,),
-433                                  (BLOCK_M,),
-434                                  (0,))
+
222    i = tl.program_id(0)
+223    z = tl.program_id(1) // n_groups
+224    g = tl.program_id(1) % n_groups
@@ -820,11 +625,40 @@ -

+

Create block pointers

-
+
227    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+228                            (q_seq_len, d_head),
+229                            (d_head, 1),
+230                            (i * BLOCK_M, 0),
+231                            (BLOCK_M, d_head),
+232                            (1, 0))
+233    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+234                            (kv_seq_len, d_head),
+235                            (d_head, 1),
+236                            (0, 0),
+237                            (BLOCK_N, d_head),
+238                            (1, 0))
+239    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+240                             (d_head, kv_seq_len),
+241                             (1, d_head),
+242                             (0, 0),
+243                             (d_head, BLOCK_N),
+244                             (0, 1))
+245    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+246                            (q_seq_len, d_head),
+247                            (d_head, 1),
+248                            (i * BLOCK_M, 0),
+249                            (BLOCK_M, d_head),
+250                            (1, 0))
+251    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+252                              (q_seq_len,),
+253                              (1,),
+254                              (i * BLOCK_M,),
+255                              (BLOCK_M,),
+256                              (0,))
@@ -832,11 +666,12 @@ -

Compute and along the masked blocks near diagonal. Use smaller block size of MASK_BLOCK_M because there is a little extra computation?

+

Initialize offsets

-
442        if is_causal:
+
259    offs_i = i * BLOCK_M + tl.arange(0, BLOCK_M)
+260    offs_j = tl.arange(0, BLOCK_N)
@@ -844,14 +679,12 @@ -

loop along m

+

Initialize and

-
444            b_dk, b_dv = _attn_bwd_dkdv_inner(
-445                b_dk, b_dv,
-446                p_qT, b_k, b_v, p_do,
-447                p_lse, p_pdp,
+
263    b_m = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) - float("inf")
+264    b_l = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) + 1.0
@@ -859,16 +692,11 @@ -

You can use a smaller BLOCK_M if BLOCK_N is not divisible by BLOCK_M

+

Accumulate

-
449                BLOCK_M, BLOCK_N,
-450                d_head,
-451                n=n * BLOCK_N, start_m=n * BLOCK_N,
-452                steps=BLOCK_N // BLOCK_M,
-453                MASK=True
-454            )
+
266    b_acc = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)
@@ -876,31 +704,11 @@ -

Compute and for non-masked blocks.

+

softmax scale / log(2)

-
457            b_dk, b_dv = _attn_bwd_dkdv_inner(
-458                b_dk, b_dv,
-459                p_qT, b_k, b_v, p_do,
-460                p_lse, p_pdp,
-461                BLOCK_M, BLOCK_N,
-462                d_head,
-463                n=n * BLOCK_N, start_m=(n + 1) * BLOCK_N,
-464                steps=(q_seq_len - (n + 1) * BLOCK_N) // BLOCK_M,
-465                MASK=False,
-466            )
-467        else:
-468            b_dk, b_dv = _attn_bwd_dkdv_inner(
-469                b_dk, b_dv,
-470                p_qT, b_k, b_v, p_do,
-471                p_lse, p_pdp,
-472                BLOCK_M, BLOCK_N,
-473                d_head,
-474                n=n * BLOCK_N, start_m=tl.full([], 0, tl.int32),
-475                steps=q_seq_len // BLOCK_M,
-476                MASK=False,
-477            )
+
269    sm_scale = sm_scale * 1.44269504
@@ -908,11 +716,13 @@ -

Save

+

Load

-
480    tl.store(p_dv, b_dv.to(t_dv.type.element_ty))
+
271    b_q = tl.load(p_q)
+272
+273    if is_causal:
@@ -920,11 +730,19 @@ -

Since we used where $hat{k} are the original keys we multiple by scale again to get gradient on original keys.

+

Upto the diagonal block

-
484    b_dk *= sm_scale
+
275        b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
+276                                          p_kT, p_v,
+277                                          sm_scale,
+278                                          BLOCK_M, d_head, BLOCK_N,
+279                                          offs_i, offs_j,
+280                                          start_n=tl.full([], 0, tl.int32),  # type: ignore
+281                                          steps=(i * BLOCK_M) // BLOCK_N,
+282                                          MASK=False,
+283                                          )
@@ -932,30 +750,40 @@ -

Save

+

Diagonal block with masking within it

-
487    tl.store(p_dk, b_dk.to(t_dk.type.element_ty))
+
285        b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
+286                                          sm_scale,
+287                                          BLOCK_M, d_head, BLOCK_N,
+288                                          offs_i, offs_j,
+289                                          start_n=i * BLOCK_M,
+290                                          steps=BLOCK_M // BLOCK_N,
+291                                          MASK=True,
+292                                          )
+293    else:
+294        b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
+295                                          sm_scale,
+296                                          BLOCK_M, d_head, BLOCK_N,
+297                                          offs_i, offs_j,
+298                                          start_n=tl.full([], 0, tl.int32),  # type: ignore
+299                                          steps=kv_seq_len // BLOCK_N,
+300                                          MASK=False,
+301                                          )
-
+
-

Inner loop along m query

+

Update LSE

-
490@triton.jit
-491def _attn_bwd_dkdv_inner(b_dk, b_dv,
-492                         p_qT, b_k, b_v, p_do,
-493                         p_lse, p_pdp,
-494                         BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-495                         d_head: tl.constexpr,
-496                         n, start_m, steps,
-497                         MASK: tl.constexpr):
+
304    tl.store(p_lse, b_m + tl.math.log2(b_l))
+305    tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty))
@@ -963,11 +791,25 @@ -

To apply the mask

- +
-
501    tl.static_assert(BLOCK_N % BLOCK_M == 0)
+
308@triton.jit
+309def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
+310                    p_kT, p_v,
+311                    scale,
+312                    BLOCK_M: tl.constexpr,
+313                    d_head: tl.constexpr,
+314                    BLOCK_N: tl.constexpr,
+315                    offs_m, offs_n,
+316                    start_n,
+317                    steps,
+318                    MASK: tl.constexpr,
+319                    ):
+320    tl.static_assert(BLOCK_M % BLOCK_N == 0)
+321
+322    p_kT = tl.advance(p_kT, (0, start_n))
+323    p_v = tl.advance(p_v, (start_n, 0))
@@ -975,12 +817,19 @@ -

Offsets for mask computation

+

loop over k, v and update accumulator

-
504    offs_m = start_m + tl.arange(0, BLOCK_M)
-505    offs_n = n + tl.arange(0, BLOCK_N)
+
326    for _ in range(steps):
+327        b_kT = tl.load(p_kT)
+328        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
+329
+330        tl.static_assert(b_s.dtype == HI_PRES_TL)
+331        b_s = b_s * scale
+332        if MASK:
+333            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
+334            b_s = b_s + tl.where(mask, 0, -1.0e6)
@@ -988,14 +837,12 @@ -

Pointers

+

-
508    p_qT = tl.advance(p_qT, (0, start_m))
-509    p_do = tl.advance(p_do, (start_m, 0))
-510    p_lse = tl.advance(p_lse, (start_m,))
-511    p_pdp = tl.advance(p_pdp, (start_m,))
+
337        tl.static_assert(len(b_s.shape) == 2)
+338        b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
@@ -1003,11 +850,11 @@ -

Loop

+

-
514    for _ in range(steps):
+
340        b_p = tl.math.exp2(b_s - b_m_new[:, None])
@@ -1015,11 +862,11 @@ -

Load

+

-
516        b_qT = tl.load(p_qT)
+
342        b_l_new = tl.sum(b_p, -1)
@@ -1027,11 +874,11 @@ -

+

-
519        b_m = tl.load(p_lse)
+
345        b_m_m_new = tl.math.exp2(b_m - b_m_new)
@@ -1039,12 +886,11 @@ -

Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of

+

-
524        b_qkT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
-525        b_pT = tl.math.exp2(b_qkT - b_m[None, :])
+
347        b_l = b_l * b_m_m_new + b_l_new
@@ -1052,13 +898,14 @@ -

Autoregressive masking.

+

-
528        if MASK:
-529            mask = (offs_m[None, :] >= offs_n[:, None])
-530            b_pT = tl.where(mask, b_pT, 0.0)
+
350        b_v = tl.load(p_v)
+351        b_acc = b_acc * b_m_m_new[:, None]
+352        b_p = b_p.to(b_q.dtype)
+353        b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
@@ -1066,14 +913,11 @@ -

+

update

-
533        b_do = tl.load(p_do)
-534        b_dv += tl.dot(b_pT.to(b_do.dtype),
-535                       b_do,
-536                       out_dtype=HI_PRES_TL)
+
356        b_m = b_m_new
@@ -1081,11 +925,17 @@ -

+

Move pointers

-
539        b_pdp = tl.load(p_pdp)
+
359        start_n += BLOCK_N
+360        p_v = tl.advance(p_v, (BLOCK_N, 0))
+361        p_kT = tl.advance(p_kT, (0, BLOCK_N))
+362
+363    tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
+364
+365    return b_acc, b_l, b_m
@@ -1093,11 +943,18 @@ -

- +
-
541        b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
+
368@triton.jit
+369def _attn_bwd_d(t_o, t_do,
+370                t_pdp,
+371                BLOCK_M: tl.constexpr, d_head: tl.constexpr,
+372                q_seq_len: tl.constexpr,
+373                n_groups: tl.constexpr,
+374                ):
+375    i = tl.program_id(0) * BLOCK_M
+376    z = tl.program_id(1)
@@ -1105,24 +962,57 @@ -

+

Create block pointers

-
543        b_dsT = b_pT * (b_dpT - b_pdp[None, :])
+
379    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
+380                            (n_groups, q_seq_len, d_head),
+381                            (q_seq_len * d_head, d_head, 1),
+382                            (0, i, 0),
+383                            (n_groups, BLOCK_M, d_head),
+384                            (2, 1, 0))
+385    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
+386                             (n_groups, q_seq_len, d_head),
+387                             (q_seq_len * d_head, d_head, 1),
+388                             (0, i, 0),
+389                             (n_groups, BLOCK_M, d_head),
+390                             (2, 1, 0))
+391    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
+392                              (n_groups, q_seq_len),
+393                              (q_seq_len, 1),
+394                              (0, i),
+395                              (n_groups, BLOCK_M),
+396                              (1, 0))
+397
+398    o = tl.load(p_o)
+399    do = tl.load(p_do).to(HI_PRES_TL)
+400    d = tl.sum(o * do, axis=-1)
+401    tl.store(p_pdp, d)
-
+
-

+

Loop along m query; n % m == 0

-
545        b_dk += tl.dot(b_dsT.to(b_qT.dtype),
-546                       tl.trans(b_qT), out_dtype=HI_PRES_TL)
+
404@triton.autotune(_get_autotune_configs(inner_loop='query'),
+405                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+406@triton.jit
+407def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
+408                   t_do,
+409                   t_dk, t_dv,
+410                   t_lse, t_pdp,
+411                   q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+412                   n_groups: tl.constexpr, d_head: tl.constexpr,
+413                   is_causal: tl.constexpr,
+414                   BLOCK_M: tl.constexpr,
+415                   BLOCK_N: tl.constexpr,
+416                   ):
@@ -1130,15 +1020,40 @@ -

Increment pointers.

+

K is already multiplied by scale

-
549        offs_m += BLOCK_M
-550        p_lse = tl.advance(p_lse, (BLOCK_M,))
-551        p_pdp = tl.advance(p_pdp, (BLOCK_M,))
-552        p_qT = tl.advance(p_qT, (0, BLOCK_M))
-553        p_do = tl.advance(p_do, (BLOCK_M, 0))
+
421    n = tl.program_id(0)
+422    z = tl.program_id(1)
+423
+424    p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+425                            (kv_seq_len, d_head),
+426                            (d_head, 1),
+427                            (n * BLOCK_N, 0),
+428                            (BLOCK_N, d_head),
+429                            (1, 0))
+430    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+431                            (kv_seq_len, d_head),
+432                            (d_head, 1),
+433                            (n * BLOCK_N, 0),
+434                            (BLOCK_N, d_head),
+435                            (1, 0))
+436    p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
+437                             (kv_seq_len, d_head),
+438                             (d_head, 1),
+439                             (n * BLOCK_N, 0),
+440                             (BLOCK_N, d_head),
+441                             (1, 0))
+442    p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
+443                             (kv_seq_len, d_head),
+444                             (d_head, 1),
+445                             (n * BLOCK_N, 0),
+446                             (BLOCK_N, d_head),
+447                             (1, 0))
+448
+449    b_dv = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)
+450    b_dk = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)
@@ -1146,11 +1061,12 @@ -

Return accumulated and

+

load K and V: they stay in SRAM throughout the inner loop.

-
556    return b_dk, b_dv
+
453    b_k = tl.load(p_k)
+454    b_v = tl.load(p_v)
@@ -1158,21 +1074,11 @@ - +

Iterate through queries that attend to save keys

+
-
559@triton.autotune(_get_autotune_configs(inner_loop='key'),
-560                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-561@triton.jit
-562def _attn_bwd_dq(t_q, t_k, t_v, t_do,
-563                 t_dq,
-564                 t_lse, t_pdp,
-565                 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-566                 n_groups: tl.constexpr, d_head: tl.constexpr,
-567                 is_causal: tl.constexpr,
-568                 BLOCK_M: tl.constexpr,
-569                 BLOCK_N: tl.constexpr,
-570                 ):
+
457    for g in range(n_groups):
@@ -1180,66 +1086,35 @@ -

+

Create block pointers

-
572    LN2: tl.constexpr = 0.6931471824645996  # type: ignore
-573
-574    m = tl.program_id(0)
-575    z = tl.program_id(1) // n_groups
-576    g = tl.program_id(1) % n_groups
-577
-578    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-579                            (q_seq_len, d_head),
-580                            (d_head, 1),
-581                            (m * BLOCK_M, 0),
-582                            (BLOCK_M, d_head),
-583                            (1, 0))
-584    p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-585                             (q_seq_len, d_head),
-586                             (d_head, 1),
-587                             (m * BLOCK_M, 0),
-588                             (BLOCK_M, d_head),
-589                             (1, 0))
-590    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-591                             (q_seq_len, d_head),
-592                             (d_head, 1),
-593                             (m * BLOCK_M, 0),
-594                             (BLOCK_M, d_head),
-595                             (1, 0))
-596    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-597                             (d_head, kv_seq_len),
-598                             (1, d_head),
-599                             (0, 0),
-600                             (d_head, BLOCK_N),
-601                             (0, 1))
-602    p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-603                             (d_head, kv_seq_len),
-604                             (1, d_head),
-605                             (0, 0),
-606                             (d_head, BLOCK_N),
-607                             (0, 1))
-608    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-609                              (q_seq_len,),
-610                              (1,),
-611                              (m * BLOCK_M,),
-612                              (BLOCK_M,),
-613                              (0,))
-614    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-615                              (q_seq_len,),
-616                              (1,),
-617                              (m * BLOCK_M,),
-618                              (BLOCK_M,),
-619                              (0,))
-620
-621    b_q = tl.load(p_q)
-622    b_do = tl.load(p_do)
-623    b_pdp = tl.load(p_pdp)
-624
-625    b_dq = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)
-626
-627    b_lse = tl.load(p_lse)
+
459        p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+460                                 (d_head, q_seq_len),
+461                                 (1, d_head),
+462                                 (0, 0),
+463                                 (d_head, BLOCK_M),
+464                                 (0, 1))
+465
+466        p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+467                                 (q_seq_len, d_head),
+468                                 (d_head, 1),
+469                                 (0, 0),
+470                                 (BLOCK_M, d_head),
+471                                 (1, 0))
+472        p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+473                                  (q_seq_len,),
+474                                  (1,),
+475                                  (0,),
+476                                  (BLOCK_M,),
+477                                  (0,))
+478        p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+479                                  (q_seq_len,),
+480                                  (1,),
+481                                  (0,),
+482                                  (BLOCK_M,),
+483                                  (0,))
@@ -1247,11 +1122,11 @@ -

+

-
631    if is_causal:
+
@@ -1259,17 +1134,11 @@ -

Compute for masked (diagonal) blocks.

+

Compute and along the masked blocks near diagonal. Use smaller block size of MASK_BLOCK_M because there is a little extra computation?

-
633        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-634                                  b_do, b_lse, b_pdp,
-635                                  BLOCK_M, BLOCK_N,
-636                                  m=m * BLOCK_M, start_n=m * BLOCK_M,
-637                                  steps=BLOCK_M // BLOCK_N,
-638                                  MASK=True
-639                                  )
+
491        if is_causal:
@@ -1277,25 +1146,14 @@ -

Other blocks

+

loop along m

-
642        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-643                                  b_do, b_lse, b_pdp,
-644                                  BLOCK_M, BLOCK_N,
-645                                  m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32),  # type: ignore
-646                                  steps=(m * BLOCK_M) // BLOCK_N,
-647                                  MASK=False
-648                                  )
-649    else:
-650        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-651                                  b_do, b_lse, b_pdp,
-652                                  BLOCK_M, BLOCK_N,
-653                                  m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32),  # type: ignore
-654                                  steps=kv_seq_len // BLOCK_N,
-655                                  MASK=False
-656                                  )
+
493            b_dk, b_dv = _attn_bwd_dkdv_inner(
+494                b_dk, b_dv,
+495                p_qT, b_k, b_v, p_do,
+496                p_lse, p_pdp,
@@ -1303,11 +1161,16 @@ -

Since was scaled by , and got this factor in to computed we need to reverse it.

+

You can use a smaller BLOCK_M if BLOCK_N is not divisible by BLOCK_M

-
660    b_dq *= LN2
+
498                BLOCK_M, BLOCK_N,
+499                d_head,
+500                n=n * BLOCK_N, start_m=n * BLOCK_N,
+501                steps=BLOCK_N // BLOCK_M,
+502                MASK=True
+503            )
@@ -1315,28 +1178,43 @@ -

Save

+

Compute and for non-masked blocks.

-
663    tl.store(p_dq, b_dq.to(t_dq.type.element_ty))
+
506            b_dk, b_dv = _attn_bwd_dkdv_inner(
+507                b_dk, b_dv,
+508                p_qT, b_k, b_v, p_do,
+509                p_lse, p_pdp,
+510                BLOCK_M, BLOCK_N,
+511                d_head,
+512                n=n * BLOCK_N, start_m=(n + 1) * BLOCK_N,
+513                steps=(q_seq_len - (n + 1) * BLOCK_N) // BLOCK_M,
+514                MASK=False,
+515            )
+516        else:
+517            b_dk, b_dv = _attn_bwd_dkdv_inner(
+518                b_dk, b_dv,
+519                p_qT, b_k, b_v, p_do,
+520                p_lse, p_pdp,
+521                BLOCK_M, BLOCK_N,
+522                d_head,
+523                n=n * BLOCK_N, start_m=tl.full([], 0, tl.int32),
+524                steps=q_seq_len // BLOCK_M,
+525                MASK=False,
+526            )
-
+
-

Inner loop over n key

+

Save

-
666@triton.jit
-667def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-668                       b_do, b_lse, b_pdp,
-669                       BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-670                       m, start_n, steps,
-671                       MASK: tl.constexpr):
+
529    tl.store(p_dv, b_dv.to(t_dv.type.element_ty))
@@ -1344,17 +1222,11 @@ - +

Since we used where $hat{k} are the original keys we multiple by scale again to get gradient on original keys.

+
-
673    offs_m = m + tl.arange(0, BLOCK_M)
-674
-675    p_kT = tl.advance(p_kT, (0, start_n))
-676    p_vT = tl.advance(p_vT, (0, start_n))
-677
-678    tl.static_assert(BLOCK_M % BLOCK_N == 0, 'BLOCK_M must be divisible by BLOCK_N')
-679
-680    for _ in range(steps):
+
533    b_dk *= sm_scale
@@ -1362,29 +1234,30 @@ -

Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of

+

Save

-
684        b_kT = tl.load(p_kT)
-685        b_vT = tl.load(p_vT)
-686        b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
-687        b_p = tl.math.exp2(b_qk - b_lse[:, None])
+
536    tl.store(p_dk, b_dk.to(t_dk.type.element_ty))
-
+
-

Autoregressive masking.

+

Inner loop along m query

-
690        if MASK:
-691            offs_n = start_n + tl.arange(0, BLOCK_N)
-692            mask = (offs_m[:, None] >= offs_n[None, :])
-693            b_p = tl.where(mask, b_p, 0.0)
+
539@triton.jit
+540def _attn_bwd_dkdv_inner(b_dk, b_dv,
+541                         p_qT, b_k, b_v, p_do,
+542                         p_lse, p_pdp,
+543                         BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
+544                         d_head: tl.constexpr,
+545                         n, start_m, steps,
+546                         MASK: tl.constexpr):
@@ -1392,11 +1265,11 @@ -

+

To apply the mask

-
+
550    tl.static_assert(BLOCK_N % BLOCK_M == 0)
@@ -1404,11 +1277,12 @@ -

+

Offsets for mask computation

-
698        b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
+
553    offs_m = start_m + tl.arange(0, BLOCK_M)
+554    offs_n = n + tl.arange(0, BLOCK_N)
@@ -1416,11 +1290,14 @@ -

+

Pointers

-
700        b_ds = b_p * (b_dp - b_pdp[:, None])
+
557    p_qT = tl.advance(p_qT, (0, start_m))
+558    p_do = tl.advance(p_do, (start_m, 0))
+559    p_lse = tl.advance(p_lse, (start_m,))
+560    p_pdp = tl.advance(p_pdp, (start_m,))
@@ -1428,13 +1305,11 @@ -

+

Loop

-
702        b_dq += tl.dot(b_ds.to(b_kT.dtype),
-703                       tl.trans(b_kT),
-704                       out_dtype=HI_PRES_TL)
+
563    for _ in range(steps):
@@ -1442,13 +1317,11 @@ -

Increment pointers.

+

Load

-
707        start_n += BLOCK_N
-708        p_kT = tl.advance(p_kT, (0, BLOCK_N))
-709        p_vT = tl.advance(p_vT, (0, BLOCK_N))
+
565        b_qT = tl.load(p_qT)
@@ -1456,11 +1329,450 @@ -

Return accumulated

+

-
712    return b_dq
+
568        b_m = tl.load(p_lse)
+
+ +
+
+ +

Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of

+ +
+
+
573        b_qkT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
+574        b_pT = tl.math.exp2(b_qkT - b_m[None, :])
+
+
+
+
+ +

Autoregressive masking.

+ +
+
+
577        if MASK:
+578            mask = (offs_m[None, :] >= offs_n[:, None])
+579            b_pT = tl.where(mask, b_pT, 0.0)
+
+
+
+
+ +

+ +
+
+
582        b_do = tl.load(p_do)
+583        b_dv += tl.dot(b_pT.to(b_do.dtype),
+584                       b_do,
+585                       out_dtype=HI_PRES_TL)
+
+
+
+
+ +

+ +
+
+
588        b_pdp = tl.load(p_pdp)
+
+
+
+
+ +

+ +
+
+
590        b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
+
+
+
+
+ +

+ +
+
+
592        b_dsT = b_pT * (b_dpT - b_pdp[None, :])
+
+
+
+
+ +

+ +
+
+
594        b_dk += tl.dot(b_dsT.to(b_qT.dtype),
+595                       tl.trans(b_qT), out_dtype=HI_PRES_TL)
+
+
+
+
+ +

Increment pointers.

+ +
+
+
598        offs_m += BLOCK_M
+599        p_lse = tl.advance(p_lse, (BLOCK_M,))
+600        p_pdp = tl.advance(p_pdp, (BLOCK_M,))
+601        p_qT = tl.advance(p_qT, (0, BLOCK_M))
+602        p_do = tl.advance(p_do, (BLOCK_M, 0))
+
+
+
+
+ +

Return accumulated and

+ +
+
+
605    return b_dk, b_dv
+
+
+
+
+ + +
+
+
608@triton.autotune(_get_autotune_configs(inner_loop='key'),
+609                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+610@triton.jit
+611def _attn_bwd_dq(t_q, t_k, t_v, t_do,
+612                 t_dq,
+613                 t_lse, t_pdp,
+614                 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+615                 n_groups: tl.constexpr, d_head: tl.constexpr,
+616                 is_causal: tl.constexpr,
+617                 BLOCK_M: tl.constexpr,
+618                 BLOCK_N: tl.constexpr,
+619                 ):
+
+
+
+
+ +

+ +
+
+
621    LN2: tl.constexpr = 0.6931471824645996  # type: ignore
+622
+623    m = tl.program_id(0)
+624    z = tl.program_id(1) // n_groups
+625    g = tl.program_id(1) % n_groups
+
+
+
+
+ +

Create block pointers

+ +
+
+
628    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+629                            (q_seq_len, d_head),
+630                            (d_head, 1),
+631                            (m * BLOCK_M, 0),
+632                            (BLOCK_M, d_head),
+633                            (1, 0))
+634    p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+635                             (q_seq_len, d_head),
+636                             (d_head, 1),
+637                             (m * BLOCK_M, 0),
+638                             (BLOCK_M, d_head),
+639                             (1, 0))
+640    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+641                             (q_seq_len, d_head),
+642                             (d_head, 1),
+643                             (m * BLOCK_M, 0),
+644                             (BLOCK_M, d_head),
+645                             (1, 0))
+646    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+647                             (d_head, kv_seq_len),
+648                             (1, d_head),
+649                             (0, 0),
+650                             (d_head, BLOCK_N),
+651                             (0, 1))
+652    p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+653                             (d_head, kv_seq_len),
+654                             (1, d_head),
+655                             (0, 0),
+656                             (d_head, BLOCK_N),
+657                             (0, 1))
+658    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+659                              (q_seq_len,),
+660                              (1,),
+661                              (m * BLOCK_M,),
+662                              (BLOCK_M,),
+663                              (0,))
+664    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+665                              (q_seq_len,),
+666                              (1,),
+667                              (m * BLOCK_M,),
+668                              (BLOCK_M,),
+669                              (0,))
+670
+671    b_q = tl.load(p_q)
+672    b_do = tl.load(p_do)
+673    b_pdp = tl.load(p_pdp)
+674
+675    b_dq = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)
+676
+677    b_lse = tl.load(p_lse)
+
+
+
+
+ +

+ +
+
+
681    if is_causal:
+
+
+
+
+ +

Compute for masked (diagonal) blocks.

+ +
+
+
683        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+684                                  b_do, b_lse, b_pdp,
+685                                  BLOCK_M, BLOCK_N,
+686                                  m=m * BLOCK_M, start_n=m * BLOCK_M,
+687                                  steps=BLOCK_M // BLOCK_N,
+688                                  MASK=True
+689                                  )
+
+
+
+
+ +

Other blocks

+ +
+
+
692        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+693                                  b_do, b_lse, b_pdp,
+694                                  BLOCK_M, BLOCK_N,
+695                                  m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32),  # type: ignore
+696                                  steps=(m * BLOCK_M) // BLOCK_N,
+697                                  MASK=False
+698                                  )
+699    else:
+700        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+701                                  b_do, b_lse, b_pdp,
+702                                  BLOCK_M, BLOCK_N,
+703                                  m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32),  # type: ignore
+704                                  steps=kv_seq_len // BLOCK_N,
+705                                  MASK=False
+706                                  )
+
+
+
+
+ +

Since was scaled by , and got this factor in to computed we need to reverse it.

+ +
+
+
710    b_dq *= LN2
+
+
+
+
+ +

Save

+ +
+
+
713    tl.store(p_dq, b_dq.to(t_dq.type.element_ty))
+
+
+
+
+ +

Inner loop over n key

+ +
+
+
716@triton.jit
+717def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+718                       b_do, b_lse, b_pdp,
+719                       BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
+720                       m, start_n, steps,
+721                       MASK: tl.constexpr):
+
+
+
+
+ + +
+
+
723    offs_m = m + tl.arange(0, BLOCK_M)
+724
+725    p_kT = tl.advance(p_kT, (0, start_n))
+726    p_vT = tl.advance(p_vT, (0, start_n))
+727
+728    tl.static_assert(BLOCK_M % BLOCK_N == 0, 'BLOCK_M must be divisible by BLOCK_N')
+729
+730    for _ in range(steps):
+
+
+
+
+ +

Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of

+ +
+
+
734        b_kT = tl.load(p_kT)
+735        b_vT = tl.load(p_vT)
+736        b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
+737        b_p = tl.math.exp2(b_qk - b_lse[:, None])
+
+
+
+
+ +

Autoregressive masking.

+ +
+
+
740        if MASK:
+741            offs_n = start_n + tl.arange(0, BLOCK_N)
+742            mask = (offs_m[:, None] >= offs_n[None, :])
+743            b_p = tl.where(mask, b_p, 0.0)
+
+
+
+
+ +

+ +
+
+
+
+
+
+
+ +

+ +
+
+
748        b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
+
+
+
+
+ +

+ +
+
+
750        b_ds = b_p * (b_dp - b_pdp[:, None])
+
+
+
+
+ +

+ +
+
+
752        b_dq += tl.dot(b_ds.to(b_kT.dtype),
+753                       tl.trans(b_kT),
+754                       out_dtype=HI_PRES_TL)
+
+
+
+
+ +

Increment pointers.

+ +
+
+
757        start_n += BLOCK_N
+758        p_kT = tl.advance(p_kT, (0, BLOCK_N))
+759        p_vT = tl.advance(p_vT, (0, BLOCK_N))
+
+
+
+
+ +

Return accumulated

+ +
+
+
762    return b_dq