diff --git a/.gitignore b/.gitignore index 4336d66b..98c6271e 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,6 @@ html/ diagrams/ .comet.config settings.md -labml_app.log \ No newline at end of file +labml_app.log +/extensions +/.nb_editor \ No newline at end of file diff --git a/docs/transformers/flash/index.html b/docs/transformers/flash/index.html index 7e726113..b3c28739 100644 --- a/docs/transformers/flash/index.html +++ b/docs/transformers/flash/index.html @@ -79,9 +79,10 @@ 6import triton.language as tl 7 8import torch -9 -10HI_PRES_TL: tl.constexpr = tl.float32 -11HI_PRES_TORCH: tl.constexpr = torch.float32 +9from typing import Any, Tuple, Optional +10 +11HI_PRES_TL: tl.constexpr = tl.float32 +12HI_PRES_TORCH: torch.dtype = torch.float32
14class AttentionFunc(torch.autograd.Function):15class AttentionFunc(torch.autograd.Function): Group query attention forward pass. Returns the output in shape [batch_size, n_heads, q_seq_len, d_head]
+.
ctx
+ is the context for torch gradient descent q
+ has shape [batch_size, n_heads, q_seq_len, d_head]
+ q
+ has shape [batch_size, n_heads, q_seq_len, d_head]
+ k
+ has shape [batch_size, k_heads, kv_seq_len, d_head]
+ v
+ has shape [batch_size, k_heads, kv_seq_len, d_head]
+ causal
+ whether to apply causal attention mask sm_scale
+ softmax scale factor15 @staticmethod
-16 def forward(ctx, q, k, v, causal, sm_scale):16 @staticmethod
+17 def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+18 causal: bool, sm_scale: float) -> torch.Tensor:Shape batch size, n_heads, seq, d
- +18 batch_size, n_heads, q_seq_len, d_head = q.shape
-19 k_heads = k.shape[1]
-20 kv_seq_len = k.shape[2]
-21 assert n_heads % k_heads == 0
-22 n_groups = n_heads // k_heads30 batch_size, n_heads, q_seq_len, d_head = q.shape
+31 _, k_heads, kv_seq_len, _ = k.shape
+32 assert n_heads % k_heads == 0
+33 n_groups = n_heads // k_heads25 assert d_head == k.shape[-1] == v.shape[-1]
-26 assert d_head in {16, 32, 64, 128, 256}
-27
-28 q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
-29 k = k.view(batch_size * k_heads, kv_seq_len, d_head)
-30 v = v.view(batch_size * k_heads, kv_seq_len, d_head)
-31
-32 assert q.is_contiguous()
-33 assert k.is_contiguous()
-34 assert v.is_contiguous()
-35
-36 o = torch.empty_like(q)
-37
-38 lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
-39
-40 grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups, 1)
-41 ctx.grid = grid
-42 _attn_fwd[grid](
-43 q, k, v, sm_scale, lse, o,
-44 n_groups=n_groups,
-45 q_seq_len=q_seq_len,
-46 kv_seq_len=kv_seq_len,
-47 d_head=d_head,
-48 is_causal=causal,
-49 )
-50
-51 ctx.save_for_backward(q, k, v, o, lse)
-52 ctx.sm_scale = sm_scale
-53 ctx.n_groups = n_groups
-54 ctx.d_head = d_head
-55 ctx.causal = causal
-56
-57 return o.view(batch_size, n_heads, q_seq_len, d_head)36 assert d_head == k.shape[-1] == v.shape[-1]
+37 assert d_head in {16, 32, 64, 128, 256}Change the tensors combining the heads with the batch dimension
+59 @staticmethod
-60 def backward(ctx, do):
-61 n_groups = ctx.n_groups
-62 sm_scale = ctx.sm_scale
-63 causal = ctx.causal
-64 q, k, v, o, lse = ctx.saved_tensors
-65 batch_size, n_heads, q_seq_len, d_head = do.shape
-66 _, kv_seq_len, _ = k.shape
-67 k_heads = n_heads // n_groups
-68
-69 do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
-70
-71 assert do.is_contiguous()
-72 assert k.stride() == v.stride()
-73 assert q.stride() == o.stride() == do.stride()
-74
-75 dq = torch.empty_like(q)
-76 dk = torch.empty_like(k)
-77 dv = torch.empty_like(v)
-78
-79 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
-80 arg_k = k * (sm_scale * RCP_LN2)
-81 BLOCK_M = 16
-82 assert q_seq_len % BLOCK_M == 0
-83 pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)40 q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
+41 k = k.view(batch_size * k_heads, kv_seq_len, d_head)
+42 v = v.view(batch_size * k_heads, kv_seq_len, d_head)85 pdp = torch.empty_like(lse)
-86 _attn_bwd_d[pre_grid](
-87 o, do,
-88 pdp,
-89 BLOCK_M=16,
-90 d_head=d_head,
-91 q_seq_len=q_seq_len,
-92 n_groups=n_groups,
-93 num_stages=1,
-94 )
-95 grid = lambda args: (triton.cdiv(kv_seq_len, args['BLOCK_N']), batch_size * k_heads)
-96 _attn_bwd_dkdv[grid](
-97 q, arg_k, v, sm_scale, do, dk, dv,
-98 lse, pdp,
-99 q_seq_len, kv_seq_len, n_groups, d_head,
-100 is_causal=causal,
-101
-102 )
-103 grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups)
-104 _attn_bwd_dq[grid](
-105 q, arg_k, v, do,
-106 dq,
-107 lse, pdp,
-108 q_seq_len, kv_seq_len, n_groups, d_head,
-109 is_causal=causal,
-110 )
-111
-112 dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
-113 dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
-114 dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
-115
-116 return dq, dk, dv, None, None
-117
-118
-119attention = AttentionFunc.apply45 assert q.is_contiguous()
+46 assert k.is_contiguous()
+47 assert v.is_contiguous()
+48 assert k.stride() == v.stride()122def _get_autotune_configs(inner_loop: str):51 o = torch.empty_like(q)Tensor for
+127 configs = []53 lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)List possible BLOCK_M and BLOCK_N that satisfy BLOCK_M divisible by BLOCK_N and also try to cover a wide range
+The forward computation will be parallelized along the batch dimension and the queries in blocks of size BLOCK_M
+
130 for bm in [64, 128, 256]:56 grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups, 1)
+57 _attn_fwd[grid](
+58 q, k, v, sm_scale, lse, o,
+59 n_groups=n_groups,
+60 q_seq_len=q_seq_len,
+61 kv_seq_len=kv_seq_len,
+62 d_head=d_head,
+63 is_causal=causal,
+64 )Save the reshaped inputs and outputs for the backward pass
+ +67 ctx.save_for_backward(q, k, v, o, lse)
+68 ctx.sm_scale = sm_scale
+69 ctx.n_groups = n_groups
+70 ctx.causal = causalReturn the output in shape [batch_size, n_heads, q_seq_len, d_head]
+
73 return o.view(batch_size, n_heads, q_seq_len, d_head)The backward pass computes the gradients of the input tensors.
+ctx
+ is the context for torch gradient descent do
+ is the gradient tensor of the attention output with shape [batch_size, n_heads, q_seq_len, d_head]
+75 @staticmethod
+76 def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:Get saved tensors and attributes
+ +85 n_groups = ctx.n_groups
+86 sm_scale = ctx.sm_scale
+87 causal = ctx.causal
+88 q, k, v, o, lse = ctx.saved_tensorsGet shapes
+ +91 batch_size, n_heads, q_seq_len, d_head = do.shape
+92 _, kv_seq_len, _ = k.shape
+93 k_heads = n_heads // n_groupsCombine the heads with the batch dimension of the output gradients tensor
+ +96 do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)Make sure it's contiguous and the strides are the same
+ +99 assert do.is_contiguous()
+100 assert k.stride() == v.stride()
+101 assert q.stride() == o.stride() == do.stride()Create tensors for input gradients
+ +104 dq = torch.empty_like(q)
+105 dk = torch.empty_like(k)
+106 dv = torch.empty_like(v)+ +
109 RCP_LN2 = 1.4426950408889634Multiply by softmax scale
+ +111 k_scaled = k * (sm_scale * RCP_LN2)+ +
113 pdp = torch.empty_like(lse)We use fixed BLOCK_M
+ for backward pass on
115 BLOCK_M = 16
+116 assert q_seq_len % BLOCK_M == 0120 pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
+121 _attn_bwd_d[pre_grid](
+122 o, do,
+123 pdp,
+124 BLOCK_M=16,
+125 d_head=d_head,
+126 q_seq_len=q_seq_len,
+127 n_groups=n_groups,
+128 num_stages=1,
+129 )Compute and
+This is parallelized along the batch and keys in blocks of size BLOCK_N
+
133 grid = lambda args: (triton.cdiv(kv_seq_len, args['BLOCK_N']), batch_size * k_heads)
+134 _attn_bwd_dkdv[grid](
+135 q, k_scaled, v, sm_scale, do, dk, dv,
+136 lse, pdp,
+137 q_seq_len, kv_seq_len, n_groups, d_head,
+138 is_causal=causal,
+139
+140 )144 grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups)
+145 _attn_bwd_dq[grid](
+146 q, k_scaled, v, do,
+147 dq,
+148 lse, pdp,
+149 q_seq_len, kv_seq_len, n_groups, d_head,
+150 is_causal=causal,
+151 )Split the combined batch and heads
+ +154 dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
+155 dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
+156 dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)+ +
159 return dq, dk, dv, None, None
+160
+161
+162attention = AttentionFunc.apply165def _get_autotune_configs(inner_loop: str) -> list:170 configs = []List possible BLOCK_M and BLOCK_N that satisfy BLOCK_M divisible by BLOCK_N and also try to cover a wide range
+ +173 for bm in [64, 128, 256]:We'll try bn in 16, 32, 64, 128 that are divisors and <= bm
132 for bn in [64, 128, 256]:
-133 if inner_loop == 'key' and bm % bn != 0:
-134 continue
-135 if inner_loop == 'query' and bn % bm != 0:
-136 continue
-137 for s in [2, 3, 4]:
-138 for w in [4, 8]:
-139 if bm * bn < 128 * 128 and w == 8:
-140 continue
-141
-142 configs.append(triton.Config({'BLOCK_M': bm, 'BLOCK_N': bn}, num_stages=s, num_warps=w))
-143
-144 return configs175 for bn in [64, 128, 256]:
+176 if inner_loop == 'key' and bm % bn != 0:
+177 continue
+178 if inner_loop == 'query' and bn % bm != 0:
+179 continue
+180 for s in [2, 3, 4]:
+181 for w in [4, 8]:
+182 if bm * bn < 128 * 128 and w == 8:
+183 continue
+184
+185 configs.append(triton.Config({'BLOCK_M': bm, 'BLOCK_N': bn}, num_stages=s, num_warps=w))
+186
+187 return configst_q
query sm_scale
softmax scale t_lse
- (out) t_o
output (out) n_groups
@@ -352,427 +593,18 @@
147@triton.autotune(_get_autotune_configs(inner_loop='key'),
-148 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-149@triton.jit
-150def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
-151 n_groups: tl.constexpr,
-152 q_seq_len: tl.constexpr,
-153 kv_seq_len: tl.constexpr,
-154 d_head: tl.constexpr,
-155 is_causal: tl.constexpr,
-156 BLOCK_M: tl.constexpr, # q seq len block
-157 BLOCK_N: tl.constexpr, # k seq len block
-158 ):179 start_m = tl.program_id(0)
-180 z = tl.program_id(1) // n_groups
-181 g = tl.program_id(1) % n_groupsblock pointers
- -184 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-185 (q_seq_len, d_head),
-186 (d_head, 1),
-187 (start_m * BLOCK_M, 0),
-188 (BLOCK_M, d_head),
-189 (1, 0))
-190 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-191 (kv_seq_len, d_head),
-192 (d_head, 1),
-193 (0, 0),
-194 (BLOCK_N, d_head),
-195 (1, 0))
-196 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-197 (d_head, kv_seq_len),
-198 (1, d_head),
-199 (0, 0),
-200 (d_head, BLOCK_N),
-201 (0, 1))
-202 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-203 (q_seq_len, d_head),
-204 (d_head, 1),
-205 (start_m * BLOCK_M, 0),
-206 (BLOCK_M, d_head),
-207 (1, 0))
-208 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-209 (q_seq_len,),
-210 (1,),
-211 (start_m * BLOCK_M,),
-212 (BLOCK_M,),
-213 (0,))initialize offsets
- -216 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
-217 offs_n = tl.arange(0, BLOCK_N)Initialize and
- -220 b_m = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) - float("inf")
-221 b_l = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) + 1.0Accumulate
- -223 b_acc = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)softmax scale / log(2)
- -226 sm_scale = sm_scale * 1.44269504Load
- -228 b_q = tl.load(p_q)
-229
-230 if is_causal:Run for ranges
- -232 b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
-233 p_kT, p_v,
-234 sm_scale,
-235 BLOCK_M, d_head, BLOCK_N,
-236 offs_m, offs_n,
-237 start_n=tl.full([], 0, tl.int32), # type: ignore
-238 steps=(start_m * BLOCK_M) // BLOCK_N,
-239 MASK=False,
-240 )
-241 b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
-242 sm_scale,
-243 BLOCK_M, d_head, BLOCK_N,
-244 offs_m, offs_n,
-245 start_n=start_m * BLOCK_M,
-246 steps=BLOCK_M // BLOCK_N,
-247 MASK=True,
-248 )
-249 else:
-250 b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
-251 sm_scale,
-252 BLOCK_M, d_head, BLOCK_N,
-253 offs_m, offs_n,
-254 start_n=tl.full([], 0, tl.int32), # type: ignore
-255 steps=kv_seq_len // BLOCK_N,
-256 MASK=False,
-257 )Update LSE
- -260 tl.store(p_lse, b_m + tl.math.log2(b_l))
-261 tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty))264@triton.jit
-265def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
-266 p_kT, p_v,
-267 scale,
-268 BLOCK_M: tl.constexpr,
-269 d_head: tl.constexpr,
-270 BLOCK_N: tl.constexpr,
-271 offs_m, offs_n,
-272 start_n,
-273 steps,
-274 MASK: tl.constexpr,
-275 ):
-276 tl.static_assert(BLOCK_M % BLOCK_N == 0)
-277
-278 p_kT = tl.advance(p_kT, (0, start_n))
-279 p_v = tl.advance(p_v, (start_n, 0))loop over k, v and update accumulator
- -282 for _ in range(steps):
-283 b_kT = tl.load(p_kT)
-284 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
-285
-286 tl.static_assert(b_s.dtype == HI_PRES_TL)
-287 b_s = b_s * scale
-288 if MASK:
-289 mask = offs_m[:, None] >= (start_n + offs_n[None, :])
-290 b_s = b_s + tl.where(mask, 0, -1.0e6)- -
293 tl.static_assert(len(b_s.shape) == 2)
-294 b_m_new = tl.maximum(b_m, tl.max(b_s, -1))- -
296 b_p = tl.math.exp2(b_s - b_m_new[:, None])- -
298 b_l_new = tl.sum(b_p, -1)- -
301 b_m_m_new = tl.math.exp2(b_m - b_m_new)- -
303 b_l = b_l * b_m_m_new + b_l_new- -
306 b_v = tl.load(p_v)
-307 b_acc = b_acc * b_m_m_new[:, None]
-308 b_p = b_p.to(b_q.dtype)
-309 b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)update m_i and l_i
- -312 b_m = b_m_new
-313
-314 start_n += BLOCK_N
-315 p_v = tl.advance(p_v, (BLOCK_N, 0))
-316 p_kT = tl.advance(p_kT, (0, BLOCK_N))
-317
-318 tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
-319
-320 return b_acc, b_l, b_mLoop along m query; n % m == 0
- -323@triton.jit
-324def _attn_bwd_d(t_o, t_do,
-325 t_pdp,
-326 BLOCK_M: tl.constexpr, d_head: tl.constexpr,
-327 q_seq_len: tl.constexpr,
-328 n_groups: tl.constexpr,
-329 ):
-330 m = tl.program_id(0) * BLOCK_M
-331 z = tl.program_id(1)
-332 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
-333 (n_groups, q_seq_len, d_head),
-334 (q_seq_len * d_head, d_head, 1),
-335 (0, m, 0),
-336 (n_groups, BLOCK_M, d_head),
-337 (2, 1, 0))
-338 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
-339 (n_groups, q_seq_len, d_head),
-340 (q_seq_len * d_head, d_head, 1),
-341 (0, m, 0),
-342 (n_groups, BLOCK_M, d_head),
-343 (2, 1, 0))
-344 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
-345 (n_groups, q_seq_len),
-346 (q_seq_len, 1),
-347 (0, m),
-348 (n_groups, BLOCK_M),
-349 (1, 0))
-350
-351 o = tl.load(p_o)
-352 do = tl.load(p_do).to(HI_PRES_TL)
-353 d = tl.sum(o * do, axis=-1)
-354 tl.store(p_pdp, d)
-355
-356
-357@triton.autotune(_get_autotune_configs(inner_loop='query'),
-358 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-359@triton.jit
-360def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
-361 t_do,
-362 t_dk, t_dv,
-363 t_lse, t_pdp,
-364 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-365 n_groups: tl.constexpr, d_head: tl.constexpr,
-366 is_causal: tl.constexpr,
-367 BLOCK_M: tl.constexpr,
-368 BLOCK_N: tl.constexpr,
-369 ):K is already multiplied by scale
- -374 n = tl.program_id(0)
-375 z = tl.program_id(1)
-376
-377 p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-378 (kv_seq_len, d_head),
-379 (d_head, 1),
-380 (n * BLOCK_N, 0),
-381 (BLOCK_N, d_head),
-382 (1, 0))
-383 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-384 (kv_seq_len, d_head),
-385 (d_head, 1),
-386 (n * BLOCK_N, 0),
-387 (BLOCK_N, d_head),
-388 (1, 0))
-389 p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
-390 (kv_seq_len, d_head),
-391 (d_head, 1),
-392 (n * BLOCK_N, 0),
-393 (BLOCK_N, d_head),
-394 (1, 0))
-395 p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
-396 (kv_seq_len, d_head),
-397 (d_head, 1),
-398 (n * BLOCK_N, 0),
-399 (BLOCK_N, d_head),
-400 (1, 0))
-401
-402 b_dv = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)
-403 b_dk = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)190@triton.autotune(_get_autotune_configs(inner_loop='key'),
+191 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+192@triton.jit
+193def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
+194 n_groups: tl.constexpr,
+195 q_seq_len: tl.constexpr,
+196 kv_seq_len: tl.constexpr,
+197 d_head: tl.constexpr,
+198 is_causal: tl.constexpr,
+199 BLOCK_M: tl.constexpr, # q seq len block
+200 BLOCK_N: tl.constexpr, # k seq len block
+201 ):load K and V: they stay in SRAM throughout the inner loop.
- +406 b_k = tl.load(p_k)
-407 b_v = tl.load(p_v)
-408
-409 for g in range(n_groups):
-410 p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-411 (d_head, q_seq_len),
-412 (1, d_head),
-413 (0, 0),
-414 (d_head, BLOCK_M),
-415 (0, 1))
-416
-417 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-418 (q_seq_len, d_head),
-419 (d_head, 1),
-420 (0, 0),
-421 (BLOCK_M, d_head),
-422 (1, 0))
-423 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-424 (q_seq_len,),
-425 (1,),
-426 (0,),
-427 (BLOCK_M,),
-428 (0,))
-429 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-430 (q_seq_len,),
-431 (1,),
-432 (0,),
-433 (BLOCK_M,),
-434 (0,))222 i = tl.program_id(0)
+223 z = tl.program_id(1) // n_groups
+224 g = tl.program_id(1) % n_groups227 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+228 (q_seq_len, d_head),
+229 (d_head, 1),
+230 (i * BLOCK_M, 0),
+231 (BLOCK_M, d_head),
+232 (1, 0))
+233 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+234 (kv_seq_len, d_head),
+235 (d_head, 1),
+236 (0, 0),
+237 (BLOCK_N, d_head),
+238 (1, 0))
+239 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+240 (d_head, kv_seq_len),
+241 (1, d_head),
+242 (0, 0),
+243 (d_head, BLOCK_N),
+244 (0, 1))
+245 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+246 (q_seq_len, d_head),
+247 (d_head, 1),
+248 (i * BLOCK_M, 0),
+249 (BLOCK_M, d_head),
+250 (1, 0))
+251 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+252 (q_seq_len,),
+253 (1,),
+254 (i * BLOCK_M,),
+255 (BLOCK_M,),
+256 (0,))Compute and along the masked blocks near diagonal. Use smaller block size of MASK_BLOCK_M because there is a little extra computation?
+Initialize offsets
442 if is_causal:259 offs_i = i * BLOCK_M + tl.arange(0, BLOCK_M)
+260 offs_j = tl.arange(0, BLOCK_N)444 b_dk, b_dv = _attn_bwd_dkdv_inner(
-445 b_dk, b_dv,
-446 p_qT, b_k, b_v, p_do,
-447 p_lse, p_pdp,263 b_m = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) - float("inf")
+264 b_l = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) + 1.0You can use a smaller BLOCK_M if BLOCK_N is not divisible by BLOCK_M
+Accumulate
449 BLOCK_M, BLOCK_N,
-450 d_head,
-451 n=n * BLOCK_N, start_m=n * BLOCK_N,
-452 steps=BLOCK_N // BLOCK_M,
-453 MASK=True
-454 )266 b_acc = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)457 b_dk, b_dv = _attn_bwd_dkdv_inner(
-458 b_dk, b_dv,
-459 p_qT, b_k, b_v, p_do,
-460 p_lse, p_pdp,
-461 BLOCK_M, BLOCK_N,
-462 d_head,
-463 n=n * BLOCK_N, start_m=(n + 1) * BLOCK_N,
-464 steps=(q_seq_len - (n + 1) * BLOCK_N) // BLOCK_M,
-465 MASK=False,
-466 )
-467 else:
-468 b_dk, b_dv = _attn_bwd_dkdv_inner(
-469 b_dk, b_dv,
-470 p_qT, b_k, b_v, p_do,
-471 p_lse, p_pdp,
-472 BLOCK_M, BLOCK_N,
-473 d_head,
-474 n=n * BLOCK_N, start_m=tl.full([], 0, tl.int32),
-475 steps=q_seq_len // BLOCK_M,
-476 MASK=False,
-477 )269 sm_scale = sm_scale * 1.44269504480 tl.store(p_dv, b_dv.to(t_dv.type.element_ty))271 b_q = tl.load(p_q)
+272
+273 if is_causal:Since we used where $hat{k} are the original keys we multiple by scale again to get gradient on original keys.
+Upto the diagonal block
484 b_dk *= sm_scale275 b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
+276 p_kT, p_v,
+277 sm_scale,
+278 BLOCK_M, d_head, BLOCK_N,
+279 offs_i, offs_j,
+280 start_n=tl.full([], 0, tl.int32), # type: ignore
+281 steps=(i * BLOCK_M) // BLOCK_N,
+282 MASK=False,
+283 )487 tl.store(p_dk, b_dk.to(t_dk.type.element_ty))285 b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
+286 sm_scale,
+287 BLOCK_M, d_head, BLOCK_N,
+288 offs_i, offs_j,
+289 start_n=i * BLOCK_M,
+290 steps=BLOCK_M // BLOCK_N,
+291 MASK=True,
+292 )
+293 else:
+294 b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
+295 sm_scale,
+296 BLOCK_M, d_head, BLOCK_N,
+297 offs_i, offs_j,
+298 start_n=tl.full([], 0, tl.int32), # type: ignore
+299 steps=kv_seq_len // BLOCK_N,
+300 MASK=False,
+301 )490@triton.jit
-491def _attn_bwd_dkdv_inner(b_dk, b_dv,
-492 p_qT, b_k, b_v, p_do,
-493 p_lse, p_pdp,
-494 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-495 d_head: tl.constexpr,
-496 n, start_m, steps,
-497 MASK: tl.constexpr):304 tl.store(p_lse, b_m + tl.math.log2(b_l))
+305 tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty))To apply the mask
- +501 tl.static_assert(BLOCK_N % BLOCK_M == 0)308@triton.jit
+309def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
+310 p_kT, p_v,
+311 scale,
+312 BLOCK_M: tl.constexpr,
+313 d_head: tl.constexpr,
+314 BLOCK_N: tl.constexpr,
+315 offs_m, offs_n,
+316 start_n,
+317 steps,
+318 MASK: tl.constexpr,
+319 ):
+320 tl.static_assert(BLOCK_M % BLOCK_N == 0)
+321
+322 p_kT = tl.advance(p_kT, (0, start_n))
+323 p_v = tl.advance(p_v, (start_n, 0))504 offs_m = start_m + tl.arange(0, BLOCK_M)
-505 offs_n = n + tl.arange(0, BLOCK_N)326 for _ in range(steps):
+327 b_kT = tl.load(p_kT)
+328 b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
+329
+330 tl.static_assert(b_s.dtype == HI_PRES_TL)
+331 b_s = b_s * scale
+332 if MASK:
+333 mask = offs_m[:, None] >= (start_n + offs_n[None, :])
+334 b_s = b_s + tl.where(mask, 0, -1.0e6)508 p_qT = tl.advance(p_qT, (0, start_m))
-509 p_do = tl.advance(p_do, (start_m, 0))
-510 p_lse = tl.advance(p_lse, (start_m,))
-511 p_pdp = tl.advance(p_pdp, (start_m,))337 tl.static_assert(len(b_s.shape) == 2)
+338 b_m_new = tl.maximum(b_m, tl.max(b_s, -1))514 for _ in range(steps):340 b_p = tl.math.exp2(b_s - b_m_new[:, None])516 b_qT = tl.load(p_qT)342 b_l_new = tl.sum(b_p, -1)519 b_m = tl.load(p_lse)345 b_m_m_new = tl.math.exp2(b_m - b_m_new)Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of
+
524 b_qkT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
-525 b_pT = tl.math.exp2(b_qkT - b_m[None, :])347 b_l = b_l * b_m_m_new + b_l_new528 if MASK:
-529 mask = (offs_m[None, :] >= offs_n[:, None])
-530 b_pT = tl.where(mask, b_pT, 0.0)350 b_v = tl.load(p_v)
+351 b_acc = b_acc * b_m_m_new[:, None]
+352 b_p = b_p.to(b_q.dtype)
+353 b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)533 b_do = tl.load(p_do)
-534 b_dv += tl.dot(b_pT.to(b_do.dtype),
-535 b_do,
-536 out_dtype=HI_PRES_TL)356 b_m = b_m_new539 b_pdp = tl.load(p_pdp)359 start_n += BLOCK_N
+360 p_v = tl.advance(p_v, (BLOCK_N, 0))
+361 p_kT = tl.advance(p_kT, (0, BLOCK_N))
+362
+363 tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
+364
+365 return b_acc, b_l, b_m- +
541 b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)368@triton.jit
+369def _attn_bwd_d(t_o, t_do,
+370 t_pdp,
+371 BLOCK_M: tl.constexpr, d_head: tl.constexpr,
+372 q_seq_len: tl.constexpr,
+373 n_groups: tl.constexpr,
+374 ):
+375 i = tl.program_id(0) * BLOCK_M
+376 z = tl.program_id(1)543 b_dsT = b_pT * (b_dpT - b_pdp[None, :])379 p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
+380 (n_groups, q_seq_len, d_head),
+381 (q_seq_len * d_head, d_head, 1),
+382 (0, i, 0),
+383 (n_groups, BLOCK_M, d_head),
+384 (2, 1, 0))
+385 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
+386 (n_groups, q_seq_len, d_head),
+387 (q_seq_len * d_head, d_head, 1),
+388 (0, i, 0),
+389 (n_groups, BLOCK_M, d_head),
+390 (2, 1, 0))
+391 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
+392 (n_groups, q_seq_len),
+393 (q_seq_len, 1),
+394 (0, i),
+395 (n_groups, BLOCK_M),
+396 (1, 0))
+397
+398 o = tl.load(p_o)
+399 do = tl.load(p_do).to(HI_PRES_TL)
+400 d = tl.sum(o * do, axis=-1)
+401 tl.store(p_pdp, d)545 b_dk += tl.dot(b_dsT.to(b_qT.dtype),
-546 tl.trans(b_qT), out_dtype=HI_PRES_TL)404@triton.autotune(_get_autotune_configs(inner_loop='query'),
+405 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+406@triton.jit
+407def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
+408 t_do,
+409 t_dk, t_dv,
+410 t_lse, t_pdp,
+411 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+412 n_groups: tl.constexpr, d_head: tl.constexpr,
+413 is_causal: tl.constexpr,
+414 BLOCK_M: tl.constexpr,
+415 BLOCK_N: tl.constexpr,
+416 ):549 offs_m += BLOCK_M
-550 p_lse = tl.advance(p_lse, (BLOCK_M,))
-551 p_pdp = tl.advance(p_pdp, (BLOCK_M,))
-552 p_qT = tl.advance(p_qT, (0, BLOCK_M))
-553 p_do = tl.advance(p_do, (BLOCK_M, 0))421 n = tl.program_id(0)
+422 z = tl.program_id(1)
+423
+424 p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+425 (kv_seq_len, d_head),
+426 (d_head, 1),
+427 (n * BLOCK_N, 0),
+428 (BLOCK_N, d_head),
+429 (1, 0))
+430 p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+431 (kv_seq_len, d_head),
+432 (d_head, 1),
+433 (n * BLOCK_N, 0),
+434 (BLOCK_N, d_head),
+435 (1, 0))
+436 p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
+437 (kv_seq_len, d_head),
+438 (d_head, 1),
+439 (n * BLOCK_N, 0),
+440 (BLOCK_N, d_head),
+441 (1, 0))
+442 p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
+443 (kv_seq_len, d_head),
+444 (d_head, 1),
+445 (n * BLOCK_N, 0),
+446 (BLOCK_N, d_head),
+447 (1, 0))
+448
+449 b_dv = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)
+450 b_dk = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)Return accumulated and
+load K and V: they stay in SRAM throughout the inner loop.
556 return b_dk, b_dv453 b_k = tl.load(p_k)
+454 b_v = tl.load(p_v)Iterate through queries that attend to save keys
+559@triton.autotune(_get_autotune_configs(inner_loop='key'),
-560 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
-561@triton.jit
-562def _attn_bwd_dq(t_q, t_k, t_v, t_do,
-563 t_dq,
-564 t_lse, t_pdp,
-565 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
-566 n_groups: tl.constexpr, d_head: tl.constexpr,
-567 is_causal: tl.constexpr,
-568 BLOCK_M: tl.constexpr,
-569 BLOCK_N: tl.constexpr,
-570 ):457 for g in range(n_groups):572 LN2: tl.constexpr = 0.6931471824645996 # type: ignore
-573
-574 m = tl.program_id(0)
-575 z = tl.program_id(1) // n_groups
-576 g = tl.program_id(1) % n_groups
-577
-578 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-579 (q_seq_len, d_head),
-580 (d_head, 1),
-581 (m * BLOCK_M, 0),
-582 (BLOCK_M, d_head),
-583 (1, 0))
-584 p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-585 (q_seq_len, d_head),
-586 (d_head, 1),
-587 (m * BLOCK_M, 0),
-588 (BLOCK_M, d_head),
-589 (1, 0))
-590 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
-591 (q_seq_len, d_head),
-592 (d_head, 1),
-593 (m * BLOCK_M, 0),
-594 (BLOCK_M, d_head),
-595 (1, 0))
-596 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
-597 (d_head, kv_seq_len),
-598 (1, d_head),
-599 (0, 0),
-600 (d_head, BLOCK_N),
-601 (0, 1))
-602 p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
-603 (d_head, kv_seq_len),
-604 (1, d_head),
-605 (0, 0),
-606 (d_head, BLOCK_N),
-607 (0, 1))
-608 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
-609 (q_seq_len,),
-610 (1,),
-611 (m * BLOCK_M,),
-612 (BLOCK_M,),
-613 (0,))
-614 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
-615 (q_seq_len,),
-616 (1,),
-617 (m * BLOCK_M,),
-618 (BLOCK_M,),
-619 (0,))
-620
-621 b_q = tl.load(p_q)
-622 b_do = tl.load(p_do)
-623 b_pdp = tl.load(p_pdp)
-624
-625 b_dq = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)
-626
-627 b_lse = tl.load(p_lse)459 p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+460 (d_head, q_seq_len),
+461 (1, d_head),
+462 (0, 0),
+463 (d_head, BLOCK_M),
+464 (0, 1))
+465
+466 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+467 (q_seq_len, d_head),
+468 (d_head, 1),
+469 (0, 0),
+470 (BLOCK_M, d_head),
+471 (1, 0))
+472 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+473 (q_seq_len,),
+474 (1,),
+475 (0,),
+476 (BLOCK_M,),
+477 (0,))
+478 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+479 (q_seq_len,),
+480 (1,),
+481 (0,),
+482 (BLOCK_M,),
+483 (0,))631 if is_causal:Compute for masked (diagonal) blocks.
+Compute and along the masked blocks near diagonal. Use smaller block size of MASK_BLOCK_M because there is a little extra computation?
633 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-634 b_do, b_lse, b_pdp,
-635 BLOCK_M, BLOCK_N,
-636 m=m * BLOCK_M, start_n=m * BLOCK_M,
-637 steps=BLOCK_M // BLOCK_N,
-638 MASK=True
-639 )491 if is_causal:642 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-643 b_do, b_lse, b_pdp,
-644 BLOCK_M, BLOCK_N,
-645 m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32), # type: ignore
-646 steps=(m * BLOCK_M) // BLOCK_N,
-647 MASK=False
-648 )
-649 else:
-650 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-651 b_do, b_lse, b_pdp,
-652 BLOCK_M, BLOCK_N,
-653 m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32), # type: ignore
-654 steps=kv_seq_len // BLOCK_N,
-655 MASK=False
-656 )493 b_dk, b_dv = _attn_bwd_dkdv_inner(
+494 b_dk, b_dv,
+495 p_qT, b_k, b_v, p_do,
+496 p_lse, p_pdp,Since was scaled by , and got this factor in to computed we need to reverse it.
+You can use a smaller BLOCK_M if BLOCK_N is not divisible by BLOCK_M
660 b_dq *= LN2498 BLOCK_M, BLOCK_N,
+499 d_head,
+500 n=n * BLOCK_N, start_m=n * BLOCK_N,
+501 steps=BLOCK_N // BLOCK_M,
+502 MASK=True
+503 )663 tl.store(p_dq, b_dq.to(t_dq.type.element_ty))506 b_dk, b_dv = _attn_bwd_dkdv_inner(
+507 b_dk, b_dv,
+508 p_qT, b_k, b_v, p_do,
+509 p_lse, p_pdp,
+510 BLOCK_M, BLOCK_N,
+511 d_head,
+512 n=n * BLOCK_N, start_m=(n + 1) * BLOCK_N,
+513 steps=(q_seq_len - (n + 1) * BLOCK_N) // BLOCK_M,
+514 MASK=False,
+515 )
+516 else:
+517 b_dk, b_dv = _attn_bwd_dkdv_inner(
+518 b_dk, b_dv,
+519 p_qT, b_k, b_v, p_do,
+520 p_lse, p_pdp,
+521 BLOCK_M, BLOCK_N,
+522 d_head,
+523 n=n * BLOCK_N, start_m=tl.full([], 0, tl.int32),
+524 steps=q_seq_len // BLOCK_M,
+525 MASK=False,
+526 )666@triton.jit
-667def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
-668 b_do, b_lse, b_pdp,
-669 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
-670 m, start_n, steps,
-671 MASK: tl.constexpr):529 tl.store(p_dv, b_dv.to(t_dv.type.element_ty))Since we used where $hat{k} are the original keys we multiple by scale again to get gradient on original keys.
+673 offs_m = m + tl.arange(0, BLOCK_M)
-674
-675 p_kT = tl.advance(p_kT, (0, start_n))
-676 p_vT = tl.advance(p_vT, (0, start_n))
-677
-678 tl.static_assert(BLOCK_M % BLOCK_N == 0, 'BLOCK_M must be divisible by BLOCK_N')
-679
-680 for _ in range(steps):533 b_dk *= sm_scaleNot that k is already multiplied by softmax scale. It is also divided by so we can use instead of
+Save
684 b_kT = tl.load(p_kT)
-685 b_vT = tl.load(p_vT)
-686 b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
-687 b_p = tl.math.exp2(b_qk - b_lse[:, None])536 tl.store(p_dk, b_dk.to(t_dk.type.element_ty))690 if MASK:
-691 offs_n = start_n + tl.arange(0, BLOCK_N)
-692 mask = (offs_m[:, None] >= offs_n[None, :])
-693 b_p = tl.where(mask, b_p, 0.0)539@triton.jit
+540def _attn_bwd_dkdv_inner(b_dk, b_dv,
+541 p_qT, b_k, b_v, p_do,
+542 p_lse, p_pdp,
+543 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
+544 d_head: tl.constexpr,
+545 n, start_m, steps,
+546 MASK: tl.constexpr):550 tl.static_assert(BLOCK_N % BLOCK_M == 0)698 b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)553 offs_m = start_m + tl.arange(0, BLOCK_M)
+554 offs_n = n + tl.arange(0, BLOCK_N)700 b_ds = b_p * (b_dp - b_pdp[:, None])557 p_qT = tl.advance(p_qT, (0, start_m))
+558 p_do = tl.advance(p_do, (start_m, 0))
+559 p_lse = tl.advance(p_lse, (start_m,))
+560 p_pdp = tl.advance(p_pdp, (start_m,))702 b_dq += tl.dot(b_ds.to(b_kT.dtype),
-703 tl.trans(b_kT),
-704 out_dtype=HI_PRES_TL)563 for _ in range(steps):707 start_n += BLOCK_N
-708 p_kT = tl.advance(p_kT, (0, BLOCK_N))
-709 p_vT = tl.advance(p_vT, (0, BLOCK_N))565 b_qT = tl.load(p_qT)712 return b_dq568 b_m = tl.load(p_lse)Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of
+ +573 b_qkT = tl.dot(b_k, b_qT, out_dtype=HI_PRES_TL)
+574 b_pT = tl.math.exp2(b_qkT - b_m[None, :])Autoregressive masking.
+ +577 if MASK:
+578 mask = (offs_m[None, :] >= offs_n[:, None])
+579 b_pT = tl.where(mask, b_pT, 0.0)+ +
582 b_do = tl.load(p_do)
+583 b_dv += tl.dot(b_pT.to(b_do.dtype),
+584 b_do,
+585 out_dtype=HI_PRES_TL)+ +
588 b_pdp = tl.load(p_pdp)+ +
590 b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)+ +
592 b_dsT = b_pT * (b_dpT - b_pdp[None, :])+ +
594 b_dk += tl.dot(b_dsT.to(b_qT.dtype),
+595 tl.trans(b_qT), out_dtype=HI_PRES_TL)Increment pointers.
+ +598 offs_m += BLOCK_M
+599 p_lse = tl.advance(p_lse, (BLOCK_M,))
+600 p_pdp = tl.advance(p_pdp, (BLOCK_M,))
+601 p_qT = tl.advance(p_qT, (0, BLOCK_M))
+602 p_do = tl.advance(p_do, (BLOCK_M, 0))Return accumulated and
+ +605 return b_dk, b_dv608@triton.autotune(_get_autotune_configs(inner_loop='key'),
+609 key=["q_seq_len", "kv_seq_len", "d_head", "n_groups", "is_causal"])
+610@triton.jit
+611def _attn_bwd_dq(t_q, t_k, t_v, t_do,
+612 t_dq,
+613 t_lse, t_pdp,
+614 q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
+615 n_groups: tl.constexpr, d_head: tl.constexpr,
+616 is_causal: tl.constexpr,
+617 BLOCK_M: tl.constexpr,
+618 BLOCK_N: tl.constexpr,
+619 ):+ +
621 LN2: tl.constexpr = 0.6931471824645996 # type: ignore
+622
+623 m = tl.program_id(0)
+624 z = tl.program_id(1) // n_groups
+625 g = tl.program_id(1) % n_groupsCreate block pointers
+ +628 p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+629 (q_seq_len, d_head),
+630 (d_head, 1),
+631 (m * BLOCK_M, 0),
+632 (BLOCK_M, d_head),
+633 (1, 0))
+634 p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+635 (q_seq_len, d_head),
+636 (d_head, 1),
+637 (m * BLOCK_M, 0),
+638 (BLOCK_M, d_head),
+639 (1, 0))
+640 p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
+641 (q_seq_len, d_head),
+642 (d_head, 1),
+643 (m * BLOCK_M, 0),
+644 (BLOCK_M, d_head),
+645 (1, 0))
+646 p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
+647 (d_head, kv_seq_len),
+648 (1, d_head),
+649 (0, 0),
+650 (d_head, BLOCK_N),
+651 (0, 1))
+652 p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
+653 (d_head, kv_seq_len),
+654 (1, d_head),
+655 (0, 0),
+656 (d_head, BLOCK_N),
+657 (0, 1))
+658 p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
+659 (q_seq_len,),
+660 (1,),
+661 (m * BLOCK_M,),
+662 (BLOCK_M,),
+663 (0,))
+664 p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
+665 (q_seq_len,),
+666 (1,),
+667 (m * BLOCK_M,),
+668 (BLOCK_M,),
+669 (0,))
+670
+671 b_q = tl.load(p_q)
+672 b_do = tl.load(p_do)
+673 b_pdp = tl.load(p_pdp)
+674
+675 b_dq = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)
+676
+677 b_lse = tl.load(p_lse)+ +
681 if is_causal:Compute for masked (diagonal) blocks.
+ +683 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+684 b_do, b_lse, b_pdp,
+685 BLOCK_M, BLOCK_N,
+686 m=m * BLOCK_M, start_n=m * BLOCK_M,
+687 steps=BLOCK_M // BLOCK_N,
+688 MASK=True
+689 )Other blocks
+ +692 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+693 b_do, b_lse, b_pdp,
+694 BLOCK_M, BLOCK_N,
+695 m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32), # type: ignore
+696 steps=(m * BLOCK_M) // BLOCK_N,
+697 MASK=False
+698 )
+699 else:
+700 b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+701 b_do, b_lse, b_pdp,
+702 BLOCK_M, BLOCK_N,
+703 m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32), # type: ignore
+704 steps=kv_seq_len // BLOCK_N,
+705 MASK=False
+706 )Since was scaled by , and got this factor in to computed we need to reverse it.
+ +710 b_dq *= LN2Save
+ +713 tl.store(p_dq, b_dq.to(t_dq.type.element_ty))Inner loop over n key
+ +716@triton.jit
+717def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
+718 b_do, b_lse, b_pdp,
+719 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
+720 m, start_n, steps,
+721 MASK: tl.constexpr):723 offs_m = m + tl.arange(0, BLOCK_M)
+724
+725 p_kT = tl.advance(p_kT, (0, start_n))
+726 p_vT = tl.advance(p_vT, (0, start_n))
+727
+728 tl.static_assert(BLOCK_M % BLOCK_N == 0, 'BLOCK_M must be divisible by BLOCK_N')
+729
+730 for _ in range(steps):Not that k is already multiplied by softmax scale. It is also divided by so we can use instead of
+ +734 b_kT = tl.load(p_kT)
+735 b_vT = tl.load(p_vT)
+736 b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
+737 b_p = tl.math.exp2(b_qk - b_lse[:, None])Autoregressive masking.
+ +740 if MASK:
+741 offs_n = start_n + tl.arange(0, BLOCK_N)
+742 mask = (offs_m[:, None] >= offs_n[None, :])
+743 b_p = tl.where(mask, b_p, 0.0)+ +
+ +
748 b_dp = tl.dot(b_do, b_vT, out_dtype=HI_PRES_TL).to(HI_PRES_TL)+ +
750 b_ds = b_p * (b_dp - b_pdp[:, None])+ +
752 b_dq += tl.dot(b_ds.to(b_kT.dtype),
+753 tl.trans(b_kT),
+754 out_dtype=HI_PRES_TL)Increment pointers.
+ +757 start_n += BLOCK_N
+758 p_kT = tl.advance(p_kT, (0, BLOCK_N))
+759 p_vT = tl.advance(p_vT, (0, BLOCK_N))Return accumulated
+ +762 return b_dq