mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 10:18:50 +08:00
flash comments
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@ -16,4 +16,6 @@ html/
|
|||||||
diagrams/
|
diagrams/
|
||||||
.comet.config
|
.comet.config
|
||||||
settings.md
|
settings.md
|
||||||
labml_app.log
|
labml_app.log
|
||||||
|
/extensions
|
||||||
|
/.nb_editor
|
||||||
File diff suppressed because one or more lines are too long
@ -6,39 +6,54 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing import Any, Tuple, Optional
|
||||||
|
|
||||||
HI_PRES_TL: tl.constexpr = tl.float32
|
HI_PRES_TL: tl.constexpr = tl.float32
|
||||||
HI_PRES_TORCH: tl.constexpr = torch.float32
|
HI_PRES_TORCH: torch.dtype = torch.float32
|
||||||
|
|
||||||
|
|
||||||
class AttentionFunc(torch.autograd.Function):
|
class AttentionFunc(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, q, k, v, causal, sm_scale):
|
def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
# Shape batch size, n_heads, seq, d
|
causal: bool, sm_scale: float) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Group query attention forward pass. Returns the output in shape `[batch_size, n_heads, q_seq_len, d_head]`.
|
||||||
|
|
||||||
|
:param ctx: is the context for torch gradient descent
|
||||||
|
:param q: has shape `[batch_size, n_heads, q_seq_len, d_head]`
|
||||||
|
:param q: has shape `[batch_size, n_heads, q_seq_len, d_head]`
|
||||||
|
:param k: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
|
||||||
|
:param v: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
|
||||||
|
:param causal: whether to apply causal attention mask
|
||||||
|
:param sm_scale: softmax scale factor
|
||||||
|
"""
|
||||||
batch_size, n_heads, q_seq_len, d_head = q.shape
|
batch_size, n_heads, q_seq_len, d_head = q.shape
|
||||||
k_heads = k.shape[1]
|
_, k_heads, kv_seq_len, _ = k.shape
|
||||||
kv_seq_len = k.shape[2]
|
|
||||||
assert n_heads % k_heads == 0
|
assert n_heads % k_heads == 0
|
||||||
n_groups = n_heads // k_heads
|
n_groups = n_heads // k_heads
|
||||||
|
|
||||||
# shape constraints
|
# Shape constraints
|
||||||
assert d_head == k.shape[-1] == v.shape[-1]
|
assert d_head == k.shape[-1] == v.shape[-1]
|
||||||
assert d_head in {16, 32, 64, 128, 256}
|
assert d_head in {16, 32, 64, 128, 256}
|
||||||
|
|
||||||
|
# Change the tensors combining the heads with the batch dimension
|
||||||
q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
|
q = q.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
|
||||||
k = k.view(batch_size * k_heads, kv_seq_len, d_head)
|
k = k.view(batch_size * k_heads, kv_seq_len, d_head)
|
||||||
v = v.view(batch_size * k_heads, kv_seq_len, d_head)
|
v = v.view(batch_size * k_heads, kv_seq_len, d_head)
|
||||||
|
|
||||||
|
# Make sure the tensors are contiguous and the strides are same
|
||||||
assert q.is_contiguous()
|
assert q.is_contiguous()
|
||||||
assert k.is_contiguous()
|
assert k.is_contiguous()
|
||||||
assert v.is_contiguous()
|
assert v.is_contiguous()
|
||||||
|
assert k.stride() == v.stride()
|
||||||
|
|
||||||
|
# Tensor for the output
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
# Tensor for $\log_2 \sum_j e^{S_{ij}}$
|
||||||
lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
|
lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
|
||||||
|
|
||||||
|
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_M`
|
||||||
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups, 1)
|
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups, 1)
|
||||||
ctx.grid = grid
|
|
||||||
_attn_fwd[grid](
|
_attn_fwd[grid](
|
||||||
q, k, v, sm_scale, lse, o,
|
q, k, v, sm_scale, lse, o,
|
||||||
n_groups=n_groups,
|
n_groups=n_groups,
|
||||||
@ -48,41 +63,61 @@ class AttentionFunc(torch.autograd.Function):
|
|||||||
is_causal=causal,
|
is_causal=causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save the reshaped inputs and outputs for the backward pass
|
||||||
ctx.save_for_backward(q, k, v, o, lse)
|
ctx.save_for_backward(q, k, v, o, lse)
|
||||||
ctx.sm_scale = sm_scale
|
ctx.sm_scale = sm_scale
|
||||||
ctx.n_groups = n_groups
|
ctx.n_groups = n_groups
|
||||||
ctx.d_head = d_head
|
|
||||||
ctx.causal = causal
|
ctx.causal = causal
|
||||||
|
|
||||||
|
# Return the output in shape `[batch_size, n_heads, q_seq_len, d_head]`
|
||||||
return o.view(batch_size, n_heads, q_seq_len, d_head)
|
return o.view(batch_size, n_heads, q_seq_len, d_head)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, do):
|
def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
|
||||||
|
"""
|
||||||
|
The backward pass computes the gradients of the input tensors.
|
||||||
|
|
||||||
|
:param ctx: is the context for torch gradient descent
|
||||||
|
:param do: is the gradient tensor of the attention output with shape `[batch_size, n_heads, q_seq_len, d_head]`
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get saved tensors and attributes
|
||||||
n_groups = ctx.n_groups
|
n_groups = ctx.n_groups
|
||||||
sm_scale = ctx.sm_scale
|
sm_scale = ctx.sm_scale
|
||||||
causal = ctx.causal
|
causal = ctx.causal
|
||||||
q, k, v, o, lse = ctx.saved_tensors
|
q, k, v, o, lse = ctx.saved_tensors
|
||||||
|
|
||||||
|
# Get shapes
|
||||||
batch_size, n_heads, q_seq_len, d_head = do.shape
|
batch_size, n_heads, q_seq_len, d_head = do.shape
|
||||||
_, kv_seq_len, _ = k.shape
|
_, kv_seq_len, _ = k.shape
|
||||||
k_heads = n_heads // n_groups
|
k_heads = n_heads // n_groups
|
||||||
|
|
||||||
|
# Combine the heads with the batch dimension of the output gradients tensor
|
||||||
do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
|
do = do.view(batch_size * k_heads, n_groups, q_seq_len, d_head)
|
||||||
|
|
||||||
|
# Make sure it's contiguous and the strides are the same
|
||||||
assert do.is_contiguous()
|
assert do.is_contiguous()
|
||||||
assert k.stride() == v.stride()
|
assert k.stride() == v.stride()
|
||||||
assert q.stride() == o.stride() == do.stride()
|
assert q.stride() == o.stride() == do.stride()
|
||||||
|
|
||||||
|
# Create tensors for input gradients
|
||||||
dq = torch.empty_like(q)
|
dq = torch.empty_like(q)
|
||||||
dk = torch.empty_like(k)
|
dk = torch.empty_like(k)
|
||||||
dv = torch.empty_like(v)
|
dv = torch.empty_like(v)
|
||||||
|
|
||||||
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
# $\frac{1}{\log_e 2}$
|
||||||
arg_k = k * (sm_scale * RCP_LN2)
|
RCP_LN2 = 1.4426950408889634
|
||||||
BLOCK_M = 16
|
# Multiply $k$ by softmax scale
|
||||||
assert q_seq_len % BLOCK_M == 0
|
k_scaled = k * (sm_scale * RCP_LN2)
|
||||||
pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
|
|
||||||
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
||||||
pdp = torch.empty_like(lse)
|
pdp = torch.empty_like(lse)
|
||||||
|
# We use fixed `BLOCK_M` for backward pass on $D$
|
||||||
|
BLOCK_M = 16
|
||||||
|
assert q_seq_len % BLOCK_M == 0
|
||||||
|
# Compute $D_i$
|
||||||
|
#
|
||||||
|
# This is parallelized along the batch and query in blocks of size `BLOCK_M`
|
||||||
|
pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
|
||||||
_attn_bwd_d[pre_grid](
|
_attn_bwd_d[pre_grid](
|
||||||
o, do,
|
o, do,
|
||||||
pdp,
|
pdp,
|
||||||
@ -92,34 +127,42 @@ class AttentionFunc(torch.autograd.Function):
|
|||||||
n_groups=n_groups,
|
n_groups=n_groups,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
|
# Compute $dK$ and $dV$
|
||||||
|
#
|
||||||
|
# This is parallelized along the batch and keys in blocks of size `BLOCK_N`
|
||||||
grid = lambda args: (triton.cdiv(kv_seq_len, args['BLOCK_N']), batch_size * k_heads)
|
grid = lambda args: (triton.cdiv(kv_seq_len, args['BLOCK_N']), batch_size * k_heads)
|
||||||
_attn_bwd_dkdv[grid](
|
_attn_bwd_dkdv[grid](
|
||||||
q, arg_k, v, sm_scale, do, dk, dv,
|
q, k_scaled, v, sm_scale, do, dk, dv,
|
||||||
lse, pdp,
|
lse, pdp,
|
||||||
q_seq_len, kv_seq_len, n_groups, d_head,
|
q_seq_len, kv_seq_len, n_groups, d_head,
|
||||||
is_causal=causal,
|
is_causal=causal,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
# Compute $dQ$
|
||||||
|
#
|
||||||
|
# This is parallelized along the batch and queries in blocks of size `BLOCK_M`
|
||||||
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups)
|
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups)
|
||||||
_attn_bwd_dq[grid](
|
_attn_bwd_dq[grid](
|
||||||
q, arg_k, v, do,
|
q, k_scaled, v, do,
|
||||||
dq,
|
dq,
|
||||||
lse, pdp,
|
lse, pdp,
|
||||||
q_seq_len, kv_seq_len, n_groups, d_head,
|
q_seq_len, kv_seq_len, n_groups, d_head,
|
||||||
is_causal=causal,
|
is_causal=causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Split the combined batch and heads
|
||||||
dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
|
dq = dq.view(batch_size, n_heads, q_seq_len, d_head)
|
||||||
dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
|
dk = dk.view(batch_size, k_heads, kv_seq_len, d_head)
|
||||||
dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
|
dv = dv.view(batch_size, k_heads, kv_seq_len, d_head)
|
||||||
|
|
||||||
|
#
|
||||||
return dq, dk, dv, None, None
|
return dq, dk, dv, None, None
|
||||||
|
|
||||||
|
|
||||||
attention = AttentionFunc.apply
|
attention = AttentionFunc.apply
|
||||||
|
|
||||||
|
|
||||||
def _get_autotune_configs(inner_loop: str):
|
def _get_autotune_configs(inner_loop: str) -> list:
|
||||||
"""
|
"""
|
||||||
#### Configs for auto-tuning
|
#### Configs for auto-tuning
|
||||||
"""
|
"""
|
||||||
@ -176,15 +219,15 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
Stride `n` denote the stride on `seq_len` of key.
|
Stride `n` denote the stride on `seq_len` of key.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
start_m = tl.program_id(0)
|
i = tl.program_id(0)
|
||||||
z = tl.program_id(1) // n_groups
|
z = tl.program_id(1) // n_groups
|
||||||
g = tl.program_id(1) % n_groups
|
g = tl.program_id(1) % n_groups
|
||||||
|
|
||||||
# block pointers
|
# Create block pointers
|
||||||
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||||
(q_seq_len, d_head),
|
(q_seq_len, d_head),
|
||||||
(d_head, 1),
|
(d_head, 1),
|
||||||
(start_m * BLOCK_M, 0),
|
(i * BLOCK_M, 0),
|
||||||
(BLOCK_M, d_head),
|
(BLOCK_M, d_head),
|
||||||
(1, 0))
|
(1, 0))
|
||||||
p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
|
p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
|
||||||
@ -202,19 +245,19 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||||
(q_seq_len, d_head),
|
(q_seq_len, d_head),
|
||||||
(d_head, 1),
|
(d_head, 1),
|
||||||
(start_m * BLOCK_M, 0),
|
(i * BLOCK_M, 0),
|
||||||
(BLOCK_M, d_head),
|
(BLOCK_M, d_head),
|
||||||
(1, 0))
|
(1, 0))
|
||||||
p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
|
p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
|
||||||
(q_seq_len,),
|
(q_seq_len,),
|
||||||
(1,),
|
(1,),
|
||||||
(start_m * BLOCK_M,),
|
(i * BLOCK_M,),
|
||||||
(BLOCK_M,),
|
(BLOCK_M,),
|
||||||
(0,))
|
(0,))
|
||||||
|
|
||||||
# initialize offsets
|
# Initialize offsets
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
offs_i = i * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
offs_j = tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
# Initialize $m_i$ and $l_i$
|
# Initialize $m_i$ and $l_i$
|
||||||
b_m = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) - float("inf")
|
b_m = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) - float("inf")
|
||||||
@ -228,21 +271,22 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
b_q = tl.load(p_q)
|
b_q = tl.load(p_q)
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
# Run for ranges
|
# Upto the diagonal block
|
||||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||||
p_kT, p_v,
|
p_kT, p_v,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
BLOCK_M, d_head, BLOCK_N,
|
BLOCK_M, d_head, BLOCK_N,
|
||||||
offs_m, offs_n,
|
offs_i, offs_j,
|
||||||
start_n=tl.full([], 0, tl.int32), # type: ignore
|
start_n=tl.full([], 0, tl.int32), # type: ignore
|
||||||
steps=(start_m * BLOCK_M) // BLOCK_N,
|
steps=(i * BLOCK_M) // BLOCK_N,
|
||||||
MASK=False,
|
MASK=False,
|
||||||
)
|
)
|
||||||
|
# Diagonal block with masking within it
|
||||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
BLOCK_M, d_head, BLOCK_N,
|
BLOCK_M, d_head, BLOCK_N,
|
||||||
offs_m, offs_n,
|
offs_i, offs_j,
|
||||||
start_n=start_m * BLOCK_M,
|
start_n=i * BLOCK_M,
|
||||||
steps=BLOCK_M // BLOCK_N,
|
steps=BLOCK_M // BLOCK_N,
|
||||||
MASK=True,
|
MASK=True,
|
||||||
)
|
)
|
||||||
@ -250,7 +294,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
|||||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
BLOCK_M, d_head, BLOCK_N,
|
BLOCK_M, d_head, BLOCK_N,
|
||||||
offs_m, offs_n,
|
offs_i, offs_j,
|
||||||
start_n=tl.full([], 0, tl.int32), # type: ignore
|
start_n=tl.full([], 0, tl.int32), # type: ignore
|
||||||
steps=kv_seq_len // BLOCK_N,
|
steps=kv_seq_len // BLOCK_N,
|
||||||
MASK=False,
|
MASK=False,
|
||||||
@ -308,9 +352,10 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
|||||||
b_p = b_p.to(b_q.dtype)
|
b_p = b_p.to(b_q.dtype)
|
||||||
b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
|
b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
|
||||||
|
|
||||||
# update m_i and l_i
|
# update $m_i$
|
||||||
b_m = b_m_new
|
b_m = b_m_new
|
||||||
|
|
||||||
|
# Move pointers
|
||||||
start_n += BLOCK_N
|
start_n += BLOCK_N
|
||||||
p_v = tl.advance(p_v, (BLOCK_N, 0))
|
p_v = tl.advance(p_v, (BLOCK_N, 0))
|
||||||
p_kT = tl.advance(p_kT, (0, BLOCK_N))
|
p_kT = tl.advance(p_kT, (0, BLOCK_N))
|
||||||
@ -327,24 +372,26 @@ def _attn_bwd_d(t_o, t_do,
|
|||||||
q_seq_len: tl.constexpr,
|
q_seq_len: tl.constexpr,
|
||||||
n_groups: tl.constexpr,
|
n_groups: tl.constexpr,
|
||||||
):
|
):
|
||||||
m = tl.program_id(0) * BLOCK_M
|
i = tl.program_id(0) * BLOCK_M
|
||||||
z = tl.program_id(1)
|
z = tl.program_id(1)
|
||||||
|
|
||||||
|
# Create block pointers
|
||||||
p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
|
p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head,
|
||||||
(n_groups, q_seq_len, d_head),
|
(n_groups, q_seq_len, d_head),
|
||||||
(q_seq_len * d_head, d_head, 1),
|
(q_seq_len * d_head, d_head, 1),
|
||||||
(0, m, 0),
|
(0, i, 0),
|
||||||
(n_groups, BLOCK_M, d_head),
|
(n_groups, BLOCK_M, d_head),
|
||||||
(2, 1, 0))
|
(2, 1, 0))
|
||||||
p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
|
p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
|
||||||
(n_groups, q_seq_len, d_head),
|
(n_groups, q_seq_len, d_head),
|
||||||
(q_seq_len * d_head, d_head, 1),
|
(q_seq_len * d_head, d_head, 1),
|
||||||
(0, m, 0),
|
(0, i, 0),
|
||||||
(n_groups, BLOCK_M, d_head),
|
(n_groups, BLOCK_M, d_head),
|
||||||
(2, 1, 0))
|
(2, 1, 0))
|
||||||
p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
|
p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
|
||||||
(n_groups, q_seq_len),
|
(n_groups, q_seq_len),
|
||||||
(q_seq_len, 1),
|
(q_seq_len, 1),
|
||||||
(0, m),
|
(0, i),
|
||||||
(n_groups, BLOCK_M),
|
(n_groups, BLOCK_M),
|
||||||
(1, 0))
|
(1, 0))
|
||||||
|
|
||||||
@ -406,7 +453,9 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
|||||||
b_k = tl.load(p_k)
|
b_k = tl.load(p_k)
|
||||||
b_v = tl.load(p_v)
|
b_v = tl.load(p_v)
|
||||||
|
|
||||||
|
# Iterate through queries that attend to save keys
|
||||||
for g in range(n_groups):
|
for g in range(n_groups):
|
||||||
|
# Create block pointers
|
||||||
p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
p_qT = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||||
(d_head, q_seq_len),
|
(d_head, q_seq_len),
|
||||||
(1, d_head),
|
(1, d_head),
|
||||||
@ -575,6 +624,7 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
|||||||
z = tl.program_id(1) // n_groups
|
z = tl.program_id(1) // n_groups
|
||||||
g = tl.program_id(1) % n_groups
|
g = tl.program_id(1) % n_groups
|
||||||
|
|
||||||
|
# Create block pointers
|
||||||
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||||
(q_seq_len, d_head),
|
(q_seq_len, d_head),
|
||||||
(d_head, 1),
|
(d_head, 1),
|
||||||
@ -709,4 +759,4 @@ def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
|||||||
p_vT = tl.advance(p_vT, (0, BLOCK_N))
|
p_vT = tl.advance(p_vT, (0, BLOCK_N))
|
||||||
|
|
||||||
# Return accumulated $dq$
|
# Return accumulated $dq$
|
||||||
return b_dq
|
return b_dq
|
||||||
Reference in New Issue
Block a user