mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
flash comments
This commit is contained in:
@ -1,12 +1,67 @@
|
||||
"""
|
||||
This is based on the flash attention tutorial from [Triton](https://triton-lang.org/main/index.html)
|
||||
# Flash Attention
|
||||
|
||||
## Forward pass
|
||||
|
||||
\begin{align}
|
||||
S_{ij} &= q_i k_j^T
|
||||
\\
|
||||
P_{ij} &= \frac{e^{S_{ij}}}{\sum_j e^{S_{ij}}}
|
||||
\\
|
||||
O_i &= \sum_j P_{ij} o_j
|
||||
\\
|
||||
&= \frac{1}{\sum_j e^{S_{ij}}} \sum_j e^{S_{ij}} o_j
|
||||
\end{align}
|
||||
|
||||
You can compute $O_i$, instead of doing the full softmax,
|
||||
by computing the sum of exponents $l_i$ and the unnormalized output $\tilde{O}_i$
|
||||
while iterating over keys:
|
||||
|
||||
\begin{align}
|
||||
S_{ij} &= q_i k_j^T
|
||||
\\
|
||||
l_i &= l_i + e^{S_{ij}}
|
||||
\\
|
||||
\tilde{O}_i &\leftarrow \tilde{O}_i + e^{S_{ij}} o_j
|
||||
\end{align}
|
||||
|
||||
Finally you can compute,
|
||||
|
||||
$$O_i = \frac{\tilde{O}_i}{l_i}$$
|
||||
|
||||
To make it numerically stable flash attention subtracts the current max of $S_{ij}$ before exponentiating.
|
||||
|
||||
So it maintains the following while iterating over keys:
|
||||
|
||||
* $m_i$, the max $S_{ij}$
|
||||
* $l_i$, the sum of exponents $\sum_j e^{S_{ij} - m_i}$, and
|
||||
* $\tilde{O}_i$, the unnormalized output
|
||||
|
||||
For each block of keys $j_1 \dots j_2$ it updates them:
|
||||
|
||||
\begin{align}
|
||||
m_i^{\text{new}} &= \max(m_i, \max_{j=j1}^{j2} S_{ij})
|
||||
\\
|
||||
\tilde{P}_{ij} &= \exp(S_{ij} - m_i^{\text{new}})
|
||||
\\
|
||||
l_i &\leftarrow e^{m_i - m_{i}^{\text{new}}} l_i + \sum_{j=j1}^{j2} \tilde{P}_{ij}
|
||||
\\
|
||||
\tilde{O}_i &\leftarrow e^{m_i - m_{i}^{\text{new}}} \tilde{O}_i + \tilde{P}_{ij} * V_j
|
||||
\\
|
||||
\end{align}
|
||||
|
||||
Then finally,
|
||||
|
||||
$$O_i = \frac{\tilde{O}_i}{l_i}$$
|
||||
|
||||
## Backward pass
|
||||
"""
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
from typing import Any, Tuple, Optional
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HI_PRES_TL: tl.constexpr = tl.float32
|
||||
HI_PRES_TORCH: torch.dtype = torch.float32
|
||||
@ -14,7 +69,7 @@ HI_PRES_TORCH: torch.dtype = torch.float32
|
||||
|
||||
class AttentionFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
causal: bool, sm_scale: float) -> torch.Tensor:
|
||||
"""
|
||||
Group query attention forward pass. Returns the output in shape `[batch_size, n_heads, q_seq_len, d_head]`.
|
||||
@ -52,8 +107,8 @@ class AttentionFunc(torch.autograd.Function):
|
||||
# Tensor for $\log_2 \sum_j e^{S_{ij}}$
|
||||
lse = torch.empty((batch_size * k_heads, n_groups, q_seq_len), device=q.device, dtype=HI_PRES_TORCH)
|
||||
|
||||
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_M`
|
||||
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups, 1)
|
||||
# The forward computation will be parallelized along the batch dimension and the queries in blocks of size `BLOCK_Q`
|
||||
grid = lambda meta: (triton.cdiv(q_seq_len, meta["BLOCK_Q"]), batch_size * k_heads * n_groups, 1)
|
||||
_attn_fwd[grid](
|
||||
q, k, v, sm_scale, lse, o,
|
||||
n_groups=n_groups,
|
||||
@ -111,17 +166,16 @@ class AttentionFunc(torch.autograd.Function):
|
||||
k_scaled = k * (sm_scale * RCP_LN2)
|
||||
# $D_i = P^T_{i:}dP_{i:} = do^T_io_i$
|
||||
pdp = torch.empty_like(lse)
|
||||
# We use fixed `BLOCK_M` for backward pass on $D$
|
||||
BLOCK_M = 16
|
||||
assert q_seq_len % BLOCK_M == 0
|
||||
# We use fixed `BLOCK_Q` for backward pass on $D$
|
||||
BLOCK_Q = 16
|
||||
# Compute $D_i$
|
||||
#
|
||||
# This is parallelized along the batch and query in blocks of size `BLOCK_M`
|
||||
pre_grid = (q_seq_len // BLOCK_M, batch_size * k_heads)
|
||||
# This is parallelized along the batch and query in blocks of size `BLOCK_Q`
|
||||
pre_grid = (triton.cdiv(q_seq_len, BLOCK_Q), batch_size * k_heads)
|
||||
_attn_bwd_d[pre_grid](
|
||||
o, do,
|
||||
pdp,
|
||||
BLOCK_M=16,
|
||||
BLOCK_Q=16,
|
||||
d_head=d_head,
|
||||
q_seq_len=q_seq_len,
|
||||
n_groups=n_groups,
|
||||
@ -129,8 +183,8 @@ class AttentionFunc(torch.autograd.Function):
|
||||
)
|
||||
# Compute $dK$ and $dV$
|
||||
#
|
||||
# This is parallelized along the batch and keys in blocks of size `BLOCK_N`
|
||||
grid = lambda args: (triton.cdiv(kv_seq_len, args['BLOCK_N']), batch_size * k_heads)
|
||||
# This is parallelized along the batch and keys in blocks of size `BLOCK_K`
|
||||
grid = lambda meta: (triton.cdiv(kv_seq_len, meta['BLOCK_K']), batch_size * k_heads)
|
||||
_attn_bwd_dkdv[grid](
|
||||
q, k_scaled, v, sm_scale, do, dk, dv,
|
||||
lse, pdp,
|
||||
@ -140,8 +194,8 @@ class AttentionFunc(torch.autograd.Function):
|
||||
)
|
||||
# Compute $dQ$
|
||||
#
|
||||
# This is parallelized along the batch and queries in blocks of size `BLOCK_M`
|
||||
grid = lambda args: (triton.cdiv(q_seq_len, args["BLOCK_M"]), batch_size * k_heads * n_groups)
|
||||
# This is parallelized along the batch and queries in blocks of size `BLOCK_Q`
|
||||
grid = lambda meta: (triton.cdiv(q_seq_len, meta['BLOCK_Q']), batch_size * k_heads * n_groups)
|
||||
_attn_bwd_dq[grid](
|
||||
q, k_scaled, v, do,
|
||||
dq,
|
||||
@ -168,7 +222,7 @@ def _get_autotune_configs(inner_loop: str) -> list:
|
||||
"""
|
||||
|
||||
configs = []
|
||||
# List possible BLOCK_M and BLOCK_N that satisfy BLOCK_M divisible by BLOCK_N
|
||||
# List possible BLOCK_Q and BLOCK_K that satisfy BLOCK_Q divisible by BLOCK_K
|
||||
# and also try to cover a wide range
|
||||
for bm in [64, 128, 256]:
|
||||
# We'll try bn in [16, 32, 64, 128] that are divisors and <= bm
|
||||
@ -182,9 +236,9 @@ def _get_autotune_configs(inner_loop: str) -> list:
|
||||
if bm * bn < 128 * 128 and w == 8:
|
||||
continue
|
||||
|
||||
configs.append(triton.Config({'BLOCK_M': bm, 'BLOCK_N': bn}, num_stages=s, num_warps=w))
|
||||
configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
|
||||
|
||||
return configs
|
||||
return configs[:1]
|
||||
|
||||
|
||||
@triton.autotune(_get_autotune_configs(inner_loop='key'),
|
||||
@ -196,8 +250,8 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
kv_seq_len: tl.constexpr,
|
||||
d_head: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, # q seq len block
|
||||
BLOCK_N: tl.constexpr, # k seq len block
|
||||
BLOCK_Q: tl.constexpr, # q seq len block
|
||||
BLOCK_K: tl.constexpr, # k seq len block
|
||||
):
|
||||
"""
|
||||
:param t_q: query
|
||||
@ -210,8 +264,8 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
:param q_seq_len: query sequence length
|
||||
:param kv_seq_len: key/value sequence length
|
||||
:param d_head: size of a head
|
||||
:param BLOCK_M: block size for query sequence length
|
||||
:param BLOCK_N: block size for key sequence length
|
||||
:param BLOCK_Q: block size for query sequence length
|
||||
:param BLOCK_K: block size for key sequence length
|
||||
:param is_causal: whether causal attention
|
||||
|
||||
Strides `z`, `h`, `m` and `d` denote the stride of the corresponding dimensions
|
||||
@ -227,111 +281,125 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
|
||||
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(q_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(i * BLOCK_M, 0),
|
||||
(BLOCK_M, d_head),
|
||||
(i * BLOCK_Q, 0),
|
||||
(BLOCK_Q, d_head),
|
||||
(1, 0))
|
||||
p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
|
||||
(kv_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(0, 0),
|
||||
(BLOCK_N, d_head),
|
||||
(BLOCK_K, d_head),
|
||||
(1, 0))
|
||||
p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
|
||||
(d_head, kv_seq_len),
|
||||
(1, d_head),
|
||||
(0, 0),
|
||||
(d_head, BLOCK_N),
|
||||
(d_head, BLOCK_K),
|
||||
(0, 1))
|
||||
p_o = tl.make_block_ptr(t_o + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(q_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(i * BLOCK_M, 0),
|
||||
(BLOCK_M, d_head),
|
||||
(i * BLOCK_Q, 0),
|
||||
(BLOCK_Q, d_head),
|
||||
(1, 0))
|
||||
p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
|
||||
(q_seq_len,),
|
||||
(1,),
|
||||
(i * BLOCK_M,),
|
||||
(BLOCK_M,),
|
||||
(i * BLOCK_Q,),
|
||||
(BLOCK_Q,),
|
||||
(0,))
|
||||
|
||||
# Initialize offsets
|
||||
offs_i = i * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_j = tl.arange(0, BLOCK_N)
|
||||
offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
|
||||
i_mask = offs_i < q_seq_len
|
||||
offs_j = tl.arange(0, BLOCK_K)
|
||||
|
||||
# Initialize $m_i$ and $l_i$
|
||||
b_m = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) - float("inf")
|
||||
b_l = tl.zeros([BLOCK_M], dtype=HI_PRES_TL) + 1.0
|
||||
b_m = tl.where(i_mask, -float("inf"), 0.0)
|
||||
b_l = tl.where(i_mask, 1.0, 0.0)
|
||||
# Accumulate $O$
|
||||
b_acc = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)
|
||||
b_acc = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
||||
|
||||
# softmax scale / log(2)
|
||||
sm_scale = sm_scale * 1.44269504
|
||||
# Load $Q_i$
|
||||
b_q = tl.load(p_q)
|
||||
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
if is_causal:
|
||||
# Upto the diagonal block
|
||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
p_kT, p_v,
|
||||
sm_scale,
|
||||
BLOCK_M, d_head, BLOCK_N,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
start_n=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=(i * BLOCK_M) // BLOCK_N,
|
||||
j=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=(i * BLOCK_Q) // BLOCK_K,
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
# Diagonal block with masking within it
|
||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale,
|
||||
BLOCK_M, d_head, BLOCK_N,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
start_n=i * BLOCK_M,
|
||||
steps=BLOCK_M // BLOCK_N,
|
||||
j=i * BLOCK_Q,
|
||||
steps=BLOCK_Q // BLOCK_K,
|
||||
MASK=True,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
else:
|
||||
b_acc, b_l, b_m = _attn_fwd_inner(b_acc, b_l, b_m, b_q, p_kT, p_v,
|
||||
sm_scale,
|
||||
BLOCK_M, d_head, BLOCK_N,
|
||||
BLOCK_Q, d_head, BLOCK_K,
|
||||
offs_i, offs_j,
|
||||
start_n=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=kv_seq_len // BLOCK_N,
|
||||
j=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=tl.cdiv(kv_seq_len, BLOCK_K),
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
|
||||
# Update LSE
|
||||
tl.store(p_lse, b_m + tl.math.log2(b_l))
|
||||
tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty))
|
||||
tl.store(p_lse, b_m + tl.math.log2(b_l), boundary_check=(0,))
|
||||
tl.store(p_o, (b_acc / b_l[:, None]).to(t_o.type.element_ty), boundary_check=(0,))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
p_kT, p_v,
|
||||
scale,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_Q: tl.constexpr,
|
||||
d_head: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
offs_m, offs_n,
|
||||
start_n,
|
||||
BLOCK_K: tl.constexpr,
|
||||
offs_i, offs_j,
|
||||
j,
|
||||
steps,
|
||||
MASK: tl.constexpr,
|
||||
q_seq_len: tl.constexpr,
|
||||
kv_seq_len: tl.constexpr
|
||||
):
|
||||
tl.static_assert(BLOCK_M % BLOCK_N == 0)
|
||||
tl.static_assert(BLOCK_Q % BLOCK_K == 0)
|
||||
|
||||
p_kT = tl.advance(p_kT, (0, start_n))
|
||||
p_v = tl.advance(p_v, (start_n, 0))
|
||||
p_kT = tl.advance(p_kT, (0, j))
|
||||
p_v = tl.advance(p_v, (j, 0))
|
||||
|
||||
# loop over k, v and update accumulator
|
||||
for _ in range(steps):
|
||||
b_kT = tl.load(p_kT)
|
||||
current_j = j + offs_j
|
||||
j_mask = current_j < kv_seq_len
|
||||
|
||||
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
||||
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
||||
|
||||
tl.static_assert(b_s.dtype == HI_PRES_TL)
|
||||
b_s = b_s * scale
|
||||
if MASK:
|
||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||
b_s = b_s + tl.where(mask, 0, -1.0e6)
|
||||
causal_mask = offs_i[:, None] >= (j + offs_j[None, :])
|
||||
b_s = tl.where(causal_mask, b_s, -float("inf"))
|
||||
# always apply seq mask
|
||||
b_s = tl.where(j_mask[None, :], b_s, -float("inf"))
|
||||
|
||||
# $m_{i}^{\text{new}} = \max(m_i, \text{rowmax}(S_{ij}))$
|
||||
tl.static_assert(len(b_s.shape) == 2)
|
||||
@ -347,7 +415,7 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
b_l = b_l * b_m_m_new + b_l_new
|
||||
|
||||
# $O_i \leftarrow e^{m_i - m_{i}^{\text{new}}} O_i + \tilde{P}_{ij} * V_j$
|
||||
b_v = tl.load(p_v)
|
||||
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
|
||||
b_acc = b_acc * b_m_m_new[:, None]
|
||||
b_p = b_p.to(b_q.dtype)
|
||||
b_acc += tl.dot(b_p, b_v, out_dtype=HI_PRES_TL)
|
||||
@ -356,9 +424,9 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
b_m = b_m_new
|
||||
|
||||
# Move pointers
|
||||
start_n += BLOCK_N
|
||||
p_v = tl.advance(p_v, (BLOCK_N, 0))
|
||||
p_kT = tl.advance(p_kT, (0, BLOCK_N))
|
||||
j += BLOCK_K
|
||||
p_v = tl.advance(p_v, (BLOCK_K, 0))
|
||||
p_kT = tl.advance(p_kT, (0, BLOCK_K))
|
||||
|
||||
tl.static_assert(b_acc.dtype == HI_PRES_TL, "attn_fwd_inner requires accumulator to be in HI_PRES_TL precision")
|
||||
|
||||
@ -368,11 +436,11 @@ def _attn_fwd_inner(b_acc, b_l, b_m, b_q,
|
||||
@triton.jit
|
||||
def _attn_bwd_d(t_o, t_do,
|
||||
t_pdp,
|
||||
BLOCK_M: tl.constexpr, d_head: tl.constexpr,
|
||||
BLOCK_Q: tl.constexpr, d_head: tl.constexpr,
|
||||
q_seq_len: tl.constexpr,
|
||||
n_groups: tl.constexpr,
|
||||
):
|
||||
i = tl.program_id(0) * BLOCK_M
|
||||
i = tl.program_id(0) * BLOCK_Q
|
||||
z = tl.program_id(1)
|
||||
|
||||
# Create block pointers
|
||||
@ -380,25 +448,25 @@ def _attn_bwd_d(t_o, t_do,
|
||||
(n_groups, q_seq_len, d_head),
|
||||
(q_seq_len * d_head, d_head, 1),
|
||||
(0, i, 0),
|
||||
(n_groups, BLOCK_M, d_head),
|
||||
(n_groups, BLOCK_Q, d_head),
|
||||
(2, 1, 0))
|
||||
p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head,
|
||||
(n_groups, q_seq_len, d_head),
|
||||
(q_seq_len * d_head, d_head, 1),
|
||||
(0, i, 0),
|
||||
(n_groups, BLOCK_M, d_head),
|
||||
(n_groups, BLOCK_Q, d_head),
|
||||
(2, 1, 0))
|
||||
p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len,
|
||||
(n_groups, q_seq_len),
|
||||
(q_seq_len, 1),
|
||||
(0, i),
|
||||
(n_groups, BLOCK_M),
|
||||
(n_groups, BLOCK_Q),
|
||||
(1, 0))
|
||||
|
||||
o = tl.load(p_o)
|
||||
do = tl.load(p_do).to(HI_PRES_TL)
|
||||
o = tl.load(p_o, boundary_check=(1,), padding_option="zero")
|
||||
do = tl.load(p_do, boundary_check=(1,), padding_option="zero").to(HI_PRES_TL)
|
||||
d = tl.sum(o * do, axis=-1)
|
||||
tl.store(p_pdp, d)
|
||||
tl.store(p_pdp, d, boundary_check=(1,))
|
||||
|
||||
|
||||
@triton.autotune(_get_autotune_configs(inner_loop='query'),
|
||||
@ -411,47 +479,47 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
|
||||
n_groups: tl.constexpr, d_head: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_Q: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Loop along m query; n % m == 0
|
||||
"""
|
||||
# K is already multiplied by scale
|
||||
n = tl.program_id(0)
|
||||
j = tl.program_id(0) * BLOCK_K
|
||||
z = tl.program_id(1)
|
||||
|
||||
p_k = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
|
||||
(kv_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(n * BLOCK_N, 0),
|
||||
(BLOCK_N, d_head),
|
||||
(j, 0),
|
||||
(BLOCK_K, d_head),
|
||||
(1, 0))
|
||||
p_v = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
|
||||
(kv_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(n * BLOCK_N, 0),
|
||||
(BLOCK_N, d_head),
|
||||
(j, 0),
|
||||
(BLOCK_K, d_head),
|
||||
(1, 0))
|
||||
p_dk = tl.make_block_ptr(t_dk + z * kv_seq_len * d_head,
|
||||
(kv_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(n * BLOCK_N, 0),
|
||||
(BLOCK_N, d_head),
|
||||
(j, 0),
|
||||
(BLOCK_K, d_head),
|
||||
(1, 0))
|
||||
p_dv = tl.make_block_ptr(t_dv + z * kv_seq_len * d_head,
|
||||
(kv_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(n * BLOCK_N, 0),
|
||||
(BLOCK_N, d_head),
|
||||
(j, 0),
|
||||
(BLOCK_K, d_head),
|
||||
(1, 0))
|
||||
|
||||
b_dv = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)
|
||||
b_dk = tl.zeros([BLOCK_N, d_head], dtype=HI_PRES_TL)
|
||||
b_dv = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
|
||||
b_dk = tl.zeros([BLOCK_K, d_head], dtype=HI_PRES_TL)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
b_k = tl.load(p_k)
|
||||
b_v = tl.load(p_v)
|
||||
b_k = tl.load(p_k, boundary_check=(0,), padding_option="zero")
|
||||
b_v = tl.load(p_v, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
# Iterate through queries that attend to save keys
|
||||
for g in range(n_groups):
|
||||
@ -460,33 +528,33 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
(d_head, q_seq_len),
|
||||
(1, d_head),
|
||||
(0, 0),
|
||||
(d_head, BLOCK_M),
|
||||
(d_head, BLOCK_Q),
|
||||
(0, 1))
|
||||
|
||||
p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(q_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(0, 0),
|
||||
(BLOCK_M, d_head),
|
||||
(BLOCK_Q, d_head),
|
||||
(1, 0))
|
||||
p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
|
||||
(q_seq_len,),
|
||||
(1,),
|
||||
(0,),
|
||||
(BLOCK_M,),
|
||||
(BLOCK_Q,),
|
||||
(0,))
|
||||
p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
|
||||
(q_seq_len,),
|
||||
(1,),
|
||||
(0,),
|
||||
(BLOCK_M,),
|
||||
(BLOCK_Q,),
|
||||
(0,))
|
||||
|
||||
# $$dk_j = \sum_i dS_{ij} q_i = \sum_i P_{ij} \big( do_i^T v_j - D_i \big) q_i$$
|
||||
# $$dv_j = \sum_i P_{ij} do_i$$
|
||||
|
||||
# Compute $dk$ $dv$ and $dv$ along the masked blocks near diagonal.
|
||||
# Use smaller block size of MASK_BLOCK_M
|
||||
# Use smaller block size of MASK_BLOCK_Q
|
||||
# because there is a little extra computation?
|
||||
if is_causal:
|
||||
# loop along m
|
||||
@ -494,12 +562,14 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
b_dk, b_dv,
|
||||
p_qT, b_k, b_v, p_do,
|
||||
p_lse, p_pdp,
|
||||
# You can use a smaller BLOCK_M if BLOCK_N is not divisible by BLOCK_M
|
||||
BLOCK_M, BLOCK_N,
|
||||
# You can use a smaller BLOCK_Q if BLOCK_K is not divisible by BLOCK_Q
|
||||
BLOCK_Q, BLOCK_K,
|
||||
d_head,
|
||||
n=n * BLOCK_N, start_m=n * BLOCK_N,
|
||||
steps=BLOCK_N // BLOCK_M,
|
||||
MASK=True
|
||||
j=j, i=j,
|
||||
steps=BLOCK_K // BLOCK_Q,
|
||||
MASK=True,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len,
|
||||
)
|
||||
|
||||
# Compute $dk$ and $dv$ for non-masked blocks.
|
||||
@ -507,65 +577,72 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
|
||||
b_dk, b_dv,
|
||||
p_qT, b_k, b_v, p_do,
|
||||
p_lse, p_pdp,
|
||||
BLOCK_M, BLOCK_N,
|
||||
BLOCK_Q, BLOCK_K,
|
||||
d_head,
|
||||
n=n * BLOCK_N, start_m=(n + 1) * BLOCK_N,
|
||||
steps=(q_seq_len - (n + 1) * BLOCK_N) // BLOCK_M,
|
||||
j=j, i=j + BLOCK_K,
|
||||
steps=tl.cdiv((q_seq_len - (j + BLOCK_K)), BLOCK_Q),
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
else:
|
||||
b_dk, b_dv = _attn_bwd_dkdv_inner(
|
||||
b_dk, b_dv,
|
||||
p_qT, b_k, b_v, p_do,
|
||||
p_lse, p_pdp,
|
||||
BLOCK_M, BLOCK_N,
|
||||
BLOCK_Q, BLOCK_K,
|
||||
d_head,
|
||||
n=n * BLOCK_N, start_m=tl.full([], 0, tl.int32),
|
||||
steps=q_seq_len // BLOCK_M,
|
||||
j=j, i=tl.full([], 0, tl.int32),
|
||||
steps=tl.cdiv(q_seq_len, BLOCK_Q),
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
|
||||
# Save $dv$
|
||||
tl.store(p_dv, b_dv.to(t_dv.type.element_ty))
|
||||
tl.store(p_dv, b_dv.to(t_dv.type.element_ty), boundary_check=(0,))
|
||||
|
||||
# Since we used $k = \text{scale} * \hat{k}$ where $\hat{k} are the original keys
|
||||
# we multiple by scale again to get gradient on original keys.
|
||||
b_dk *= sm_scale
|
||||
|
||||
# Save $dk$
|
||||
tl.store(p_dk, b_dk.to(t_dk.type.element_ty))
|
||||
tl.store(p_dk, b_dk.to(t_dk.type.element_ty), boundary_check=(0,))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd_dkdv_inner(b_dk, b_dv,
|
||||
p_qT, b_k, b_v, p_do,
|
||||
p_lse, p_pdp,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
d_head: tl.constexpr,
|
||||
n, start_m, steps,
|
||||
MASK: tl.constexpr):
|
||||
j, i, steps,
|
||||
MASK: tl.constexpr,
|
||||
q_seq_len: tl.constexpr,
|
||||
kv_seq_len: tl.constexpr):
|
||||
"""Inner loop along m query"""
|
||||
|
||||
# To apply the mask
|
||||
tl.static_assert(BLOCK_N % BLOCK_M == 0)
|
||||
tl.static_assert(BLOCK_K % BLOCK_Q == 0)
|
||||
|
||||
# Offsets for mask computation
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M)
|
||||
offs_n = n + tl.arange(0, BLOCK_N)
|
||||
offs_i = i + tl.arange(0, BLOCK_Q)
|
||||
i_mask = offs_i < q_seq_len
|
||||
offs_j = j + tl.arange(0, BLOCK_K)
|
||||
|
||||
# Pointers
|
||||
p_qT = tl.advance(p_qT, (0, start_m))
|
||||
p_do = tl.advance(p_do, (start_m, 0))
|
||||
p_lse = tl.advance(p_lse, (start_m,))
|
||||
p_pdp = tl.advance(p_pdp, (start_m,))
|
||||
p_qT = tl.advance(p_qT, (0, i))
|
||||
p_do = tl.advance(p_do, (i, 0))
|
||||
p_lse = tl.advance(p_lse, (i,))
|
||||
p_pdp = tl.advance(p_pdp, (i,))
|
||||
|
||||
# Loop
|
||||
for _ in range(steps):
|
||||
# Load $$qT$$
|
||||
b_qT = tl.load(p_qT)
|
||||
b_qT = tl.load(p_qT, boundary_check=(1,), padding_option="zero")
|
||||
|
||||
# $M_i = log_2 L_i$
|
||||
b_m = tl.load(p_lse)
|
||||
b_m = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
# $$P_{ij} = \frac{e^{q_i^T k_j}}{L_i} = e^{q_i^T k_j - M_i}$$
|
||||
# Not that k is already multiplied by softmax scale.
|
||||
@ -575,31 +652,30 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
|
||||
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
mask = (offs_m[None, :] >= offs_n[:, None])
|
||||
mask = (offs_i[None, :] >= offs_j[:, None])
|
||||
b_pT = tl.where(mask, b_pT, 0.0)
|
||||
|
||||
b_pT = tl.where(i_mask[None, :], b_pT, 0.0)
|
||||
|
||||
# $$dv_j = \sum_i P_{ij} do_i$$
|
||||
b_do = tl.load(p_do)
|
||||
b_dv += tl.dot(b_pT.to(b_do.dtype),
|
||||
b_do,
|
||||
out_dtype=HI_PRES_TL)
|
||||
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
|
||||
b_dv += tl.dot(b_pT.to(b_do.dtype), b_do, out_dtype=HI_PRES_TL)
|
||||
|
||||
# $$dk_j = \sum_i dS_{ij} q_i = \sum_i P_{ij} \big( dP^T_{i:} - D_i \big) q_i$$
|
||||
b_pdp = tl.load(p_pdp)
|
||||
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
|
||||
# $dP_{ij} = do^T_i v_j$
|
||||
b_dpT = tl.dot(b_v, tl.trans(b_do), out_dtype=HI_PRES_TL).to(HI_PRES_TL)
|
||||
# $dS_{ij} = P_{ij} \big( dP_{i:} - D_i \big)$
|
||||
b_dsT = b_pT * (b_dpT - b_pdp[None, :])
|
||||
# $dk_j = \sum_i dS_{ij} q_i$
|
||||
b_dk += tl.dot(b_dsT.to(b_qT.dtype),
|
||||
tl.trans(b_qT), out_dtype=HI_PRES_TL)
|
||||
b_dk += tl.dot(b_dsT.to(b_qT.dtype), tl.trans(b_qT), out_dtype=HI_PRES_TL)
|
||||
|
||||
# Increment pointers.
|
||||
offs_m += BLOCK_M
|
||||
p_lse = tl.advance(p_lse, (BLOCK_M,))
|
||||
p_pdp = tl.advance(p_pdp, (BLOCK_M,))
|
||||
p_qT = tl.advance(p_qT, (0, BLOCK_M))
|
||||
p_do = tl.advance(p_do, (BLOCK_M, 0))
|
||||
offs_i += BLOCK_Q
|
||||
p_lse = tl.advance(p_lse, (BLOCK_Q,))
|
||||
p_pdp = tl.advance(p_pdp, (BLOCK_Q,))
|
||||
p_qT = tl.advance(p_qT, (0, BLOCK_Q))
|
||||
p_do = tl.advance(p_do, (BLOCK_Q, 0))
|
||||
|
||||
# Return accumulated $dk$ and $dv$
|
||||
return b_dk, b_dv
|
||||
@ -614,13 +690,13 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
q_seq_len: tl.constexpr, kv_seq_len: tl.constexpr,
|
||||
n_groups: tl.constexpr, d_head: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_Q: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
# $\log_e 2$
|
||||
LN2: tl.constexpr = 0.6931471824645996 # type: ignore
|
||||
|
||||
m = tl.program_id(0)
|
||||
i = tl.program_id(0) * BLOCK_Q
|
||||
z = tl.program_id(1) // n_groups
|
||||
g = tl.program_id(1) % n_groups
|
||||
|
||||
@ -628,53 +704,53 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(q_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(m * BLOCK_M, 0),
|
||||
(BLOCK_M, d_head),
|
||||
(i, 0),
|
||||
(BLOCK_Q, d_head),
|
||||
(1, 0))
|
||||
p_dq = tl.make_block_ptr(t_dq + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(q_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(m * BLOCK_M, 0),
|
||||
(BLOCK_M, d_head),
|
||||
(i, 0),
|
||||
(BLOCK_Q, d_head),
|
||||
(1, 0))
|
||||
p_do = tl.make_block_ptr(t_do + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
|
||||
(q_seq_len, d_head),
|
||||
(d_head, 1),
|
||||
(m * BLOCK_M, 0),
|
||||
(BLOCK_M, d_head),
|
||||
(i, 0),
|
||||
(BLOCK_Q, d_head),
|
||||
(1, 0))
|
||||
p_kT = tl.make_block_ptr(t_k + z * kv_seq_len * d_head,
|
||||
(d_head, kv_seq_len),
|
||||
(1, d_head),
|
||||
(0, 0),
|
||||
(d_head, BLOCK_N),
|
||||
(d_head, BLOCK_K),
|
||||
(0, 1))
|
||||
p_vT = tl.make_block_ptr(t_v + z * kv_seq_len * d_head,
|
||||
(d_head, kv_seq_len),
|
||||
(1, d_head),
|
||||
(0, 0),
|
||||
(d_head, BLOCK_N),
|
||||
(d_head, BLOCK_K),
|
||||
(0, 1))
|
||||
p_lse = tl.make_block_ptr(t_lse + z * n_groups * q_seq_len + g * q_seq_len,
|
||||
(q_seq_len,),
|
||||
(1,),
|
||||
(m * BLOCK_M,),
|
||||
(BLOCK_M,),
|
||||
(i,),
|
||||
(BLOCK_Q,),
|
||||
(0,))
|
||||
p_pdp = tl.make_block_ptr(t_pdp + z * n_groups * q_seq_len + g * q_seq_len,
|
||||
(q_seq_len,),
|
||||
(1,),
|
||||
(m * BLOCK_M,),
|
||||
(BLOCK_M,),
|
||||
(i,),
|
||||
(BLOCK_Q,),
|
||||
(0,))
|
||||
|
||||
b_q = tl.load(p_q)
|
||||
b_do = tl.load(p_do)
|
||||
b_pdp = tl.load(p_pdp)
|
||||
b_q = tl.load(p_q, boundary_check=(0,), padding_option="zero")
|
||||
b_do = tl.load(p_do, boundary_check=(0,), padding_option="zero")
|
||||
b_pdp = tl.load(p_pdp, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
b_dq = tl.zeros([BLOCK_M, d_head], dtype=HI_PRES_TL)
|
||||
b_dq = tl.zeros([BLOCK_Q, d_head], dtype=HI_PRES_TL)
|
||||
|
||||
b_lse = tl.load(p_lse)
|
||||
b_lse = tl.load(p_lse, boundary_check=(0,), padding_option="zero")
|
||||
|
||||
# $$dq_i = \sum_j dS_{ij} k_j = \sum_j P_{ij} \big( dP_{ij} - D_i \big) k_j$$
|
||||
|
||||
@ -682,27 +758,33 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
# Compute $dQ$ for masked (diagonal) blocks.
|
||||
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||
b_do, b_lse, b_pdp,
|
||||
BLOCK_M, BLOCK_N,
|
||||
m=m * BLOCK_M, start_n=m * BLOCK_M,
|
||||
steps=BLOCK_M // BLOCK_N,
|
||||
MASK=True
|
||||
BLOCK_Q, BLOCK_K,
|
||||
i=i, j=i,
|
||||
steps=BLOCK_Q // BLOCK_K,
|
||||
MASK=True,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
|
||||
# Other blocks
|
||||
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||
b_do, b_lse, b_pdp,
|
||||
BLOCK_M, BLOCK_N,
|
||||
m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=(m * BLOCK_M) // BLOCK_N,
|
||||
MASK=False
|
||||
BLOCK_Q, BLOCK_K,
|
||||
i=i, j=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=i // BLOCK_K,
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
else:
|
||||
b_dq = _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||
b_do, b_lse, b_pdp,
|
||||
BLOCK_M, BLOCK_N,
|
||||
m=m * BLOCK_M, start_n=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=kv_seq_len // BLOCK_N,
|
||||
MASK=False
|
||||
BLOCK_Q, BLOCK_K,
|
||||
i=i, j=tl.full([], 0, tl.int32), # type: ignore
|
||||
steps=tl.cdiv(kv_seq_len, BLOCK_K),
|
||||
MASK=False,
|
||||
q_seq_len=q_seq_len,
|
||||
kv_seq_len=kv_seq_len
|
||||
)
|
||||
|
||||
# Since $k$ was scaled by $\frac{1}{log_e 2}$, and $dq_j = \sum_j dS_{ij} k_j$
|
||||
@ -710,37 +792,44 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
|
||||
b_dq *= LN2
|
||||
|
||||
# Save $dq$
|
||||
tl.store(p_dq, b_dq.to(t_dq.type.element_ty))
|
||||
tl.store(p_dq, b_dq.to(t_dq.type.element_ty), boundary_check=(0,))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||
b_do, b_lse, b_pdp,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
m, start_n, steps,
|
||||
MASK: tl.constexpr):
|
||||
BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
i, j, steps,
|
||||
MASK: tl.constexpr,
|
||||
q_seq_len: tl.constexpr,
|
||||
kv_seq_len: tl.constexpr):
|
||||
"""Inner loop over n key"""
|
||||
offs_m = m + tl.arange(0, BLOCK_M)
|
||||
offs_i = i + tl.arange(0, BLOCK_Q)
|
||||
offs_j = tl.arange(0, BLOCK_K)
|
||||
|
||||
p_kT = tl.advance(p_kT, (0, start_n))
|
||||
p_vT = tl.advance(p_vT, (0, start_n))
|
||||
p_kT = tl.advance(p_kT, (0, j))
|
||||
p_vT = tl.advance(p_vT, (0, j))
|
||||
|
||||
tl.static_assert(BLOCK_M % BLOCK_N == 0, 'BLOCK_M must be divisible by BLOCK_N')
|
||||
tl.static_assert(BLOCK_Q % BLOCK_K == 0, 'BLOCK_Q must be divisible by BLOCK_K')
|
||||
|
||||
for _ in range(steps):
|
||||
current_j = j + offs_j
|
||||
j_mask = current_j < kv_seq_len
|
||||
|
||||
# $$P_{ij} = \frac{e^{q_i^T k_j}}{L_i} = e^{q_i^T k_j - M_i}$$
|
||||
# Not that k is already multiplied by softmax scale.
|
||||
# It is also divided by $log_e 2$ so we can use $2^x$ instead of $e^x$
|
||||
b_kT = tl.load(p_kT)
|
||||
b_vT = tl.load(p_vT)
|
||||
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
|
||||
b_vT = tl.load(p_vT, boundary_check=(1,), padding_option="zero")
|
||||
b_qk = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
|
||||
b_p = tl.math.exp2(b_qk - b_lse[:, None])
|
||||
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
mask = (offs_m[:, None] >= offs_n[None, :])
|
||||
b_p = tl.where(mask, b_p, 0.0)
|
||||
causal_mask = (offs_i[:, None] >= current_j[None, :])
|
||||
b_p = tl.where(causal_mask, b_p, 0.0)
|
||||
|
||||
b_p = tl.where(j_mask[None, :], b_p, 0.0)
|
||||
|
||||
# $$dq_i = \sum_j dS_{ij} k_j = \sum_j P_{ij} \big( dP_{ij} - D_i \big) k_j$$
|
||||
|
||||
@ -754,9 +843,9 @@ def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
|
||||
out_dtype=HI_PRES_TL)
|
||||
|
||||
# Increment pointers.
|
||||
start_n += BLOCK_N
|
||||
p_kT = tl.advance(p_kT, (0, BLOCK_N))
|
||||
p_vT = tl.advance(p_vT, (0, BLOCK_N))
|
||||
j += BLOCK_K
|
||||
p_kT = tl.advance(p_kT, (0, BLOCK_K))
|
||||
p_vT = tl.advance(p_vT, (0, BLOCK_K))
|
||||
|
||||
# Return accumulated $dq$
|
||||
return b_dq
|
||||
return b_dq
|
||||
|
@ -19,7 +19,7 @@ def _calc_abs_rel_error(a: torch.Tensor, b: torch.Tensor, atol=1e-2):
|
||||
|
||||
|
||||
def _test_op(batch_size, n_heads, k_heads, q_seq_len, kv_seq_len, d_head, causal, dtype, device):
|
||||
with monit.section('Init'):
|
||||
with monit.section(f'Init {q_seq_len} {kv_seq_len} {d_head}'):
|
||||
torch.manual_seed(20)
|
||||
q = (torch.empty((batch_size, n_heads, q_seq_len, d_head),
|
||||
dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
|
||||
@ -88,8 +88,7 @@ def _test_op(batch_size, n_heads, k_heads, q_seq_len, kv_seq_len, d_head, causal
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def _perf_triton_fn(*, device,
|
||||
dtype, batch_size, k_heads, n_groups, seq_len, d_head, causal, ):
|
||||
def _perf_triton_fn(*, device, dtype, batch_size, k_heads, n_groups, seq_len, d_head, causal):
|
||||
q = torch.randn((batch_size, k_heads * n_groups, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
|
||||
k = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
|
||||
v = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
|
||||
@ -97,8 +96,7 @@ def _perf_triton_fn(*, device,
|
||||
return lambda: attention(q, k, v, causal, sm_scale)
|
||||
|
||||
|
||||
def _perf_flash(*, batch_size, k_heads, n_groups, seq_len, d_head, causal, device,
|
||||
dtype):
|
||||
def _perf_flash(*, batch_size, k_heads, n_groups, seq_len, d_head, causal, device, dtype):
|
||||
q = torch.randn((batch_size, seq_len, k_heads * n_groups, d_head), dtype=dtype, device=device, requires_grad=True)
|
||||
k = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
|
||||
v = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
|
||||
@ -128,13 +126,13 @@ def _test():
|
||||
device = torch.device('cuda:0')
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
dtype = torch.bfloat16
|
||||
dtype = torch.float16
|
||||
|
||||
# only works on post-Ampere GPUs right now
|
||||
_test_op(1, 4, 1, 2048, 2048, 128, True, dtype=dtype, device=device)
|
||||
_test_op(16, 32, 8, 2048, 4096, 128, False, dtype=dtype, device=device)
|
||||
_test_op(16, 32, 8, 2001, 4001, 128, False, dtype=dtype, device=device)
|
||||
_test_op(4, 32, 8, 2048, 1024, 128, False, dtype=dtype, device=device)
|
||||
_test_op(4, 32, 8, 2048, 2048, 128, True, dtype=dtype, device=device)
|
||||
_test_op(4, 32, 8, 2001, 4001, 128, True, dtype=dtype, device=device)
|
||||
|
||||
_conf = {
|
||||
'batch_size': 16,
|
||||
|
Reference in New Issue
Block a user