diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 59b8be13..a1eb71f3 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -1086,7 +1086,7 @@
You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys:
-Finally you can compute,
--
To make it numerically stable flash attention subtracts the current max of before exponentiating.
+You can compute , instead of doing the full softmax, by computing the sum of exponents and the unnormalized output while iterating over keys:
+Finally you can compute,
++
To make it numerically stable flash attention subtracts the current max of before exponentiating.
So it maintains the following while iterating over keys:
-For each block of keys it updates them:
-Then finally,
-+
For each block of keys it updates them:
+Then finally,
+
where is when and otherwise.
-Flash attention paper introduces to simplify computation.
-Then,
-Note: , , , etc are row vectors.
+where is when and otherwise.
+Flash attention paper introduces to simplify computation.
+Then,
+Note: , , , etc are row vectors.
100from typing import Any, Tuple
-101
-102import torch
-103import triton
-104import triton.language as tl
-105
-106HI_PRES_TL: tl.constexpr = tl.float32
-107HI_PRES_TORCH: torch.dtype = torch.float32
101from typing import Any, Tuple
+102
+103import torch
+104import triton
+105import triton.language as tl
+106
+107HI_PRES_TL: tl.constexpr = tl.float32
+108HI_PRES_TORCH: torch.dtype = torch.float32
110class AttentionFunc(torch.autograd.Function):
111class AttentionFunc(torch.autograd.Function):
111 @staticmethod
-112 def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
-113 causal: bool, sm_scale: float) -> torch.Tensor:
112 @staticmethod
+113 def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+114 causal: bool, sm_scale: float) -> torch.Tensor:
125 batch_size, n_heads, q_seq_len, d_head = q.shape
-126 _, k_heads, kv_seq_len, _ = k.shape
-127 assert n_heads % k_heads == 0
-128 n_groups = n_heads // k_heads
126 batch_size, n_heads, q_seq_len, d_head = q.shape
+127 _, k_heads, kv_seq_len, _ = k.shape
+128 assert n_heads % k_heads == 0
+129 n_groups = n_heads // k_heads
131 assert d_head == k.shape[-1] == v.shape[-1]
-132 assert d_head in {16, 32, 64, 128, 256}
132 assert d_head == k.shape[-1] == v.shape[-1]
+133 assert d_head in {16, 32, 64, 128, 256}
135 q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
-136 k = k.view(batch_size * k_heads, kv_seq_len, d_head)
-137 v = v.view(batch_size * k_heads, kv_seq_len, d_head)
136 q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
+137 k = k.view(batch_size * k_heads, kv_seq_len, d_head)
+138 v = v.view(batch_size * k_heads, kv_seq_len, d_head)
140 assert q.is_contiguous()
-141 assert k.is_contiguous()
-142 assert v.is_contiguous()
-143 assert k.stride() == v.stride()
141 assert q.is_contiguous()
+142 assert k.is_contiguous()
+143 assert v.is_contiguous()
+144 assert k.stride() == v.stride()
146 o = torch.empty_like(q)
147 o = torch.empty_like(q)
148 lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
149 lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
151 grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
-152 _attn_fwd[grid](
-153 q, k, v, sm_scale, lse, o,
-154 n_groups=n_groups,
-155 q_seq_len=q_seq_len,
-156 kv_seq_len=kv_seq_len,
-157 d_head=d_head,
-158 is_causal=causal,
-159 )
152 grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
+153 _attn_fwd[grid](
+154 q, k, v, sm_scale, lse, o,
+155 n_groups=n_groups,
+156 q_seq_len=q_seq_len,
+157 kv_seq_len=kv_seq_len,
+158 d_head=d_head,
+159 is_causal=causal,
+160 )
162 ctx.save_for_backward(q, k, v, o, lse)
-163 ctx.sm_scale = sm_scale
-164 ctx.n_groups = n_groups
-165 ctx.causal = causal
163 ctx.save_for_backward(q, k, v, o, lse)
+164 ctx.sm_scale = sm_scale
+165 ctx.n_groups = n_groups
+166 ctx.causal = causal
168 return o.view(batch_size, n_heads, q_seq_len, d_head)
169 return o.view(batch_size, n_heads, q_seq_len, d_head)
170 @staticmethod
-171 def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
171 @staticmethod
+172 def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
180 n_groups = ctx.n_groups
-181 sm_scale = ctx.sm_scale
-182 causal = ctx.causal
-183 q, k, v, o, lse = ctx.saved_tensors
181 n_groups = ctx.n_groups
+182 sm_scale = ctx.sm_scale
+183 causal = ctx.causal
+184 q, k, v, o, lse = ctx.saved_tensors
186 batch_size, n_heads, q_seq_len, d_head = do.shape
-187 _, kv_seq_len, _ = k.shape
-188 k_heads = n_heads // n_groups
187 batch_size, n_heads, q_seq_len, d_head = do.shape
+188 _, kv_seq_len, _ = k.shape
+189 k_heads = n_heads // n_groups
191 do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
192 do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
194 assert do.is_contiguous()
-195 assert k.stride() == v.stride()
-196 assert q.stride() == o.stride() == do.stride()
195 assert do.is_contiguous()
+196 assert k.stride() == v.stride()
+197 assert q.stride() == o.stride() == do.stride()
199 dq = torch.empty_like(q)
-200 dk = torch.empty_like(k)
-201 dv = torch.empty_like(v)
200 dq = torch.empty_like(q)
+201 dk = torch.empty_like(k)
+202 dv = torch.empty_like(v)
204 RCP_LN2 = 1.4426950408889634
205 RCP_LN2 = 1.4426950408889634
206 k_scaled = k * (sm_scale * RCP_LN2)
207 k_scaled = k * (sm_scale * RCP_LN2)
208 pdp = torch.empty_like(lse)
209 pdp = torch.empty_like(lse)
We use fixed BLOCK_Q
- for backward pass on
210 BLOCK_Q = 16
Compute
+Compute
This is parallelized along the batch and query in blocks of size BLOCK_Q
214 pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
-215 _attn_bwd_d[pre_grid](
-216 o, do,
-217 pdp,
-218 BLOCK_Q=16,
-219 d_head=d_head,
-220 q_seq_len=q_seq_len,
-221 n_groups=n_groups,
-222 num_stages=1,
-223 )
215 BLOCK_Q = 16
+216 pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
+217 _attn_bwd_d[pre_grid](
+218 o, do,
+219 pdp,
+220 BLOCK_Q=16,
+221 d_head=d_head,
+222 q_seq_len=q_seq_len,
+223 n_groups=n_groups,
+224 num_stages=1,
+225 )
Compute and
+Compute and
This is parallelized along the batch and keys in blocks of size BLOCK_K
227 grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
-228 _attn_bwd_dkdv[grid](
-229 q, k_scaled, v, sm_scale, do, dk, dv,
-230 lse, pdp,
-231 q_seq_len, kv_seq_len, n_groups, d_head,
-232 is_causal=causal,
-233
-234 )
230 grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
+231 _attn_bwd_dkdv[grid](
+232 q, k_scaled, v, sm_scale, do, dk, dv,
+233 lse, pdp,
+234 q_seq_len, kv_seq_len, n_groups, d_head,
+235 is_causal=causal,
+236
+237 )
Compute
+Compute
This is parallelized along the batch and queries in blocks of size BLOCK_Q
238 grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
-239 _attn_bwd_dq[grid](
-240 q, k_scaled, v, do,
-241 dq,
-242 lse, pdp,
-243 q_seq_len, kv_seq_len, n_groups, d_head,
-244 is_causal=causal,
-245 )
242 grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
+243 _attn_bwd_dq[grid](
+244 q, k_scaled, v, do,
+245 dq,
+246 lse, pdp,
+247 q_seq_len, kv_seq_len, n_groups, d_head,
+248 is_causal=causal,
+249 )
248 dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
-249 dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
-250 dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
252 dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
+253 dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
+254 dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
253 return dq, dk, dv, None, None
-254
-255
-256attention = AttentionFunc.apply
257 return dq, dk, dv, None, None
+258
+259
+260attention = AttentionFunc.apply
259def _get_autotune_configs(inner_loop: str) -> list:
263def _get_autotune_configs(inner_loop: str) -> list:
264 configs = []
268 configs = []
267 for bm in [64, 128, 256]:
271 for bm in [64, 128, 256]:
269 for bn in [64, 128, 256]:
-270 if inner_loop == 'key' and bm % bn != 0:
-271 continue
-272 if inner_loop == 'query' and bn % bm != 0:
-273 continue
-274 for s in [2, 3, 4]:
-275 for w in [4, 8]:
-276 if bm * bn < 128 * 128 and w == 8:
-277 continue
-278
-279 configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
-280
-281 return configs[:1]
273 for bn in [64, 128, 256]:
+274 if inner_loop == 'key' and bm % bn != 0:
+275 continue
+276 if inner_loop == 'query' and bn % bm != 0:
+277 continue
+278 for s in [2, 3, 4]:
+279 for w in [4, 8]:
+280 if bm * bn < 128 * 128 and w == 8:
+281 continue
+282
+283 configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
+284
+285 return configs[:1]
sm_scale
softmax scale t_lse
- (out) t_o
output (out) n_groups
@@ -609,18 +610,18 @@
284@triton.autotune(_get_autotune_configs(inner_loop='key'),
-285 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-286@triton.jit
-287def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
-288 n_groups: tl.constexpr,
-289 q_seq_len: tl.constexpr,
-290 kv_seq_len: tl.constexpr,
-291 d_head: tl.constexpr,
-292 is_causal: tl.constexpr,
-293 BLOCK_Q: tl.constexpr, # q seq len block
-294 BLOCK_K: tl.constexpr, # k seq len block
-295 ):
288@triton.autotune(_get_autotune_configs(inner_loop='key'),
+289 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+290@triton.jit
+291def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
+292 n_groups: tl.constexpr,
+293 q_seq_len: tl.constexpr,
+294 kv_seq_len: tl.constexpr,
+295 d_head: tl.constexpr,
+296 is_causal: tl.constexpr,
+297 BLOCK_Q: tl.constexpr, # q seq len block
+298 BLOCK_K: tl.constexpr, # k seq len block
+299 ):
316 i = tl.program_id(0)
-317 z = tl.program_id(1) // n_groups
-318 g = tl.program_id(1) % n_groups
320 i = tl.program_id(0)
+321 z = tl.program_id(1) // n_groups
+322 g = tl.program_id(1) % n_groups
321 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-322 (q_seq_len, d_head),
-323 (d_head, 1),
-324 (i * BLOCK_Q, 0),
-325 (BLOCK_Q, d_head),
-326 (1, 0))
-327 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-328 (kv_seq_len, d_head),
-329 (d_head, 1),
-330 (0, 0),
-331 (BLOCK_K, d_head),
-332 (1, 0))
-333 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-334 (d_head, kv_seq_len),
-335 (1, d_head),
-336 (0, 0),
-337 (d_head, BLOCK_K),
-338 (0, 1))
-339 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-340 (q_seq_len, d_head),
-341 (d_head, 1),
-342 (i * BLOCK_Q, 0),
-343 (BLOCK_Q, d_head),
-344 (1, 0))
-345 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-346 (q_seq_len,),
-347 (1,),
-348 (i * BLOCK_Q,),
-349 (BLOCK_Q,),
-350 (0,))
325 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+326 (q_seq_len, d_head),
+327 (d_head, 1),
+328 (i * BLOCK_Q, 0),
+329 (BLOCK_Q, d_head),
+330 (1, 0))
+331 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+332 (kv_seq_len, d_head),
+333 (d_head, 1),
+334 (0, 0),
+335 (BLOCK_K, d_head),
+336 (1, 0))
+337 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+338 (d_head, kv_seq_len),
+339 (1, d_head),
+340 (0, 0),
+341 (d_head, BLOCK_K),
+342 (0, 1))
+343 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+344 (q_seq_len, d_head),
+345 (d_head, 1),
+346 (i * BLOCK_Q, 0),
+347 (BLOCK_Q, d_head),
+348 (1, 0))
+349 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+350 (q_seq_len,),
+351 (1,),
+352 (i * BLOCK_Q,),
+353 (BLOCK_Q,),
+354 (0,))
353 offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
-354 i_mask = offs_i < q_seq_len
-355 offs_j = tl.arange(0, BLOCK_K)
357 offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
+358 offs_j = tl.arange(0, BLOCK_K)
358 b_m = tl.where(i_mask, -float("inf"), 0.0)
-359 b_l = tl.where(i_mask, 1.0, 0.0)
360 i_mask = offs_i < q_seq_len
Accumulate
+Precalculate .
+We will be use this when calculating so S
+ will store instead.
361 b_acc = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
365 sm_scale = sm_scale * 1.44269504
softmax scale / log(2)
+Initialize and . is initialized to and to . So in the first update, the effect of initial is .
+b_m
+ will be storing
364 sm_scale = sm_scale * 1.44269504
371 b_m = tl.where(i_mask, -float("inf"), 0.0)
+372 b_l = tl.where(i_mask, 1.0, 0.0)
366 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
-367
-368 if is_causal:
375 b_o = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
Upto the diagonal block
+Load outside the loop since it will be reused through out the loop over .
370 b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
-371 p_kT, p_v,
-372 sm_scale,
-373 BLOCK_Q, d_head, BLOCK_K,
-374 offs_i, offs_j,
-375 j=tl.full([], 0, tl.int32), # type: ignore
-376 steps=(i * BLOCK_Q) // BLOCK_K,
-377 MASK=False,
-378 q_seq_len=q_seq_len,
-379 kv_seq_len=kv_seq_len
-380 )
378 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
+379
+380 if is_causal:
Diagonal block with masking within it
+Inner loop upto the diagonal block
382 b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
-383 sm_scale,
-384 BLOCK_Q, d_head, BLOCK_K,
-385 offs_i, offs_j,
-386 j=i * BLOCK_Q,
-387 steps=BLOCK_Q // BLOCK_K,
-388 MASK=True,
-389 q_seq_len=q_seq_len,
-390 kv_seq_len=kv_seq_len
-391 )
-392 else:
-393 b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
-394 sm_scale,
-395 BLOCK_Q, d_head, BLOCK_K,
-396 offs_i, offs_j,
-397 j=tl.full([], 0, tl.int32), # type: ignore
-398 steps=tl.cdiv(kv_seq_len, BLOCK_K),
-399 MASK=False,
-400 q_seq_len=q_seq_len,
-401 kv_seq_len=kv_seq_len
-402 )
382 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
+383 p_kT, p_v,
+384 sm_scale,
+385 BLOCK_Q, d_head, BLOCK_K,
+386 offs_i, offs_j,
+387 j=tl.full([], 0, tl.int32), # type: ignore
+388 steps=(i * BLOCK_Q) // BLOCK_K,
+389 MASK=False,
+390 q_seq_len=q_seq_len,
+391 kv_seq_len=kv_seq_len
+392 )
405 tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
-406 tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
394 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
+395 sm_scale,
+396 BLOCK_Q, d_head, BLOCK_K,
+397 offs_i, offs_j,
+398 j=i * BLOCK_Q,
+399 steps=BLOCK_Q // BLOCK_K,
+400 MASK=True,
+401 q_seq_len=q_seq_len,
+402 kv_seq_len=kv_seq_len
+403 )
+404 else:
Iterate through all
+409@triton.jit
-410def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
-411 p_kT, p_v,
-412 scale,
-413 BLOCK_Q: tl.constexpr,
-414 d_head: tl.constexpr,
-415 BLOCK_K: tl.constexpr,
-416 offs_i, offs_j,
-417 j,
-418 steps,
-419 MASK: tl.constexpr,
-420 q_seq_len: tl.constexpr,
-421 kv_seq_len: tl.constexpr
-422 ):
-423 tl.static_assert(BLOCK_Q % BLOCK_K == 0)
-424
-425 p_kT = tl.advance(p_kT, (0, j))
-426 p_v = tl.advance(p_v, (j, 0))
406 b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
+407 sm_scale,
+408 BLOCK_Q, d_head, BLOCK_K,
+409 offs_i, offs_j,
+410 j=tl.full([], 0, tl.int32), # type: ignore
+411 steps=tl.cdiv(kv_seq_len, BLOCK_K),
+412 MASK=False,
+413 q_seq_len=q_seq_len,
+414 kv_seq_len=kv_seq_len
+415 )
429 for _ in range(steps):
-430 current_j = j + offs_j
-431 j_mask = current_j < kv_seq_len
-432
-433 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
-434 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
-435
-436 tl.static_assert(b_s.dtype == HI_PRES_TL)
-437 b_s = b_s * scale
-438 if MASK:
-439 causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
-440 b_s = tl.where(causal_mask, b_s, -float("inf"))
418 tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
442 b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
420 tl.store(p_o, (b_o / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
- +
445 tl.static_assert(len(b_s.shape) == 2)
-446 b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
423@triton.jit
+424def _attn_fwd_inner(b_o, b_l, b_m, b_q,
+425 p_kT, p_v,
+426 scale,
+427 BLOCK_Q: tl.constexpr,
+428 d_head: tl.constexpr,
+429 BLOCK_K: tl.constexpr,
+430 offs_i, offs_j,
+431 j,
+432 steps,
+433 MASK: tl.constexpr,
+434 q_seq_len: tl.constexpr,
+435 kv_seq_len: tl.constexpr
+436 ):
+437 tl.static_assert(BLOCK_Q % BLOCK_K == 0)
448 b_p = tl.math.exp2(b_s - b_m_new[:, None])
440 p_kT = tl.advance(p_kT, (0, j))
+441 p_v = tl.advance(p_v, (j, 0))
450 b_l_new = tl.sum(b_p, -1)
444 for _ in range(steps):
453 b_m_m_new = tl.math.exp2(b_m - b_m_new)
446 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
455 b_l = b_l * b_m_m_new + b_l_new
448 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
+449 b_s = b_s * scale
458 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
-459 b_acc = b_acc * b_m_m_new[:, None]
-460 b_p = b_p.to(b_q.dtype)
-461 b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
452 if MASK:
+453 causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
+454 b_s = tl.where(causal_mask, b_s, -float("inf"))
464 b_m = b_m_new
457 j_mask = (j + offs_j) < kv_seq_len
+458 b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
467 j += BLOCK_K
-468 p_v = tl.advance(p_v, (BLOCK_K, 0))
-469 p_kT = tl.advance(p_kT, (0, BLOCK_K))
-470
-471 tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
-472
-473 return b_acc, b_l, b_m
461 b_m_new = tl.maximum(b_m, tl.max(b_s, -1))
+
476@triton.jit
-477def _attn_bwd_d(t_o, t_do,
-478 t_pdp,
-479 BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
-480 q_seq_len: tl.constexpr,
-481 n_groups: tl.constexpr,
-482 ):
-483 i = tl.program_id(0) * BLOCK_Q
-484 z = tl.program_id(1)
463 b_p = tl.math.exp2(b_s - b_m_new[:, None])
487 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
-488 (n_groups, q_seq_len, d_head),
-489 (q_seq_len * d_head, d_head, 1),
-490 (0, i, 0),
-491 (n_groups, BLOCK_Q, d_head),
-492 (2, 1, 0))
-493 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
-494 (n_groups, q_seq_len, d_head),
-495 (q_seq_len * d_head, d_head, 1),
-496 (0, i, 0),
-497 (n_groups, BLOCK_Q, d_head),
-498 (2, 1, 0))
-499 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
-500 (n_groups, q_seq_len),
-501 (q_seq_len, 1),
-502 (0, i),
-503 (n_groups, BLOCK_Q),
-504 (1, 0))
-505
-506 o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
-507 do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
-508 d = tl.sum(o * do, axis=-1)
-509 tl.store(p_pdp, d, boundary_check=(1,))
466 b_l_new = tl.sum(b_p, -1)
512@triton.autotune(_get_autotune_configs(inner_loop='query'),
-513 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-514@triton.jit
-515def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
-516 t_do,
-517 t_dk, t_dv,
-518 t_lse, t_pdp,
-519 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-520 n_groups: tl.constexpr, d_head: tl.constexpr,
-521 is_causal: tl.constexpr,
-522 BLOCK_Q: tl.constexpr,
-523 BLOCK_K: tl.constexpr,
-524 ):
468 b_m_m_new = tl.math.exp2(b_m - b_m_new)
529 j = tl.program_id(0) * BLOCK_K
-530 z = tl.program_id(1)
-531
-532 p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-533 (kv_seq_len, d_head),
-534 (d_head, 1),
-535 (j, 0),
-536 (BLOCK_K, d_head),
-537 (1, 0))
-538 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-539 (kv_seq_len, d_head),
-540 (d_head, 1),
-541 (j, 0),
-542 (BLOCK_K, d_head),
-543 (1, 0))
-544 p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
-545 (kv_seq_len, d_head),
-546 (d_head, 1),
-547 (j, 0),
-548 (BLOCK_K, d_head),
-549 (1, 0))
-550 p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
-551 (kv_seq_len, d_head),
-552 (d_head, 1),
-553 (j, 0),
-554 (BLOCK_K, d_head),
-555 (1, 0))
-556
-557 b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
-558 b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
470 b_l = b_l * b_m_m_new + b_l_new
561 b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
-562 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
473 b_o = b_o * b_m_m_new[:, None]
+474 b_p = b_p.to(b_q.dtype) # TODO
+475 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
+476 b_o += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
565 for g in range(n_groups):
479 b_m = b_m_new
567 p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-568 (d_head, q_seq_len),
-569 (1, d_head),
-570 (0, 0),
-571 (d_head, BLOCK_Q),
-572 (0, 1))
-573
-574 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-575 (q_seq_len, d_head),
-576 (d_head, 1),
-577 (0, 0),
-578 (BLOCK_Q, d_head),
-579 (1, 0))
-580 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-581 (q_seq_len,),
-582 (1,),
-583 (0,),
-584 (BLOCK_Q,),
-585 (0,))
-586 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-587 (q_seq_len,),
-588 (1,),
-589 (0,),
-590 (BLOCK_Q,),
-591 (0,))
482 j += BLOCK_K
+483 p_v = tl.advance(p_v, (BLOCK_K, 0))
+484 p_kT = tl.advance(p_kT, (0, BLOCK_K))
+485
+486 tl.static_assert(b_o.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
+487
+488 return b_o, b_l, b_m
- +
491@triton.jit
+492def _attn_bwd_d(t_o, t_do,
+493 t_pdp,
+494 BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
+495 q_seq_len: tl.constexpr,
+496 n_groups: tl.constexpr,
+497 ):
+498 i = tl.program_id(0) * BLOCK_Q
+499 z = tl.program_id(1)
Compute and along the masked blocks near diagonal. Use smaller block size of MASK_BLOCK_Q because there is a little extra computation?
+Create block pointers
599 if is_causal:
502 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
+503 (n_groups, q_seq_len, d_head),
+504 (q_seq_len * d_head, d_head, 1),
+505 (0, i, 0),
+506 (n_groups, BLOCK_Q, d_head),
+507 (2, 1, 0))
+508 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
+509 (n_groups, q_seq_len, d_head),
+510 (q_seq_len * d_head, d_head, 1),
+511 (0, i, 0),
+512 (n_groups, BLOCK_Q, d_head),
+513 (2, 1, 0))
+514 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
+515 (n_groups, q_seq_len),
+516 (q_seq_len, 1),
+517 (0, i),
+518 (n_groups, BLOCK_Q),
+519 (1, 0))
601 b_dk, b_dv = _attn_bwd_dkdv_inner(
-602 b_dk, b_dv,
-603 p_qT, b_k, b_v, p_do,
-604 p_lse, p_pdp,
522 o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
You can use a smaller BLOCK_Q if BLOCK_K is not divisible by BLOCK_Q
+Load
606 BLOCK_Q, BLOCK_K,
-607 d_head,
-608 j=j, i=j,
-609 steps=BLOCK_K // BLOCK_Q,
-610 MASK=True,
-611 q_seq_len=q_seq_len,
-612 kv_seq_len=kv_seq_len,
-613 )
524 do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
526 d = tl.sum(o * do, axis=-1)
Save
+ +528 tl.store(p_pdp, d, boundary_check=(1,))
Compute and for by iterating over
+ +531@triton.autotune(_get_autotune_configs(inner_loop='query'),
+532 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+533@triton.jit
+534def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
+535 t_do,
+536 t_dk, t_dv,
+537 t_lse, t_pdp,
+538 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+539 n_groups: tl.constexpr, d_head: tl.constexpr,
+540 is_causal: tl.constexpr,
+541 BLOCK_Q: tl.constexpr,
+542 BLOCK_K: tl.constexpr,
+543 ):
548 j = tl.program_id(0) * BLOCK_K
+549 z = tl.program_id(1)
Create block pointers
+ +552 p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+553 (kv_seq_len, d_head),
+554 (d_head, 1),
+555 (j, 0),
+556 (BLOCK_K, d_head),
+557 (1, 0))
+558 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+559 (kv_seq_len, d_head),
+560 (d_head, 1),
+561 (j, 0),
+562 (BLOCK_K, d_head),
+563 (1, 0))
+564 p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
+565 (kv_seq_len, d_head),
+566 (d_head, 1),
+567 (j, 0),
+568 (BLOCK_K, d_head),
+569 (1, 0))
+570 p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
+571 (kv_seq_len, d_head),
+572 (d_head, 1),
+573 (j, 0),
+574 (BLOCK_K, d_head),
+575 (1, 0))
Initialize and
+ +578 b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
+579 b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
Load and outside the loop.
+ +582 b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
+583 b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
Iterate through queries in GQA
+ +586 for g in range(n_groups):
Create block pointers
+ +588 p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+589 (d_head, q_seq_len),
+590 (1, d_head),
+591 (0, 0),
+592 (d_head, BLOCK_Q),
+593 (0, 1))
+594
+595 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+596 (q_seq_len, d_head),
+597 (d_head, 1),
+598 (0, 0),
+599 (BLOCK_Q, d_head),
+600 (1, 0))
+601 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+602 (q_seq_len,),
+603 (1,),
+604 (0,),
+605 (BLOCK_Q,),
+606 (0,))
+607 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+608 (q_seq_len,),
+609 (1,),
+610 (0,),
+611 (BLOCK_Q,),
+612 (0,))
+613
+614 if is_causal:
Inner loop at the diagonal block
Save
- -643 tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
Since we used where $hat{k} are the original keys we multiple by scale again to get gradient on original keys.
- -647 b_dk *= sm_scale
Save
- -650 tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
Inner loop along m query
- -653@triton.jit
-654def _attn_bwd_dkdv_inner(b_dk, b_dv,
-655 p_qT, b_k, b_v, p_do,
-656 p_lse, p_pdp,
-657 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
-658 d_head: tl.constexpr,
-659 j, i, steps,
-660 MASK: tl.constexpr,
-661 q_seq_len: tl.constexpr,
-662 kv_seq_len: tl.constexpr):
To apply the mask
- -666 tl.static_assert(BLOCK_K % BLOCK_Q == 0)
Offsets for mask computation
- -669 offs_i = i + tl.arange(0, BLOCK_Q)
-670 i_mask = offs_i < q_seq_len
-671 offs_j = j + tl.arange(0, BLOCK_K)
Pointers
- -674 p_qT = tl.advance(p_qT, (0, i))
-675 p_do = tl.advance(p_do, (i, 0))
-676 p_lse = tl.advance(p_lse, (i,))
-677 p_pdp = tl.advance(p_pdp, (i,))
Loop
- -680 for _ in range(steps):
Load
- -682 b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
685 b_m = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
630 b_dk, b_dv = _attn_bwd_dkdv_inner(
+631 b_dk, b_dv,
+632 p_qT, b_k, b_v, p_do,
+633 p_lse, p_pdp,
+634 BLOCK_Q, BLOCK_K,
+635 d_head,
+636 j=j, i=j + BLOCK_K,
+637 steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
+638 MASK=False,
+639 q_seq_len=q_seq_len,
+640 kv_seq_len=kv_seq_len
+641 )
+642 else:
Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of
+Iterate through all queries
690 b_qkT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
-691 b_pT = tl.math.exp2(b_qkT - b_m[None, :])
644 b_dk, b_dv = _attn_bwd_dkdv_inner(
+645 b_dk, b_dv,
+646 p_qT, b_k, b_v, p_do,
+647 p_lse, p_pdp,
+648 BLOCK_Q, BLOCK_K,
+649 d_head,
+650 j=j, i=tl.full([], 0, tl.int32),
+651 steps=tl.cdiv(q_seq_len, BLOCK_Q),
+652 MASK=False,
+653 q_seq_len=q_seq_len,
+654 kv_seq_len=kv_seq_len
+655 )
694 if MASK:
-695 mask = (offs_i[None, :] >= offs_j[:, None])
-696 b_pT = tl.where(mask, b_pT, 0.0)
-697
-698 b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
658 tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
701 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
-702 b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
661 b_dk *= sm_scale
705 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
664 tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
707 b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
667@triton.jit
+668def _attn_bwd_dkdv_inner(b_dk, b_dv,
+669 p_qT, b_k, b_v, p_do,
+670 p_lse, p_pdp,
+671 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
+672 d_head: tl.constexpr,
+673 j, i, steps,
+674 MASK: tl.constexpr,
+675 q_seq_len: tl.constexpr,
+676 kv_seq_len: tl.constexpr):
709 b_dsT = b_pT * (b_dpT - b_pdp[None, :])
680 tl.static_assert(BLOCK_K % BLOCK_Q == 0)
711 b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
683 offs_i = i + tl.arange(0, BLOCK_Q)
+684 i_mask = offs_i < q_seq_len
+685 offs_j = j + tl.arange(0, BLOCK_K)
714 offs_i += BLOCK_Q
-715 p_lse = tl.advance(p_lse, (BLOCK_Q,))
-716 p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
-717 p_qT = tl.advance(p_qT, (0, BLOCK_Q))
-718 p_do = tl.advance(p_do, (BLOCK_Q, 0))
688 p_qT = tl.advance(p_qT, (0, i))
+689 p_do = tl.advance(p_do, (i, 0))
+690 p_lse = tl.advance(p_lse, (i,))
+691 p_pdp = tl.advance(p_pdp, (i,))
721 return b_dk, b_dv
694 for _ in range(steps):
Load
+724@triton.autotune(_get_autotune_configs(inner_loop='key'),
-725 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-726@triton.jit
-727def _attn_bwd_dq(t_q, t_k, t_v, t_do,
-728 t_dq,
-729 t_lse, t_pdp,
-730 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-731 n_groups: tl.constexpr, d_head: tl.constexpr,
-732 is_causal: tl.constexpr,
-733 BLOCK_Q: tl.constexpr,
-734 BLOCK_K: tl.constexpr,
-735 ):
696 b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
737 LN2: tl.constexpr = 0.6931471824645996 # type: ignore
-738
-739 i = tl.program_id(0) * BLOCK_Q
-740 z = tl.program_id(1) // n_groups
-741 g = tl.program_id(1) % n_groups
699 b_l = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
744 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-745 (q_seq_len, d_head),
-746 (d_head, 1),
-747 (i, 0),
-748 (BLOCK_Q, d_head),
-749 (1, 0))
-750 p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-751 (q_seq_len, d_head),
-752 (d_head, 1),
-753 (i, 0),
-754 (BLOCK_Q, d_head),
-755 (1, 0))
-756 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-757 (q_seq_len, d_head),
-758 (d_head, 1),
-759 (i, 0),
-760 (BLOCK_Q, d_head),
-761 (1, 0))
-762 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-763 (d_head, kv_seq_len),
-764 (1, d_head),
-765 (0, 0),
-766 (d_head, BLOCK_K),
-767 (0, 1))
-768 p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-769 (d_head, kv_seq_len),
-770 (1, d_head),
-771 (0, 0),
-772 (d_head, BLOCK_K),
-773 (0, 1))
-774 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-775 (q_seq_len,),
-776 (1,),
-777 (i,),
-778 (BLOCK_Q,),
-779 (0,))
-780 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-781 (q_seq_len,),
-782 (1,),
-783 (i,),
-784 (BLOCK_Q,),
-785 (0,))
-786
-787 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
-788 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
-789 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
-790
-791 b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
-792
-793 b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
702 b_sT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
797 if is_causal:
711 b_pT = tl.math.exp2(b_sT - b_l[None, :])
799 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-800 b_do, b_lse, b_pdp,
-801 BLOCK_Q, BLOCK_K,
-802 i=i, j=i,
-803 steps=BLOCK_Q // BLOCK_K,
-804 MASK=True,
-805 q_seq_len=q_seq_len,
-806 kv_seq_len=kv_seq_len
-807 )
714 if MASK:
+715 mask = (offs_i[None, :] >= offs_j[:, None])
+716 b_pT = tl.where(mask, b_pT, 0.0)
810 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-811 b_do, b_lse, b_pdp,
-812 BLOCK_Q, BLOCK_K,
-813 i=i, j=tl.full([], 0, tl.int32), # type: ignore
-814 steps=i // BLOCK_K,
-815 MASK=False,
-816 q_seq_len=q_seq_len,
-817 kv_seq_len=kv_seq_len
-818 )
-819 else:
-820 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-821 b_do, b_lse, b_pdp,
-822 BLOCK_Q, BLOCK_K,
-823 i=i, j=tl.full([], 0, tl.int32), # type: ignore
-824 steps=tl.cdiv(kv_seq_len, BLOCK_K),
-825 MASK=False,
-826 q_seq_len=q_seq_len,
-827 kv_seq_len=kv_seq_len
-828 )
719 b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
Since was scaled by , and got this factor in to computed we need to reverse it.
+
832 b_dq *= LN2
722 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
+723 b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
835 tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
726 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
838@triton.jit
-839def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-840 b_do, b_lse, b_pdp,
-841 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
-842 i, j, steps,
-843 MASK: tl.constexpr,
-844 q_seq_len: tl.constexpr,
-845 kv_seq_len: tl.constexpr):
728 b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
+
847 offs_i = i + tl.arange(0, BLOCK_Q)
-848 offs_j = tl.arange(0, BLOCK_K)
-849
-850 p_kT = tl.advance(p_kT, (0, j))
-851 p_vT = tl.advance(p_vT, (0, j))
-852
-853 tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')
-854
-855 for _ in range(steps):
-856 current_j = j + offs_j
-857 j_mask = current_j < kv_seq_len
730 b_dsT = b_pT * (b_dpT - b_pdp[None, :])
Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of
+
862 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
-863 b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")
-864 b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
-865 b_p = tl.math.exp2(b_qk - b_lse[:, None])
732 b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
868 if MASK:
-869 causal_mask = (offs_i[:, None] >= current_j[None, :])
-870 b_p = tl.where(causal_mask, b_p, 0.0)
-871
-872 b_p = tl.where(j_mask[None, :], b_p, 0.0)
735 offs_i += BLOCK_Q
+736 p_lse = tl.advance(p_lse, (BLOCK_Q,))
+737 p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
+738 p_qT = tl.advance(p_qT, (0, BLOCK_Q))
+739 p_do = tl.advance(p_do, (BLOCK_Q, 0))
742 return b_dk, b_dv
- +
877 b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
745@triton.autotune(_get_autotune_configs(inner_loop='key'),
+746 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+747@triton.jit
+748def _attn_bwd_dq(t_q, t_k, t_v, t_do,
+749 t_dq,
+750 t_lse, t_pdp,
+751 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+752 n_groups: tl.constexpr, d_head: tl.constexpr,
+753 is_causal: tl.constexpr,
+754 BLOCK_Q: tl.constexpr,
+755 BLOCK_K: tl.constexpr,
+756 ):
879 b_ds = b_p * (b_dp - b_pdp[:, None])
758 LN2: tl.constexpr = 0.6931471824645996 # type: ignore
+759
+760 i = tl.program_id(0) * BLOCK_Q
+761 z = tl.program_id(1) // n_groups
+762 g = tl.program_id(1) % n_groups
881 b_dq += tl.dot(b_ds.to(b_kT.dtype),
-882 tl.trans(b_kT),
-883 out_dtype=HI_PRES_TL)
765 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+766 (q_seq_len, d_head),
+767 (d_head, 1),
+768 (i, 0),
+769 (BLOCK_Q, d_head),
+770 (1, 0))
+771 p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+772 (q_seq_len, d_head),
+773 (d_head, 1),
+774 (i, 0),
+775 (BLOCK_Q, d_head),
+776 (1, 0))
+777 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+778 (q_seq_len, d_head),
+779 (d_head, 1),
+780 (i, 0),
+781 (BLOCK_Q, d_head),
+782 (1, 0))
+783 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+784 (d_head, kv_seq_len),
+785 (1, d_head),
+786 (0, 0),
+787 (d_head, BLOCK_K),
+788 (0, 1))
+789 p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+790 (d_head, kv_seq_len),
+791 (1, d_head),
+792 (0, 0),
+793 (d_head, BLOCK_K),
+794 (0, 1))
+795 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+796 (q_seq_len,),
+797 (1,),
+798 (i,),
+799 (BLOCK_Q,),
+800 (0,))
+801 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+802 (q_seq_len,),
+803 (1,),
+804 (i,),
+805 (BLOCK_Q,),
+806 (0,))
886 j += BLOCK_K
-887 p_kT = tl.advance(p_kT, (0, BLOCK_K))
-888 p_vT = tl.advance(p_vT, (0, BLOCK_K))
809 b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
+810 b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
+811 b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
+812 b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
891 return b_dq
815 b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
+ +
819 if is_causal:
Compute for masked (diagonal) blocks.
+ +821 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+822 b_do, b_lse, b_pdp,
+823 BLOCK_Q, BLOCK_K,
+824 i=i, j=i,
+825 steps=BLOCK_Q // BLOCK_K,
+826 MASK=True,
+827 q_seq_len=q_seq_len,
+828 kv_seq_len=kv_seq_len
+829 )
Compute for other blocks
+ +832 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+833 b_do, b_lse, b_pdp,
+834 BLOCK_Q, BLOCK_K,
+835 i=i, j=tl.full([], 0, tl.int32), # type: ignore
+836 steps=i // BLOCK_K,
+837 MASK=False,
+838 q_seq_len=q_seq_len,
+839 kv_seq_len=kv_seq_len
+840 )
+841 else:
Iterate through all
+ +843 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+844 b_do, b_lse, b_pdp,
+845 BLOCK_Q, BLOCK_K,
+846 i=i, j=tl.full([], 0, tl.int32), # type: ignore
+847 steps=tl.cdiv(kv_seq_len, BLOCK_K),
+848 MASK=False,
+849 q_seq_len=q_seq_len,
+850 kv_seq_len=kv_seq_len
+851 )
b_dq
+ stores so multiply by to get
854 b_dq *= LN2
Save
+ +857 tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
Inner loop over n key
+ +860@triton.jit
+861def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+862 b_do, b_lse, b_pdp,
+863 BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
+864 i, j, steps,
+865 MASK: tl.constexpr,
+866 q_seq_len: tl.constexpr,
+867 kv_seq_len: tl.constexpr):
869 offs_i = i + tl.arange(0, BLOCK_Q)
+870 offs_j = tl.arange(0, BLOCK_K)
+871
+872 p_kT = tl.advance(p_kT, (0, j))
+873 p_vT = tl.advance(p_vT, (0, j))
+874
+875 tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')
+876
+877 for _ in range(steps):
+878 current_j = j + offs_j
+879 j_mask = current_j < kv_seq_len
Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of
+ +884 b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
+885 b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")
+886 b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
+887 b_p = tl.math.exp2(b_qk - b_lse[:, None])
Autoregressive masking.
+ +890 if MASK:
+891 causal_mask = (offs_i[:, None] >= current_j[None, :])
+892 b_p = tl.where(causal_mask, b_p, 0.0)
+893
+894 b_p = tl.where(j_mask[None, :], b_p, 0.0)
+ +
+ +
899 b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)
+ +
901 b_ds = b_p * (b_dp - b_pdp[:, None])
+ +
903 b_dq += tl.dot(b_ds.to(b_kT.dtype),
+904 tl.trans(b_kT),
+905 out_dtype=HI_PRES_TL)
Increment pointers.
+ +908 j += BLOCK_K
+909 p_kT = tl.advance(p_kT, (0, BLOCK_K))
+910 p_vT = tl.advance(p_vT, (0, BLOCK_K))
Return accumulated
+ +913 return b_dq