From 9262c57f181a52130a64f65bc204fb5b3470f0fd Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Fri, 8 Aug 2025 19:57:57 +0530 Subject: [PATCH] flash attention --- docs/index.html | 1 + docs/transformers/flash/index.html | 1574 ++++++++++++----------- docs/transformers/flash/test.html | 375 +++--- labml_nn/__init__.py | 1 + labml_nn/transformers/flash/__init__.py | 141 +- labml_nn/transformers/flash/test.py | 52 +- readme.md | 1 + setup.py | 4 +- 8 files changed, 1212 insertions(+), 937 deletions(-) diff --git a/docs/index.html b/docs/index.html index 16647748..f859361e 100644 --- a/docs/index.html +++ b/docs/index.html @@ -80,6 +80,7 @@

Paper Implementations

Transformers

-
112    @staticmethod
-113    def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
-114                causal: bool, sm_scale: float) -> torch.Tensor:
+
159    @staticmethod
+160    def forward(ctx: Any,
+161                q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+162                causal: bool, sm_scale: float) -> torch.Tensor:
@@ -154,10 +168,10 @@
-
126        batch_size, n_heads, q_seq_len, d_head = q.shape
-127        _, k_heads, kv_seq_len, _ = k.shape
-128        assert n_heads % k_heads == 0
-129        n_groups = n_heads // k_heads
+
176        batch_size, n_heads, q_seq_len, d_head = q.shape
+177        _, k_heads, kv_seq_len, _ = k.shape
+178        assert n_heads % k_heads == 0
+179        n_groups = n_heads // k_heads
@@ -169,8 +183,8 @@
-
132        assert d_head == k.shape[-1] == v.shape[-1]
-133        assert d_head in {16, 32, 64, 128, 256}
+
182        assert d_head == k.shape[-1] == v.shape[-1]
+183        assert d_head in {16, 32, 64, 128, 256}
@@ -182,9 +196,9 @@
-
136        q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
-137        k = k.view(batch_size * k_heads, kv_seq_len, d_head)
-138        v = v.view(batch_size * k_heads, kv_seq_len, d_head)
+
186        q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
+187        k = k.view(batch_size * k_heads, kv_seq_len, d_head)
+188        v = v.view(batch_size * k_heads, kv_seq_len, d_head)
@@ -196,10 +210,10 @@
-
141        assert q.is_contiguous()
-142        assert k.is_contiguous()
-143        assert v.is_contiguous()
-144        assert k.stride() == v.stride()
+
191        assert q.is_contiguous()
+192        assert k.is_contiguous()
+193        assert v.is_contiguous()
+194        assert k.stride() == v.stride()
@@ -211,7 +225,7 @@
-
147        o = torch.empty_like(q)
+
197        o = torch.empty_like(q)
@@ -219,11 +233,11 @@ -

Tensor for log of sum of exponentials

+

Tensor for log of sum of exponentials

-
149        lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
+
199        lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
@@ -236,15 +250,15 @@
-
152        grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
-153        _attn_fwd[grid](
-154            q, k, v, sm_scale * 1.4426950408889634, lse, o,
-155            n_groups=n_groups,
-156            q_seq_len=q_seq_len,
-157            kv_seq_len=kv_seq_len,
-158            d_head=d_head,
-159            is_causal=causal,
-160        )
+
202        grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
+203        _attn_fwd[grid](
+204            q, k, v, sm_scale * 1.4426950408889634, lse, o,
+205            n_groups=n_groups,
+206            q_seq_len=q_seq_len,
+207            kv_seq_len=kv_seq_len,
+208            d_head=d_head,
+209            is_causal=causal,
+210        )
@@ -256,10 +270,10 @@
-
163        ctx.save_for_backward(q, k, v, o, lse)
-164        ctx.sm_scale = sm_scale
-165        ctx.n_groups = n_groups
-166        ctx.causal = causal
+
213        ctx.save_for_backward(q, k, v, o, lse)
+214        ctx.sm_scale = sm_scale
+215        ctx.n_groups = n_groups
+216        ctx.causal = causal
@@ -272,7 +286,7 @@
-
169        return o.view(batch_size, n_heads, q_seq_len, d_head)
+
219        return o.view(batch_size, n_heads, q_seq_len, d_head)
@@ -280,7 +294,8 @@ -

The backward pass computes the gradients of the input tensors.

+

Backward pass

+

The backward pass computes the gradients of the input tensors.

-
171    @staticmethod
-172    def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
+
221    @staticmethod
+222    def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
@@ -302,10 +317,10 @@
-
181        n_groups = ctx.n_groups
-182        sm_scale = ctx.sm_scale
-183        causal = ctx.causal
-184        q, k, v, o, lse = ctx.saved_tensors
+
233        n_groups = ctx.n_groups
+234        sm_scale = ctx.sm_scale
+235        causal = ctx.causal
+236        q, k, v, o, lse = ctx.saved_tensors
@@ -317,9 +332,9 @@
-
187        batch_size, n_heads, q_seq_len, d_head = do.shape
-188        _, kv_seq_len, _ = k.shape
-189        k_heads = n_heads // n_groups
+
239        batch_size, n_heads, q_seq_len, d_head = do.shape
+240        _, kv_seq_len, _ = k.shape
+241        k_heads = n_heads // n_groups
@@ -331,7 +346,7 @@
-
192        do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
+
244        do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
@@ -343,9 +358,9 @@
-
195        assert do.is_contiguous()
-196        assert k.stride() == v.stride()
-197        assert q.stride() == o.stride() == do.stride()
+
247        assert do.is_contiguous()
+248        assert k.stride() == v.stride()
+249        assert q.stride() == o.stride() == do.stride()
@@ -357,9 +372,9 @@
-
200        dq = torch.empty_like(q)
-201        dk = torch.empty_like(k)
-202        dv = torch.empty_like(v)
+
252        dq = torch.empty_like(q)
+253        dk = torch.empty_like(k)
+254        dv = torch.empty_like(v)
@@ -367,11 +382,11 @@ -

Precompute

+

Precompute

-
205        k_scaled = k * (sm_scale * 1.4426950408889634)
+
257        k_scaled = k * (sm_scale * 1.4426950408889634)
@@ -379,11 +394,11 @@ -

+

-
207        pdp = torch.empty_like(lse)
+
259        pdp = torch.empty_like(lse)
@@ -392,7 +407,7 @@ #

We use fixed BLOCK_Q - for backward pass on

+ for backward pass on

@@ -404,23 +419,23 @@ -

Compute

+

Compute

This is parallelized along the batch and query in blocks of size BLOCK_Q

-
213        BLOCK_Q = 16
-214        pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
-215        _attn_bwd_d[pre_grid](
-216            o, do,
-217            pdp,
-218            BLOCK_Q=16,
-219            d_head=d_head,
-220            q_seq_len=q_seq_len,
-221            n_groups=n_groups,
-222            num_stages=1,
-223        )
+
265        BLOCK_Q = 16
+266        pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
+267        _attn_bwd_d[pre_grid](
+268            o, do,
+269            pdp,
+270            BLOCK_Q=16,
+271            d_head=d_head,
+272            q_seq_len=q_seq_len,
+273            n_groups=n_groups,
+274            num_stages=1,
+275        )
@@ -428,20 +443,20 @@ -

Compute and

+

Compute and

This is parallelized along the batch and keys in blocks of size BLOCK_K

-
228        grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
-229        _attn_bwd_dkdv[grid](
-230            q, k_scaled, v, sm_scale, do, dk, dv,
-231            lse, pdp,
-232            q_seq_len, kv_seq_len, n_groups, d_head,
-233            is_causal=causal,
-234
-235        )
+
280        grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
+281        _attn_bwd_dkdv[grid](
+282            q, k_scaled, v, sm_scale, do, dk, dv,
+283            lse, pdp,
+284            q_seq_len, kv_seq_len, n_groups, d_head,
+285            is_causal=causal,
+286
+287        )
@@ -449,20 +464,20 @@ -

Compute

+

Compute

This is parallelized along the batch and queries in blocks of size BLOCK_Q

-
240        grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
-241        _attn_bwd_dq[grid](
-242            q, k_scaled, v, do,
-243            dq,
-244            lse, pdp,
-245            q_seq_len, kv_seq_len, n_groups, d_head,
-246            is_causal=causal,
-247        )
+
292        grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
+293        _attn_bwd_dq[grid](
+294            q, k_scaled, v, do,
+295            dq,
+296            lse, pdp,
+297            q_seq_len, kv_seq_len, n_groups, d_head,
+298            is_causal=causal,
+299        )
@@ -474,9 +489,9 @@
-
250        dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
-251        dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
-252        dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
+
302        dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
+303        dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
+304        dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
@@ -488,10 +503,10 @@
-
255        return dq, dk, dv, None, None
-256
-257
-258attention = AttentionFunc.apply
+
307        return dq, dk, dv, None, None
+308
+309
+310attention = AttentionFunc.apply
@@ -503,7 +518,7 @@
-
261def _get_autotune_configs(inner_loop: str) -> list:
+
313def _get_autotune_configs(inner_loop: str) -> list:
@@ -514,7 +529,7 @@
-
266    configs = []
+
318    configs = []
@@ -522,11 +537,12 @@ -

List possible BLOCK_Q and BLOCK_K that satisfy BLOCK_Q divisible by BLOCK_K and also try to cover a wide range

+

Possible options for BLOCK_Q +

-
269    for bm in [64, 128, 256]:
+
321    for bq in [64, 128, 256]:
@@ -534,50 +550,98 @@ -

We'll try bn in 16, 32, 64, 128 that are divisors and <= bm

+

Possible options for BLOCK_K +

-
271        for bn in [64, 128, 256]:
-272            if inner_loop == 'key' and bm % bn != 0:
-273                continue
-274            if inner_loop == 'query' and bn % bm != 0:
-275                continue
-276            for s in [2, 3, 4]:
-277                for w in [4, 8]:
-278                    if bm * bn < 128 * 128 and w == 8:
-279                        continue
-280
-281                    configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
-282
-283    return configs[:1]
+
323        for bk in [64, 128, 256]:
-
+
-
  • t_q - query
  • +

    If the inner loop is along keys the BLOCK_Q + must be a multiple of BLOCK_K + for causal masking

    + +
+
+
325            if inner_loop == 'key' and bq % bk != 0:
+326                continue
+
+
+
+
+ +

Similarly when the inner loop is along queries

+ +
+
+
328            if inner_loop == 'query' and bk % bq != 0:
+329                continue
+
+
+
+
+ +

Number of stages and warps

+ +
+
+
332            for s in [2, 3, 4]:
+333                for w in [4, 8]:
+334                    if bq * bk < 128 * 128 and w == 8:
+335                        continue
+336
+337                    configs.append(triton.Config({'BLOCK_Q': bq, 'BLOCK_K': bk}, num_stages=s, num_warps=w))
+
+
+
+
+ +

Use return configs + to autotune. Trying all combinations is slow for testing.

+ +
+
+
340    return configs[:1]
+
+
+
+
+ +

Triton kernel for Flash attention forward pass

+
  • t_q + queries
  • t_k - keys
  • + keys
  • t_v - values
  • -
  • sm_scale - softmax scale
  • + values +
  • sm_scale_log2e + softmax scale multiplied by
  • t_lse - (out)
  • + (out)
  • t_o - output (out)
  • + output
  • n_groups - number of groups
  • + number of groups in GQA
  • q_seq_len query sequence length
  • kv_seq_len key/value sequence length
  • d_head - size of a head
  • + number of dimensions in a head
  • BLOCK_Q block size for query sequence length
  • BLOCK_K @@ -590,105 +654,26 @@ and d denote the stride of the corresponding dimensions (batch_size , n_heads -, seq_len +, q_seq_len , d_head ) in the query. Stride n - denote the stride on seq_len + denote the stride on kv_seq_len of key.

-
286@triton.autotune(_get_autotune_configs(inner_loop='key'),
-287                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-288@triton.jit
-289def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
-290              n_groups: tl.constexpr,
-291              q_seq_len: tl.constexpr,
-292              kv_seq_len: tl.constexpr,
-293              d_head: tl.constexpr,
-294              is_causal: tl.constexpr,
-295              BLOCK_Q: tl.constexpr,  # q seq len block
-296              BLOCK_K: tl.constexpr,  # k seq len block
-297              ):
-
-
-
-
- - -
-
-
318    i = tl.program_id(0)
-319    z = tl.program_id(1) // n_groups
-320    g = tl.program_id(1) % n_groups  # TODO
-
-
-
-
- -

Create block pointers

- -
-
-
323    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-324                            (q_seq_len, d_head),
-325                            (d_head, 1),
-326                            (i * BLOCK_Q, 0),
-327                            (BLOCK_Q, d_head),
-328                            (1, 0))
-329    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-330                            (kv_seq_len, d_head),
-331                            (d_head, 1),
-332                            (0, 0),
-333                            (BLOCK_K, d_head),
-334                            (1, 0))
-335    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-336                             (d_head, kv_seq_len),
-337                             (1, d_head),
-338                             (0, 0),
-339                             (d_head, BLOCK_K),
-340                             (0, 1))
-341    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-342                            (q_seq_len, d_head),
-343                            (d_head, 1),
-344                            (i * BLOCK_Q, 0),
-345                            (BLOCK_Q, d_head),
-346                            (1, 0))
-347    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-348                              (q_seq_len,),
-349                              (1,),
-350                              (i * BLOCK_Q,),
-351                              (BLOCK_Q,),
-352                              (0,))
-
-
-
-
- -

Initialize offsets

- -
-
-
355    offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
-356    offs_j = tl.arange(0, BLOCK_K)
-
-
-
-
- -

Mask for for the last block

- -
-
-
358    i_mask = offs_i < q_seq_len
+
343@triton.autotune(_get_autotune_configs(inner_loop='key'),
+344                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+345@triton.jit
+346def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
+347              n_groups: tl.constexpr,
+348              q_seq_len: tl.constexpr,
+349              kv_seq_len: tl.constexpr,
+350              d_head: tl.constexpr,
+351              is_causal: tl.constexpr,
+352              BLOCK_Q: tl.constexpr,
+353              BLOCK_K: tl.constexpr,
+354              ):
@@ -696,14 +681,14 @@ -

Initialize and . is initialized to and to . So in the first update, the effect of initial is .

-

b_m - will be storing

+

We are computing the attention for for i + ... `i + BLOCK_Q' in batch/head combination .

-
364    b_m = tl.where(i_mask, -float("inf"), 0.0)
-365    b_l = tl.where(i_mask, 1.0, 0.0)
+
378    i = tl.program_id(0)
+379    z = tl.program_id(1) // n_groups
+380    g = tl.program_id(1) % n_groups
@@ -711,11 +696,40 @@ -

+

Create block pointers

-
368    b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
+
383    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+384                            (q_seq_len, d_head),
+385                            (d_head, 1),
+386                            (i * BLOCK_Q, 0),
+387                            (BLOCK_Q, d_head),
+388                            (1, 0))
+389    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+390                            (kv_seq_len, d_head),
+391                            (d_head, 1),
+392                            (0, 0),
+393                            (BLOCK_K, d_head),
+394                            (1, 0))
+395    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+396                             (d_head, kv_seq_len),
+397                             (1, d_head),
+398                             (0, 0),
+399                             (d_head, BLOCK_K),
+400                             (0, 1))
+401    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+402                            (q_seq_len, d_head),
+403                            (d_head, 1),
+404                            (i * BLOCK_Q, 0),
+405                            (BLOCK_Q, d_head),
+406                            (1, 0))
+407    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+408                              (q_seq_len,),
+409                              (1,),
+410                              (i * BLOCK_Q,),
+411                              (BLOCK_Q,),
+412                              (0,))
@@ -723,13 +737,12 @@ -

Load outside the loop since it will be reused through out the loop over .

+

Initialize offsets

-
371    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
-372
-373    if is_causal:
+
415    offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
+416    offs_j = tl.arange(0, BLOCK_K)
@@ -737,21 +750,11 @@ -

Inner loop upto the diagonal block

+

Mask for for the last block

-
375        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
-376                                        p_kT, p_v,
-377                                        sm_scale_log2e,
-378                                        BLOCK_Q, d_head, BLOCK_K,
-379                                        offs_i, offs_j,
-380                                        j=tl.full([], 0, tl.int32),  # type: ignore
-381                                        steps=(i * BLOCK_Q) // BLOCK_K,
-382                                        MASK=False,
-383                                        q_seq_len=q_seq_len,
-384                                        kv_seq_len=kv_seq_len
-385                                        )
+
419    i_mask = offs_i < q_seq_len
@@ -759,21 +762,14 @@ -

Diagonal block with masking within it

+

Initialize and . is initialized to and to . So in the first update, the effect of initial is .

+

b_m + will be storing

-
387        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
-388                                        sm_scale_log2e,
-389                                        BLOCK_Q, d_head, BLOCK_K,
-390                                        offs_i, offs_j,
-391                                        j=i * BLOCK_Q,
-392                                        steps=BLOCK_Q // BLOCK_K,
-393                                        MASK=True,
-394                                        q_seq_len=q_seq_len,
-395                                        kv_seq_len=kv_seq_len
-396                                        )
-397    else:
+
425    b_m = tl.where(i_mask, -float("inf"), 0.0)
+426    b_l = tl.where(i_mask, 1.0, 0.0)
@@ -781,20 +777,11 @@ -

Iterate through all

+

-
399        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
-400                                        sm_scale_log2e,
-401                                        BLOCK_Q, d_head, BLOCK_K,
-402                                        offs_i, offs_j,
-403                                        j=tl.full([], 0, tl.int32),  # type: ignore
-404                                        steps=tl.cdiv(kv_seq_len, BLOCK_K),
-405                                        MASK=False,
-406                                        q_seq_len=q_seq_len,
-407                                        kv_seq_len=kv_seq_len
-408                                        )
+
429    b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
@@ -802,11 +789,13 @@ -

Store LSE

+

Load outside the loop since it will be reused through out the loop over .

-
411    tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
+
432    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
+433
+434    if is_causal:
@@ -814,11 +803,21 @@ -

Store

+

Inner loop upto the diagonal block

-
413    tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
+
436        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
+437                                        p_kT, p_v,
+438                                        sm_scale_log2e,
+439                                        BLOCK_Q, d_head, BLOCK_K,
+440                                        offs_i, offs_j,
+441                                        j=tl.full([], 0, tl.int32),  # type: ignore
+442                                        steps=(i * BLOCK_Q) // BLOCK_K,
+443                                        MASK=False,
+444                                        q_seq_len=q_seq_len,
+445                                        kv_seq_len=kv_seq_len
+446                                        )
@@ -826,24 +825,21 @@ - +

Diagonal block with masking within it

+
-
416@triton.jit
-417def _attn_fwd_inner(b_o, b_l, b_m, b_q,
-418                    p_kT, p_v,
-419                    sm_scale_log2e,
-420                    BLOCK_Q: tl.constexpr,
-421                    d_head: tl.constexpr,
-422                    BLOCK_K: tl.constexpr,
-423                    offs_i, offs_j,
-424                    j,
-425                    steps,
-426                    MASK: tl.constexpr,
-427                    q_seq_len: tl.constexpr,
-428                    kv_seq_len: tl.constexpr
-429                    ):
-430    tl.static_assert(BLOCK_Q % BLOCK_K == 0)
+
448        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
+449                                        sm_scale_log2e,
+450                                        BLOCK_Q, d_head, BLOCK_K,
+451                                        offs_i, offs_j,
+452                                        j=i * BLOCK_Q,
+453                                        steps=BLOCK_Q // BLOCK_K,
+454                                        MASK=True,
+455                                        q_seq_len=q_seq_len,
+456                                        kv_seq_len=kv_seq_len
+457                                        )
+458    else:
@@ -851,12 +847,20 @@ -

Move and pointers

+

Iterate through all

-
433    p_kT = tl.advance(p_kT, (0, j))
-434    p_v = tl.advance(p_v, (j, 0))
+
460        b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
+461                                        sm_scale_log2e,
+462                                        BLOCK_Q, d_head, BLOCK_K,
+463                                        offs_i, offs_j,
+464                                        j=tl.full([], 0, tl.int32),  # type: ignore
+465                                        steps=tl.cdiv(kv_seq_len, BLOCK_K),
+466                                        MASK=False,
+467                                        q_seq_len=q_seq_len,
+468                                        kv_seq_len=kv_seq_len
+469                                        )
@@ -864,11 +868,11 @@ -

Iterate over , and update and

+

Store LSE

-
437    for _ in range(steps):
+
472    tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
@@ -876,24 +880,40 @@ -

Load

+

Store

-
439        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
+
474    tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
-
+
-

Compute

+

Inner loop to calculate

+

This iterates through keys and values starting from j + for steps + number of steps. In each step it processes BLOCK_K + entries of keys/values.

-
441        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
-442        b_s = b_s * sm_scale_log2e
+
477@triton.jit
+478def _attn_fwd_inner(b_o, b_l, b_m, b_q,
+479                    p_kT, p_v,
+480                    sm_scale_log2e,
+481                    BLOCK_Q: tl.constexpr,
+482                    d_head: tl.constexpr,
+483                    BLOCK_K: tl.constexpr,
+484                    offs_i, offs_j,
+485                    j,
+486                    steps,
+487                    MASK: tl.constexpr,
+488                    q_seq_len: tl.constexpr,
+489                    kv_seq_len: tl.constexpr
+490                    ):
@@ -901,13 +921,10 @@ -

Apply causal mask

- +
-
445        if MASK:
-446            causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
-447            b_s = tl.where(causal_mask, b_s, -float("inf"))
+
497    tl.static_assert(BLOCK_Q % BLOCK_K == 0)
@@ -915,12 +932,12 @@ -

Mask out if the block is beyond the end of

+

Move and pointers

-
450        j_mask = (j + offs_j) < kv_seq_len
-451        b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
+
500    p_kT = tl.advance(p_kT, (0, j))
+501    p_v = tl.advance(p_v, (j, 0))
@@ -928,11 +945,11 @@ -

+

Iterate over , and update and

-
454        b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
+
504    for _ in range(steps):
@@ -940,11 +957,11 @@ -

+

Load

-
460        b_p = tl.math.exp2(b_s - b_m_new[:, None])
+
506        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
@@ -952,11 +969,12 @@ -

+

Compute

-
463        b_l_new = tl.sum(b_p, -1)
+
508        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
+509        b_s = b_s * sm_scale_log2e
@@ -964,11 +982,13 @@ -

+

Apply causal mask

-
465        b_m_m_new = tl.math.exp2(b_m - b_m_new)
+
512        if MASK:
+513            causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
+514            b_s = tl.where(causal_mask, b_s, -float("inf"))
@@ -976,11 +996,12 @@ -

+

Mask out if the block is beyond the end of

-
467        b_l = b_l * b_m_m_new + b_l_new
+
517        j_mask = (j + offs_j) < kv_seq_len
+518        b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
@@ -988,14 +1009,11 @@ -

+

-
470        b_o = b_o * b_m_m_new[:, None]
-471        b_p = b_p.to(b_q.dtype)  # TODO
-472        b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
-473        b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
+
521        b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
@@ -1003,11 +1021,11 @@ -

+

-
476        b_m = b_m_new
+
527        b_p = tl.math.exp2(b_s - b_m_new[:, None])
@@ -1015,17 +1033,11 @@ -

Move pointers

+

-
479        j += BLOCK_K
-480        p_v = tl.advance(p_v, (BLOCK_K, 0))
-481        p_kT = tl.advance(p_kT, (0, BLOCK_K))
-482
-483    tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
-484
-485    return b_o, b_l, b_m
+
530        b_l_new = tl.sum(b_p, -1)
@@ -1033,18 +1045,11 @@ - +

+
-
488@triton.jit
-489def _attn_bwd_d(t_o, t_do,
-490                t_pdp,
-491                BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
-492                q_seq_len: tl.constexpr,
-493                n_groups: tl.constexpr,
-494                ):
-495    i = tl.program_id(0) * BLOCK_Q
-496    z = tl.program_id(1)
+
532        b_m_m_new = tl.math.exp2(b_m - b_m_new)
@@ -1052,28 +1057,11 @@ -

Create block pointers

+

-
499    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
-500                            (n_groups, q_seq_len, d_head),
-501                            (q_seq_len * d_head, d_head, 1),
-502                            (0, i, 0),
-503                            (n_groups, BLOCK_Q, d_head),
-504                            (2, 1, 0))
-505    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
-506                             (n_groups, q_seq_len, d_head),
-507                             (q_seq_len * d_head, d_head, 1),
-508                             (0, i, 0),
-509                             (n_groups, BLOCK_Q, d_head),
-510                             (2, 1, 0))
-511    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
-512                              (n_groups, q_seq_len),
-513                              (q_seq_len, 1),
-514                              (0, i),
-515                              (n_groups, BLOCK_Q),
-516                              (1, 0))
+
534        b_l = b_l * b_m_m_new + b_l_new
@@ -1081,11 +1069,14 @@ -

Load

+

-
519    o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
+
537        b_o = b_o * b_m_m_new[:, None]
+538        b_p = b_p.to(b_q.dtype)  # TODO
+539        b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
+540        b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
@@ -1093,11 +1084,11 @@ -

Load

+

-
521    do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
+
543        b_m = b_m_new
@@ -1105,47 +1096,47 @@ -

Calculate

+

Move pointers

-
523    d = tl.sum(o * do, axis=-1)
+
546        j += BLOCK_K
+547        p_v = tl.advance(p_v, (BLOCK_K, 0))
+548        p_kT = tl.advance(p_kT, (0, BLOCK_K))
+549
+550    tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
+551
+552    return b_o, b_l, b_m
-
+
-

Save

+

Triton kernel to compute

-
525    tl.store(p_pdp, d, boundary_check=(1,))
+
555@triton.jit
+556def _attn_bwd_d(t_o, t_do,
+557                t_pdp,
+558                BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
+559                q_seq_len: tl.constexpr,
+560                n_groups: tl.constexpr,
+561                ):
-
+
-

Compute and for by iterating over

- +
-
528@triton.autotune(_get_autotune_configs(inner_loop='query'),
-529                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-530@triton.jit
-531def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
-532                   t_do,
-533                   t_dk, t_dv,
-534                   t_lse, t_pdp,
-535                   q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-536                   n_groups: tl.constexpr, d_head: tl.constexpr,
-537                   is_causal: tl.constexpr,
-538                   BLOCK_Q: tl.constexpr,
-539                   BLOCK_K: tl.constexpr,
-540                   ):
+
565    i = tl.program_id(0) * BLOCK_Q
+566    z = tl.program_id(1)
@@ -1153,11 +1144,28 @@ - +

Create block pointers

+
-
545    j = tl.program_id(0) * BLOCK_K
-546    z = tl.program_id(1)
+
569    p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
+570                            (n_groups, q_seq_len, d_head),
+571                            (q_seq_len * d_head, d_head, 1),
+572                            (0, i, 0),
+573                            (n_groups, BLOCK_Q, d_head),
+574                            (2, 1, 0))
+575    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
+576                             (n_groups, q_seq_len, d_head),
+577                             (q_seq_len * d_head, d_head, 1),
+578                             (0, i, 0),
+579                             (n_groups, BLOCK_Q, d_head),
+580                             (2, 1, 0))
+581    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
+582                              (n_groups, q_seq_len),
+583                              (q_seq_len, 1),
+584                              (0, i),
+585                              (n_groups, BLOCK_Q),
+586                              (1, 0))
@@ -1165,34 +1173,11 @@ -

Create block pointers

+

Load

-
549    p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-550                            (kv_seq_len, d_head),
-551                            (d_head, 1),
-552                            (j, 0),
-553                            (BLOCK_K, d_head),
-554                            (1, 0))
-555    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-556                            (kv_seq_len, d_head),
-557                            (d_head, 1),
-558                            (j, 0),
-559                            (BLOCK_K, d_head),
-560                            (1, 0))
-561    p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
-562                             (kv_seq_len, d_head),
-563                             (d_head, 1),
-564                             (j, 0),
-565                             (BLOCK_K, d_head),
-566                             (1, 0))
-567    p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
-568                             (kv_seq_len, d_head),
-569                             (d_head, 1),
-570                             (j, 0),
-571                             (BLOCK_K, d_head),
-572                             (1, 0))
+
589    o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
@@ -1200,12 +1185,11 @@ -

Initialize and

+

Load

-
575    b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
-576    b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
+
591    do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
@@ -1213,12 +1197,11 @@ -

Load and outside the loop.

+

Calculate

-
579    b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
-580    b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
+
593    d = tl.sum(o * do, axis=-1)
@@ -1226,49 +1209,35 @@ -

Iterate through queries in GQA

+

Save

-
583    for g in range(n_groups):
+
595    tl.store(p_pdp, d, boundary_check=(1,))
-
+
-

Create block pointers

+

Triton kernel to compute and

-
585        p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-586                                 (d_head, q_seq_len),
-587                                 (1, d_head),
-588                                 (0, 0),
-589                                 (d_head, BLOCK_Q),
-590                                 (0, 1))
-591
-592        p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-593                                 (q_seq_len, d_head),
-594                                 (d_head, 1),
-595                                 (0, 0),
-596                                 (BLOCK_Q, d_head),
-597                                 (1, 0))
-598        p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-599                                  (q_seq_len,),
-600                                  (1,),
-601                                  (0,),
-602                                  (BLOCK_Q,),
-603                                  (0,))
-604        p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-605                                  (q_seq_len,),
-606                                  (1,),
-607                                  (0,),
-608                                  (BLOCK_Q,),
-609                                  (0,))
-610
-611        if is_causal:
+
598@triton.autotune(_get_autotune_configs(inner_loop='query'),
+599                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+600@triton.jit
+601def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
+602                   t_do,
+603                   t_dk, t_dv,
+604                   t_lse, t_pdp,
+605                   q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+606                   n_groups: tl.constexpr, d_head: tl.constexpr,
+607                   is_causal: tl.constexpr,
+608                   BLOCK_Q: tl.constexpr,
+609                   BLOCK_K: tl.constexpr,
+610                   ):
@@ -1276,22 +1245,14 @@ -

Inner loop at the diagonal block

+

Compute and for j + ... j + BLOCK_K + by iterating over

-
613            b_dk, b_dv = _attn_bwd_dkdv_inner(
-614                b_dk, b_dv,
-615                p_qT, b_k, b_v, p_do,
-616                p_lse, p_pdp,
-617                BLOCK_Q, BLOCK_K,
-618                d_head,
-619                j=j, i=j,
-620                steps=BLOCK_K // BLOCK_Q,
-621                MASK=True,
-622                q_seq_len=q_seq_len,
-623                kv_seq_len=kv_seq_len,
-624            )
+
616    j = tl.program_id(0) * BLOCK_K
+617    z = tl.program_id(1)
@@ -1299,23 +1260,34 @@ -

Innerloop on queries after the diagonal

+

Create block pointers

-
627            b_dk, b_dv = _attn_bwd_dkdv_inner(
-628                b_dk, b_dv,
-629                p_qT, b_k, b_v, p_do,
-630                p_lse, p_pdp,
-631                BLOCK_Q, BLOCK_K,
-632                d_head,
-633                j=j, i=j + BLOCK_K,
-634                steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
-635                MASK=False,
-636                q_seq_len=q_seq_len,
-637                kv_seq_len=kv_seq_len
-638            )
-639        else:
+
620    p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+621                            (kv_seq_len, d_head),
+622                            (d_head, 1),
+623                            (j, 0),
+624                            (BLOCK_K, d_head),
+625                            (1, 0))
+626    p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+627                            (kv_seq_len, d_head),
+628                            (d_head, 1),
+629                            (j, 0),
+630                            (BLOCK_K, d_head),
+631                            (1, 0))
+632    p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
+633                             (kv_seq_len, d_head),
+634                             (d_head, 1),
+635                             (j, 0),
+636                             (BLOCK_K, d_head),
+637                             (1, 0))
+638    p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
+639                             (kv_seq_len, d_head),
+640                             (d_head, 1),
+641                             (j, 0),
+642                             (BLOCK_K, d_head),
+643                             (1, 0))
@@ -1323,22 +1295,12 @@ -

Iterate through all queries

+

Initialize and

-
641            b_dk, b_dv = _attn_bwd_dkdv_inner(
-642                b_dk, b_dv,
-643                p_qT, b_k, b_v, p_do,
-644                p_lse, p_pdp,
-645                BLOCK_Q, BLOCK_K,
-646                d_head,
-647                j=j, i=tl.full([], 0, tl.int32),
-648                steps=tl.cdiv(q_seq_len, BLOCK_Q),
-649                MASK=False,
-650                q_seq_len=q_seq_len,
-651                kv_seq_len=kv_seq_len
-652            )
+
646    b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
+647    b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
@@ -1346,11 +1308,12 @@ -

Save

+

Load and outside the loop.

-
655    tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
+
650    b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
+651    b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
@@ -1358,12 +1321,11 @@ -

b_dk - had

+

Iterate through queries in GQA

-
658    b_dk *= sm_scale
+
654    for g in range(n_groups):
@@ -1371,32 +1333,60 @@ -

Save

+

Create block pointers

-
661    tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
+
656        p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+657                                 (d_head, q_seq_len),
+658                                 (1, d_head),
+659                                 (0, 0),
+660                                 (d_head, BLOCK_Q),
+661                                 (0, 1))
+662
+663        p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+664                                 (q_seq_len, d_head),
+665                                 (d_head, 1),
+666                                 (0, 0),
+667                                 (BLOCK_Q, d_head),
+668                                 (1, 0))
+669        p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+670                                  (q_seq_len,),
+671                                  (1,),
+672                                  (0,),
+673                                  (BLOCK_Q,),
+674                                  (0,))
+675        p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+676                                  (q_seq_len,),
+677                                  (1,),
+678                                  (0,),
+679                                  (BLOCK_Q,),
+680                                  (0,))
+681
+682        if is_causal:
-
+
-

Inner loop along query

+

Inner loop at the diagonal block

-
664@triton.jit
-665def _attn_bwd_dkdv_inner(b_dk, b_dv,
-666                         p_qT, b_k, b_v, p_do,
-667                         p_lse, p_pdp,
-668                         BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
-669                         d_head: tl.constexpr,
-670                         j, i, steps,
-671                         MASK: tl.constexpr,
-672                         q_seq_len: tl.constexpr,
-673                         kv_seq_len: tl.constexpr):
+
684            b_dk, b_dv = _attn_bwd_dkdv_inner(
+685                b_dk, b_dv,
+686                p_qT, b_k, b_v, p_do,
+687                p_lse, p_pdp,
+688                BLOCK_Q, BLOCK_K,
+689                d_head,
+690                j=j, i=j,
+691                steps=BLOCK_K // BLOCK_Q,
+692                MASK=True,
+693                q_seq_len=q_seq_len,
+694                kv_seq_len=kv_seq_len,
+695            )
@@ -1404,11 +1394,23 @@ -

To apply the mask

+

Inner loop on queries after the diagonal

-
677    tl.static_assert(BLOCK_K % BLOCK_Q == 0)
+
698            b_dk, b_dv = _attn_bwd_dkdv_inner(
+699                b_dk, b_dv,
+700                p_qT, b_k, b_v, p_do,
+701                p_lse, p_pdp,
+702                BLOCK_Q, BLOCK_K,
+703                d_head,
+704                j=j, i=j + BLOCK_K,
+705                steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
+706                MASK=False,
+707                q_seq_len=q_seq_len,
+708                kv_seq_len=kv_seq_len
+709            )
+710        else:
@@ -1416,12 +1418,22 @@ -

Offsets and mask

+

Iterate through all queries

-
680    offs_i = i + tl.arange(0, BLOCK_Q)
-681    offs_j = j + tl.arange(0, BLOCK_K)
+
712            b_dk, b_dv = _attn_bwd_dkdv_inner(
+713                b_dk, b_dv,
+714                p_qT, b_k, b_v, p_do,
+715                p_lse, p_pdp,
+716                BLOCK_Q, BLOCK_K,
+717                d_head,
+718                j=j, i=tl.full([], 0, tl.int32),
+719                steps=tl.cdiv(q_seq_len, BLOCK_Q),
+720                MASK=False,
+721                q_seq_len=q_seq_len,
+722                kv_seq_len=kv_seq_len
+723            )
@@ -1429,14 +1441,11 @@ -

Move the pointers

+

Save

-
684    p_qT = tl.advance(p_qT, (0, i))
-685    p_do = tl.advance(p_do, (i, 0))
-686    p_lse = tl.advance(p_lse, (i,))
-687    p_pdp = tl.advance(p_pdp, (i,))
+
726    tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
@@ -1444,11 +1453,12 @@ -

Iterate over

+

b_dk + had

-
690    for _ in range(steps):
+
729    b_dk *= sm_scale
@@ -1456,23 +1466,32 @@ -

Load

+

Save

-
692        b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
+
732    tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
-
+
-

+

Inner loop to calculate ,

-
695        b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
+
735@triton.jit
+736def _attn_bwd_dkdv_inner(b_dk, b_dv,
+737                         p_qT, b_k, b_v, p_do,
+738                         p_lse, p_pdp,
+739                         BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
+740                         d_head: tl.constexpr,
+741                         j, i, steps,
+742                         MASK: tl.constexpr,
+743                         q_seq_len: tl.constexpr,
+744                         kv_seq_len: tl.constexpr):
@@ -1480,11 +1499,11 @@ -

+

To apply the mask

-
698        b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
+
750    tl.static_assert(BLOCK_K % BLOCK_Q == 0)
@@ -1492,11 +1511,12 @@ -

+

Offsets and mask

-
707        b_pT = tl.math.exp2(b_sT - b_l[None, :])
+
753    offs_i = i + tl.arange(0, BLOCK_Q)
+754    offs_j = j + tl.arange(0, BLOCK_K)
@@ -1504,13 +1524,14 @@ -

Autoregressive masking

+

Move the pointers

-
710        if MASK:
-711            mask = (offs_i[None, :] >= offs_j[:, None])
-712            b_pT = tl.where(mask, b_pT, 0.0)
+
757    p_qT = tl.advance(p_qT, (0, i))
+758    p_do = tl.advance(p_do, (i, 0))
+759    p_lse = tl.advance(p_lse, (i,))
+760    p_pdp = tl.advance(p_pdp, (i,))
@@ -1518,13 +1539,11 @@ -

Mask out if the block is beyond the end of

-

Note: No need to mask out based on because the effects on positions outside boundary will not get stored in or Masking by may also not be necessary size the tensors have 0 on loading

+

Iterate over

-
719        i_mask = offs_i < q_seq_len
-720        b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
+
763    for _ in range(steps):
@@ -1532,12 +1551,11 @@ -

+

Load

-
723        b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
-724        b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
+
765        b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
@@ -1545,11 +1563,11 @@ -

+

-
727        b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
+
768        b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
@@ -1557,11 +1575,11 @@ -

+

-
729        b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
+
771        b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
@@ -1569,11 +1587,11 @@ -

+

-
731        b_dsT = b_pT * (b_dpT - b_pdp[None, :])
+
780        b_pT = tl.math.exp2(b_sT - b_l[None, :])
@@ -1581,11 +1599,13 @@ -

+

Autoregressive masking

-
733        b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
+
783        if MASK:
+784            mask = (offs_i[None, :] >= offs_j[:, None])
+785            b_pT = tl.where(mask, b_pT, 0.0)
@@ -1593,15 +1613,13 @@ -

Increment pointers.

+

Mask out if the block is beyond the end of

+

Note: No need to mask out based on because the effects on positions outside boundary will not get stored in or Masking by may also not be necessary size the tensors have 0 on loading

-
736        offs_i += BLOCK_Q
-737        p_lse = tl.advance(p_lse, (BLOCK_Q,))
-738        p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
-739        p_qT = tl.advance(p_qT, (0, BLOCK_Q))
-740        p_do = tl.advance(p_do, (BLOCK_Q, 0))
+
792        i_mask = offs_i < q_seq_len
+793        b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
@@ -1609,11 +1627,12 @@ -

Return accumulated and

+

-
743    return b_dk, b_dv
+
796        b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
+797        b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
@@ -1621,24 +1640,11 @@ - +

+
-
746@triton.autotune(_get_autotune_configs(inner_loop='key'),
-747                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-748@triton.jit
-749def _attn_bwd_dq(t_q, t_k, t_v, t_do,
-750                 t_dq,
-751                 t_lse, t_pdp,
-752                 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-753                 n_groups: tl.constexpr, d_head: tl.constexpr,
-754                 is_causal: tl.constexpr,
-755                 BLOCK_Q: tl.constexpr,
-756                 BLOCK_K: tl.constexpr,
-757                 ):
-758    i = tl.program_id(0) * BLOCK_Q
-759    z = tl.program_id(1) // n_groups
-760    g = tl.program_id(1) % n_groups  # TODO
+
800        b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
@@ -1646,52 +1652,11 @@ -

Create block pointers

+

-
763    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-764                            (q_seq_len, d_head),
-765                            (d_head, 1),
-766                            (i, 0),
-767                            (BLOCK_Q, d_head),
-768                            (1, 0))
-769    p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-770                             (q_seq_len, d_head),
-771                             (d_head, 1),
-772                             (i, 0),
-773                             (BLOCK_Q, d_head),
-774                             (1, 0))
-775    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-776                             (q_seq_len, d_head),
-777                             (d_head, 1),
-778                             (i, 0),
-779                             (BLOCK_Q, d_head),
-780                             (1, 0))
-781    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-782                             (d_head, kv_seq_len),
-783                             (1, d_head),
-784                             (0, 0),
-785                             (d_head, BLOCK_K),
-786                             (0, 1))
-787    p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-788                             (d_head, kv_seq_len),
-789                             (1, d_head),
-790                             (0, 0),
-791                             (d_head, BLOCK_K),
-792                             (0, 1))
-793    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-794                              (q_seq_len,),
-795                              (1,),
-796                              (i,),
-797                              (BLOCK_Q,),
-798                              (0,))
-799    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-800                              (q_seq_len,),
-801                              (1,),
-802                              (i,),
-803                              (BLOCK_Q,),
-804                              (0,))
+
802        b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
@@ -1699,14 +1664,11 @@ -

Load , , , and outside the loop

+

-
807    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
-808    b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
-809    b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
-810    b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
+
804        b_dsT = b_pT * (b_dpT - b_pdp[None, :])
@@ -1714,11 +1676,11 @@ -

Initialize

+

-
813    b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
+
806        b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
@@ -1726,11 +1688,15 @@ -

+

Increment pointers.

-
817    if is_causal:
+
809        offs_i += BLOCK_Q
+810        p_lse = tl.advance(p_lse, (BLOCK_Q,))
+811        p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
+812        p_qT = tl.advance(p_qT, (0, BLOCK_Q))
+813        p_do = tl.advance(p_do, (BLOCK_Q, 0))
@@ -1738,40 +1704,34 @@ -

Compute for masked (diagonal) blocks.

+

Return accumulated and

-
819        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-820                                  b_do, b_lse, b_pdp,
-821                                  BLOCK_Q, BLOCK_K,
-822                                  i=i, j=i,
-823                                  steps=BLOCK_Q // BLOCK_K,
-824                                  MASK=True,
-825                                  q_seq_len=q_seq_len,
-826                                  kv_seq_len=kv_seq_len
-827                                  )
+
816    return b_dk, b_dv
-
+
-

Compute for other blocks

+

Triton kernel to compute

-
830        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-831                                  b_do, b_lse, b_pdp,
-832                                  BLOCK_Q, BLOCK_K,
-833                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
-834                                  steps=i // BLOCK_K,
-835                                  MASK=False,
-836                                  q_seq_len=q_seq_len,
-837                                  kv_seq_len=kv_seq_len
-838                                  )
-839    else:
+
819@triton.autotune(_get_autotune_configs(inner_loop='key'),
+820                 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+821@triton.jit
+822def _attn_bwd_dq(t_q, t_k, t_v, t_do,
+823                 t_dq,
+824                 t_lse, t_pdp,
+825                 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+826                 n_groups: tl.constexpr, d_head: tl.constexpr,
+827                 is_causal: tl.constexpr,
+828                 BLOCK_Q: tl.constexpr,
+829                 BLOCK_K: tl.constexpr,
+830                 ):
@@ -1779,19 +1739,12 @@ -

Iterate through all

- +
-
841        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-842                                  b_do, b_lse, b_pdp,
-843                                  BLOCK_Q, BLOCK_K,
-844                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
-845                                  steps=tl.cdiv(kv_seq_len, BLOCK_K),
-846                                  MASK=False,
-847                                  q_seq_len=q_seq_len,
-848                                  kv_seq_len=kv_seq_len
-849                                  )
+
835    i = tl.program_id(0) * BLOCK_Q
+836    z = tl.program_id(1) // n_groups
+837    g = tl.program_id(1) % n_groups  # TODO
@@ -1799,12 +1752,52 @@ -

b_dq - stores so multiply by to get

+

Create block pointers

-
852    b_dq *= 0.6931471824645996
+
840    p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+841                            (q_seq_len, d_head),
+842                            (d_head, 1),
+843                            (i, 0),
+844                            (BLOCK_Q, d_head),
+845                            (1, 0))
+846    p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+847                             (q_seq_len, d_head),
+848                             (d_head, 1),
+849                             (i, 0),
+850                             (BLOCK_Q, d_head),
+851                             (1, 0))
+852    p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+853                             (q_seq_len, d_head),
+854                             (d_head, 1),
+855                             (i, 0),
+856                             (BLOCK_Q, d_head),
+857                             (1, 0))
+858    p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+859                             (d_head, kv_seq_len),
+860                             (1, d_head),
+861                             (0, 0),
+862                             (d_head, BLOCK_K),
+863                             (0, 1))
+864    p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+865                             (d_head, kv_seq_len),
+866                             (1, d_head),
+867                             (0, 0),
+868                             (d_head, BLOCK_K),
+869                             (0, 1))
+870    p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+871                              (q_seq_len,),
+872                              (1,),
+873                              (i,),
+874                              (BLOCK_Q,),
+875                              (0,))
+876    p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+877                              (q_seq_len,),
+878                              (1,),
+879                              (i,),
+880                              (BLOCK_Q,),
+881                              (0,))
@@ -1812,30 +1805,26 @@ -

Save

+

Load , , , and outside the loop

-
855    tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
+
884    b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
+885    b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
+886    b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
+887    b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
-
+
-

Inner loop over key

+

Initialize

-
858@triton.jit
-859def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-860                       b_do, b_lse, b_pdp,
-861                       BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
-862                       i, j, steps,
-863                       MASK: tl.constexpr,
-864                       q_seq_len: tl.constexpr,
-865                       kv_seq_len: tl.constexpr):
+
890    b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
@@ -1843,12 +1832,11 @@ -

Offsets

+

-
869    offs_i = i + tl.arange(0, BLOCK_Q)
-870    offs_j = j + tl.arange(0, BLOCK_K)
+
894    if is_causal:
@@ -1856,14 +1844,19 @@ -

Move the pointers

+

Compute for masked (diagonal) blocks.

-
873    p_kT = tl.advance(p_kT, (0, j))
-874    p_vT = tl.advance(p_vT, (0, j))
-875
-876    tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')
+
896        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+897                                  b_do, b_lse, b_pdp,
+898                                  BLOCK_Q, BLOCK_K,
+899                                  i=i, j=i,
+900                                  steps=BLOCK_Q // BLOCK_K,
+901                                  MASK=True,
+902                                  q_seq_len=q_seq_len,
+903                                  kv_seq_len=kv_seq_len
+904                                  )
@@ -1871,11 +1864,20 @@ -

Iterate over

+

Compute for other blocks

-
879    for _ in range(steps):
+
907        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+908                                  b_do, b_lse, b_pdp,
+909                                  BLOCK_Q, BLOCK_K,
+910                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
+911                                  steps=i // BLOCK_K,
+912                                  MASK=False,
+913                                  q_seq_len=q_seq_len,
+914                                  kv_seq_len=kv_seq_len
+915                                  )
+916    else:
@@ -1883,11 +1885,19 @@ -

Load

+

Iterate through all

-
881        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
+
918        b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+919                                  b_do, b_lse, b_pdp,
+920                                  BLOCK_Q, BLOCK_K,
+921                                  i=i, j=tl.full([], 0, tl.int32),  # type: ignore
+922                                  steps=tl.cdiv(kv_seq_len, BLOCK_K),
+923                                  MASK=False,
+924                                  q_seq_len=q_seq_len,
+925                                  kv_seq_len=kv_seq_len
+926                                  )
@@ -1895,11 +1905,12 @@ -

Load

+

b_dq + stores so multiply by to get

-
883        b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")
+
929    b_dq *= 0.6931471824645996
@@ -1907,23 +1918,30 @@ -

+

Save

-
886        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
+
932    tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
-
+
-

+

Inner loop to calculate

-
895        b_p = tl.math.exp2(b_s - b_lse[:, None])
+
935@triton.jit
+936def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+937                       b_do, b_lse, b_pdp,
+938                       BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
+939                       i, j, steps,
+940                       MASK: tl.constexpr,
+941                       q_seq_len: tl.constexpr,
+942                       kv_seq_len: tl.constexpr):
@@ -1931,13 +1949,12 @@ -

Autoregressive masking

+

Offsets

-
898        if MASK:
-899            causal_mask = (offs_i[:, None] >= offs_j[None, :])
-900            b_p = tl.where(causal_mask, b_p, 0.0)
+
948    offs_i = i + tl.arange(0, BLOCK_Q)
+949    offs_j = j + tl.arange(0, BLOCK_K)
@@ -1945,12 +1962,14 @@ -

Mask out if the block is beyond the end of

+

Move the pointers

-
903        j_mask = offs_j < kv_seq_len
-904        b_p = tl.where(j_mask[None, :], b_p, 0.0)
+
952    p_kT = tl.advance(p_kT, (0, j))
+953    p_vT = tl.advance(p_vT, (0, j))
+954
+955    tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')
@@ -1958,11 +1977,11 @@ -

+

Iterate over

-
+
958    for _ in range(steps):
@@ -1970,11 +1989,11 @@ -

+

Load

-
909        b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
+
960        b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
@@ -1982,11 +2001,11 @@ -

+

Load

-
911        b_ds = b_p * (b_dp - b_pdp[:, None])
+
962        b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")
@@ -1994,11 +2013,11 @@ -

+

-
913        b_dq += tl.dot(b_ds.to(b_kT.dtype), tl.trans(b_kT), out_dtype=HI_PRES_TL)
+
965        b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
@@ -2006,13 +2025,11 @@ -

Increment pointers.

+

-
916        offs_j += BLOCK_K
-917        p_kT = tl.advance(p_kT, (0, BLOCK_K))
-918        p_vT = tl.advance(p_vT, (0, BLOCK_K))
+
974        b_p = tl.math.exp2(b_s - b_lse[:, None])
@@ -2020,11 +2037,100 @@ -

Return accumulated

+

Autoregressive masking

-
921    return b_dq
+
977        if MASK:
+978            causal_mask = (offs_i[:, None] >= offs_j[None, :])
+979            b_p = tl.where(causal_mask, b_p, 0.0)
+
+ +
+
+ +

Mask out if the block is beyond the end of

+ +
+
+
982        j_mask = offs_j < kv_seq_len
+983        b_p = tl.where(j_mask[None, :], b_p, 0.0)
+
+
+
+
+ +

+ +
+
+
+
+
+
+
+ +

+ +
+
+
988        b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
+
+
+
+
+ +

+ +
+
+
990        b_ds = b_p * (b_dp - b_pdp[:, None])
+
+
+
+
+ +

+ +
+
+
992        b_dq += tl.dot(b_ds.to(b_kT.dtype), tl.trans(b_kT), out_dtype=HI_PRES_TL)
+
+
+
+
+ +

Increment pointers.

+ +
+
+
995        offs_j += BLOCK_K
+996        p_kT = tl.advance(p_kT, (0, BLOCK_K))
+997        p_vT = tl.advance(p_vT, (0, BLOCK_K))
+
+
+
+
+ +

Return accumulated

+ +
+
+
1000    return b_dq
-
+
- +

Test Flash Attention Implementation

+

This is the code to test and measure performance of our flash attention implementation

+
-
1import triton
-2
-3import torch
-4from labml import logger, monit
-5from labml_nn.transformers.flash import attention
-6
-7HI_PRES_TORCH = torch.float32
+
7import torch
+8import triton
+9
+10from labml import logger, monit
+11from labml_nn.transformers.flash import attention
+12
+13HI_PRES_TORCH = torch.float32
-
+
- +

Calculate absolute and relative error for reporting

+
-
10@torch.no_grad()
-11def _calc_abs_rel_error(a: torch.Tensor, b: torch.Tensor, atol=1e-2):
-12    d = (a - b).abs()
-13    max_abs = d.max()
-14    d = (d - atol).clamp(min=0)
-15    d = d / b.abs()
-16    max_rel = d.max()
-17
-18    return max_abs.cpu().item(), max_rel.cpu().item()
-19
-20
-21def _test_op(batch_size, n_heads, k_heads, q_seq_len, kv_seq_len, d_head, causal, dtype, device):
-22    with monit.section(f'Init {q_seq_len} {kv_seq_len} {d_head}'):
-23        torch.manual_seed(20)
-24        q = (torch.empty((batch_size, n_heads, q_seq_len, d_head),
-25                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
-26        k = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
-27                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
-28        v = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
-29                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
-30        sm_scale = d_head ** -0.5
-31        d_out = torch.randn_like(q)
+
16@torch.no_grad()
+17def _calc_abs_rel_error(a: torch.Tensor, b: torch.Tensor, atol=1e-2):
@@ -120,76 +103,28 @@ -

reference implementation

- +
-
33        mask = torch.tril(torch.ones((q_seq_len, kv_seq_len), device=device, dtype=torch.bool))
-34        torch.cuda.synchronize()
-35
-36    with monit.section('Pytorch'):
-37        p = torch.matmul(q.view(batch_size, k_heads, -1, q_seq_len, d_head),
-38                         k.transpose(2, 3)[:, :, None, :, :]) * sm_scale
-39        if causal:
-40            p[:, :, :, ~mask] = float("-inf")
-41        p = torch.softmax(p.to(HI_PRES_TORCH), dim=-1).to(dtype)
-42        ref_out = torch.matmul(p, v[:, :, None, :, :])
-43        ref_out = ref_out.view(q.shape)
-44        ref_out.backward(d_out)
-45        ref_dv, v.grad = v.grad.clone(), None
-46        ref_dk, k.grad = k.grad.clone(), None
-47        ref_dq, q.grad = q.grad.clone(), None
-48        torch.cuda.synchronize()
-49
-50    with monit.section('Triton'):
-51        assert q.dtype == dtype
-52        tri_out = attention(q, k, v, causal, sm_scale).to(dtype)
-53        monit.progress(0.5)
-54
-55        tri_out.backward(d_out)
-56        monit.progress(0.9)
-57        tri_dv, v.grad = v.grad.clone(), None  # type: ignore
-58        tri_dk, k.grad = k.grad.clone(), None  # type: ignore
-59        tri_dq, q.grad = q.grad.clone(), None  # type: ignore
-60        torch.cuda.synchronize()
-61
-62    with monit.section('Test') as s:
+
21    d = (a - b).abs()
+22    max_abs = d.max()
+23    d = (d - atol).clamp(min=0)
+24    d = d / b.abs()
+25    max_rel = d.max()
+26
+27    return max_abs.cpu().item(), max_rel.cpu().item()
-
+
-

compare

+

Compare our implementation with naive PyTorch attention

-
64        passed = True
-65        if not torch.allclose(tri_out, ref_out, atol=1e-2, rtol=0.):
-66            abs_err, rel_err = _calc_abs_rel_error(ref_out, tri_out)
-67            logger.log(('[FAILED]', logger.Text.danger), f' Out mismatch {abs_err} {rel_err}')
-68            passed = False
-69        rtol = 1e-1
-70        if not torch.allclose(tri_dq, ref_dq, atol=1e-2, rtol=rtol):
-71            abs_err, rel_err = _calc_abs_rel_error(ref_dq, tri_dq)
-72            logger.log(('[FAILED]', logger.Text.danger), f' dQ mismatch {abs_err} {rel_err}')
-73            passed = False
-74        if not torch.allclose(tri_dv, ref_dv, atol=1e-2, rtol=rtol):
-75            abs_err, rel_err = _calc_abs_rel_error(ref_dv, tri_dv)
-76            logger.log(('[FAILED]', logger.Text.danger), f' dV mismatch {abs_err} {rel_err}')
-77            passed = False
-78        if not torch.allclose(tri_dk, ref_dk, atol=1e-2, rtol=rtol):
-79            abs_err, rel_err = _calc_abs_rel_error(ref_dk, tri_dk)
-80            logger.log(('[FAILED]', logger.Text.danger), f' dK mismatch {abs_err} {rel_err}')
-81            passed = False
-82
-83        if passed:
-84            logger.log('[PASSED]', logger.Text.success)
-85            s.success = True
-86        else:
-87            s.success = False
-88        torch.cuda.synchronize()
+
30def test_fwd_bwd(batch_size, n_heads, k_heads, q_seq_len, kv_seq_len, d_head, causal, dtype, device):
@@ -200,12 +135,16 @@
-
91def _perf_triton_fn(*, device, dtype, batch_size, k_heads, n_groups, seq_len, d_head, causal):
-92    q = torch.randn((batch_size, k_heads * n_groups, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
-93    k = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
-94    v = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
-95    sm_scale = d_head ** -0.5
-96    return lambda: attention(q, k, v, causal, sm_scale)
+
35    with monit.section(f'Init {q_seq_len} {kv_seq_len} {d_head}'):
+36        torch.manual_seed(20)
+37        q = (torch.empty((batch_size, n_heads, q_seq_len, d_head),
+38                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
+39        k = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
+40                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
+41        v = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
+42                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
+43        sm_scale = d_head ** -0.5
+44        d_out = torch.randn_like(q)
@@ -213,15 +152,40 @@ - +

reference implementation

+
-
99def _perf_flash(*, batch_size, k_heads, n_groups, seq_len, d_head, causal, device, dtype):
-100    q = torch.randn((batch_size, seq_len, k_heads * n_groups, d_head), dtype=dtype, device=device, requires_grad=True)
-101    k = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
-102    v = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
-103    from flash_attn import flash_attn_func
-104    return lambda: flash_attn_func(q, k, v, causal=causal)
+
46        mask = torch.tril(torch.ones((q_seq_len, kv_seq_len), device=device, dtype=torch.bool))
+47        torch.cuda.synchronize()
+48
+49    with monit.section('Pytorch'):
+50        p = torch.matmul(q.view(batch_size, k_heads, -1, q_seq_len, d_head),
+51                         k.transpose(2, 3)[:, :, None, :, :]) * sm_scale
+52        if causal:
+53            p[:, :, :, ~mask] = float("-inf")
+54        p = torch.softmax(p.to(HI_PRES_TORCH), dim=-1).to(dtype)
+55        ref_out = torch.matmul(p, v[:, :, None, :, :])
+56        ref_out = ref_out.view(q.shape)
+57        ref_out.backward(d_out)
+58        ref_dv, v.grad = v.grad.clone(), None
+59        ref_dk, k.grad = k.grad.clone(), None
+60        ref_dq, q.grad = q.grad.clone(), None
+61        torch.cuda.synchronize()
+62
+63    with monit.section('Triton'):
+64        assert q.dtype == dtype
+65        tri_out = attention(q, k, v, causal, sm_scale).to(dtype)
+66        monit.progress(0.5)
+67
+68        tri_out.backward(d_out)
+69        monit.progress(0.9)
+70        tri_dv, v.grad = v.grad.clone(), None  # type: ignore
+71        tri_dk, k.grad = k.grad.clone(), None  # type: ignore
+72        tri_dq, q.grad = q.grad.clone(), None  # type: ignore
+73        torch.cuda.synchronize()
+74
+75    with monit.section('Test') as s:
@@ -229,40 +193,47 @@ - +

compare

+
-
107def _perf_fn(name, fn, *, batch_size, k_heads, n_groups, seq_len, d_head, causal, is_bwd: bool):
-108    if is_bwd:
-109        o = fn()
-110        do = torch.randn_like(o)
-111        fn = lambda: o.backward(do, retain_graph=True)
-112    ms = triton.testing.do_bench(fn)
-113
-114    flops_per_matmul = 2.0 * batch_size * k_heads * n_groups * seq_len * seq_len * d_head
-115    total_flops = 2 * flops_per_matmul
-116    if causal:
-117        total_flops *= 0.5
-118    if is_bwd:
-119        total_flops *= 2.5  # 2.0(bwd) + 0.5(recompute)
-120
-121    tf_ps = total_flops * 1e-12 / (ms * 1e-3)
-122    logger.log((f'{name}', logger.Text.key), ': ', f'{ms :,.1f}ms', ' ', f'{tf_ps :,.2f}TFps')
+
77        passed = True
+78        if not torch.allclose(tri_out, ref_out, atol=1e-2, rtol=0.):
+79            abs_err, rel_err = _calc_abs_rel_error(ref_out, tri_out)
+80            logger.log(('[FAILED]', logger.Text.danger), f' Out mismatch {abs_err} {rel_err}')
+81            passed = False
+82        rtol = 1e-1
+83        if not torch.allclose(tri_dq, ref_dq, atol=1e-2, rtol=rtol):
+84            abs_err, rel_err = _calc_abs_rel_error(ref_dq, tri_dq)
+85            logger.log(('[FAILED]', logger.Text.danger), f' dQ mismatch {abs_err} {rel_err}')
+86            passed = False
+87        if not torch.allclose(tri_dv, ref_dv, atol=1e-2, rtol=rtol):
+88            abs_err, rel_err = _calc_abs_rel_error(ref_dv, tri_dv)
+89            logger.log(('[FAILED]', logger.Text.danger), f' dV mismatch {abs_err} {rel_err}')
+90            passed = False
+91        if not torch.allclose(tri_dk, ref_dk, atol=1e-2, rtol=rtol):
+92            abs_err, rel_err = _calc_abs_rel_error(ref_dk, tri_dk)
+93            logger.log(('[FAILED]', logger.Text.danger), f' dK mismatch {abs_err} {rel_err}')
+94            passed = False
+95
+96        if passed:
+97            logger.log('[PASSED]', logger.Text.success)
+98            s.success = True
+99        else:
+100            s.success = False
+101        torch.cuda.synchronize()
-
+
- +

Get a partial function to test performance of our implementation

+
-
125def _test():
-126    device = torch.device('cuda:0')
-127    torch.cuda.set_device(device)
-128
-129    dtype = torch.float16
+
104def _perf_triton_fn(*, device, dtype, batch_size, k_heads, n_groups, seq_len, d_head, causal):
@@ -270,36 +241,130 @@ + +
+
+
108    q = torch.randn((batch_size, k_heads * n_groups, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
+109    k = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
+110    v = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
+111    sm_scale = d_head ** -0.5
+112    return lambda: attention(q, k, v, causal, sm_scale)
+
+
+
+
+ +

Get a partial function to test performance of original flash implementation

+ +
+
+
115def _perf_flash(*, batch_size, k_heads, n_groups, seq_len, d_head, causal, device, dtype):
+
+
+
+
+ + +
+
+
119    q = torch.randn((batch_size, seq_len, k_heads * n_groups, d_head), dtype=dtype, device=device, requires_grad=True)
+120    k = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
+121    v = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
+122    from flash_attn import flash_attn_func
+123    return lambda: flash_attn_func(q, k, v, causal=causal)
+
+
+
+
+ +

Measure the speed

+ +
+
+
126def measure_performance(name, fn, *, batch_size, k_heads, n_groups, seq_len, d_head, causal, is_bwd: bool):
+
+
+
+
+ + +
+
+
130    if is_bwd:
+131        o = fn()
+132        do = torch.randn_like(o)
+133        fn = lambda: o.backward(do, retain_graph=True)
+134    ms = triton.testing.do_bench(fn)
+135
+136    flops_per_matmul = 2.0 * batch_size * k_heads * n_groups * seq_len * seq_len * d_head
+137    total_flops = 2 * flops_per_matmul
+138    if causal:
+139        total_flops *= 0.5
+140    if is_bwd:
+141        total_flops *= 2.5  # 2.0(bwd) + 0.5(recompute)
+142
+143    tf_ps = total_flops * 1e-12 / (ms * 1e-3)
+144    logger.log((f'{name}', logger.Text.key), ': ', f'{ms :,.1f}ms', ' ', f'{tf_ps :,.2f}TFps')
+
+
+
+
+ + +
+
+
147def main():
+148    device = torch.device('cuda:0')
+149    torch.cuda.set_device(device)
+150
+151    dtype = torch.float16
+
+
+
+
+

only works on post-Ampere GPUs right now

-
132    _test_op(1, 4, 1, 2048, 2048, 128, True, dtype=dtype, device=device)
-133    _test_op(16, 32, 8, 2001, 4001, 128, False, dtype=dtype, device=device)
-134    _test_op(4, 32, 8, 2048, 1024, 128, False, dtype=dtype, device=device)
-135    _test_op(4, 32, 8, 2001, 4001, 128, True, dtype=dtype, device=device)
-136
-137    _conf = {
-138        'batch_size': 16,
-139        'k_heads': 8,
-140        'n_groups': 4,
-141        'seq_len': 2048,
-142        'd_head': 128,
-143    }
-144
-145    for _causal in [False, True]:
-146        for is_bwd in [False, True]:
-147            logger.log(f'{"Causal" if _causal else "Non-causal"} {" Backward" if is_bwd else ""}', logger.Text.title)
-148            _perf_fn(f'flash', _perf_flash(causal=_causal, device=device, dtype=dtype, **_conf),
-149                     is_bwd=is_bwd,
-150                     causal=_causal, **_conf)
-151            _perf_fn(f'triton', _perf_triton_fn(causal=_causal, device=device, dtype=dtype, **_conf),
-152                     is_bwd=is_bwd,
-153                     causal=_causal, **_conf)
-154
-155
-156if __name__ == "__main__":
-157    _test()
+
154    test_fwd_bwd(1, 4, 1, 2048, 2048, 128, True, dtype=dtype, device=device)
+155    test_fwd_bwd(16, 32, 8, 2001, 4001, 128, False, dtype=dtype, device=device)
+156    test_fwd_bwd(4, 32, 8, 2048, 1024, 128, False, dtype=dtype, device=device)
+157    test_fwd_bwd(4, 32, 8, 2001, 4001, 128, True, dtype=dtype, device=device)
+158
+159    _conf = {
+160        'batch_size': 16,
+161        'k_heads': 8,
+162        'n_groups': 4,
+163        'seq_len': 2048,
+164        'd_head': 128,
+165    }
+166
+167    for _causal in [False, True]:
+168        for is_bwd in [False, True]:
+169            logger.log(f'{"Causal" if _causal else "Non-causal"} {" Backward" if is_bwd else ""}', logger.Text.title)
+170            measure_performance(f'flash', _perf_flash(causal=_causal, device=device, dtype=dtype, **_conf),
+171                                is_bwd=is_bwd,
+172                                causal=_causal, **_conf)
+173            measure_performance(f'triton', _perf_triton_fn(causal=_causal, device=device, dtype=dtype, **_conf),
+174                                is_bwd=is_bwd,
+175                                causal=_causal, **_conf)
+176
+177
+178if __name__ == "__main__":
+179    main()