1import triton
2
3import torch
4from labml import logger, monit
5from labml_nn.transformers.flash import attention
6
7HI_PRES_TORCH = torch.float3210@torch.no_grad()
11def _calc_abs_rel_error(a: torch.Tensor, b: torch.Tensor, atol=1e-2):
12    d = (a - b).abs()
13    max_abs = d.max()
14    d = (d - atol).clamp(min=0)
15    d = d / b.abs()
16    max_rel = d.max()
17
18    return max_abs.cpu().item(), max_rel.cpu().item()
19
20
21def _test_op(batch_size, n_heads, k_heads, q_seq_len, kv_seq_len, d_head, causal, dtype, device):
22    with monit.section(f'Init {q_seq_len} {kv_seq_len} {d_head}'):
23        torch.manual_seed(20)
24        q = (torch.empty((batch_size, n_heads, q_seq_len, d_head),
25                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
26        k = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
27                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
28        v = (torch.empty((batch_size, k_heads, kv_seq_len, d_head),
29                         dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
30        sm_scale = d_head ** -0.5
31        d_out = torch.randn_like(q)reference implementation
33        mask = torch.tril(torch.ones((q_seq_len, kv_seq_len), device=device, dtype=torch.bool))
34        torch.cuda.synchronize()
35
36    with monit.section('Pytorch'):
37        p = torch.matmul(q.view(batch_size, k_heads, -1, q_seq_len, d_head),
38                         k.transpose(2, 3)[:, :, None, :, :]) * sm_scale
39        if causal:
40            p[:, :, :, ~mask] = float("-inf")
41        p = torch.softmax(p.to(HI_PRES_TORCH), dim=-1).to(dtype)
42        ref_out = torch.matmul(p, v[:, :, None, :, :])
43        ref_out = ref_out.view(q.shape)
44        ref_out.backward(d_out)
45        ref_dv, v.grad = v.grad.clone(), None
46        ref_dk, k.grad = k.grad.clone(), None
47        ref_dq, q.grad = q.grad.clone(), None
48        torch.cuda.synchronize()
49
50    with monit.section('Triton'):
51        assert q.dtype == dtype
52        tri_out = attention(q, k, v, causal, sm_scale).to(dtype)
53        monit.progress(0.5)
54
55        tri_out.backward(d_out)
56        monit.progress(0.9)
57        tri_dv, v.grad = v.grad.clone(), None  # type: ignore
58        tri_dk, k.grad = k.grad.clone(), None  # type: ignore
59        tri_dq, q.grad = q.grad.clone(), None  # type: ignore
60        torch.cuda.synchronize()
61
62    with monit.section('Test') as s:compare
64        passed = True
65        if not torch.allclose(tri_out, ref_out, atol=1e-2, rtol=0.):
66            abs_err, rel_err = _calc_abs_rel_error(ref_out, tri_out)
67            logger.log(('[FAILED]', logger.Text.danger), f' Out mismatch {abs_err} {rel_err}')
68            passed = False
69        rtol = 1e-1
70        if not torch.allclose(tri_dq, ref_dq, atol=1e-2, rtol=rtol):
71            abs_err, rel_err = _calc_abs_rel_error(ref_dq, tri_dq)
72            logger.log(('[FAILED]', logger.Text.danger), f' dQ mismatch {abs_err} {rel_err}')
73            passed = False
74        if not torch.allclose(tri_dv, ref_dv, atol=1e-2, rtol=rtol):
75            abs_err, rel_err = _calc_abs_rel_error(ref_dv, tri_dv)
76            logger.log(('[FAILED]', logger.Text.danger), f' dV mismatch {abs_err} {rel_err}')
77            passed = False
78        if not torch.allclose(tri_dk, ref_dk, atol=1e-2, rtol=rtol):
79            abs_err, rel_err = _calc_abs_rel_error(ref_dk, tri_dk)
80            logger.log(('[FAILED]', logger.Text.danger), f' dK mismatch {abs_err} {rel_err}')
81            passed = False
82
83        if passed:
84            logger.log('[PASSED]', logger.Text.success)
85            s.success = True
86        else:
87            s.success = False
88        torch.cuda.synchronize()91def _perf_triton_fn(*, device, dtype, batch_size, k_heads, n_groups, seq_len, d_head, causal):
92    q = torch.randn((batch_size, k_heads * n_groups, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
93    k = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
94    v = torch.randn((batch_size, k_heads, seq_len, d_head), dtype=dtype, device=device, requires_grad=True)
95    sm_scale = d_head ** -0.5
96    return lambda: attention(q, k, v, causal, sm_scale)99def _perf_flash(*, batch_size, k_heads, n_groups, seq_len, d_head, causal, device, dtype):
100    q = torch.randn((batch_size, seq_len, k_heads * n_groups, d_head), dtype=dtype, device=device, requires_grad=True)
101    k = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
102    v = torch.randn((batch_size, seq_len, k_heads, d_head), dtype=dtype, device=device, requires_grad=True)
103    from flash_attn import flash_attn_func
104    return lambda: flash_attn_func(q, k, v, causal=causal)107def _perf_fn(name, fn, *, batch_size, k_heads, n_groups, seq_len, d_head, causal, is_bwd: bool):
108    if is_bwd:
109        o = fn()
110        do = torch.randn_like(o)
111        fn = lambda: o.backward(do, retain_graph=True)
112    ms = triton.testing.do_bench(fn)
113
114    flops_per_matmul = 2.0 * batch_size * k_heads * n_groups * seq_len * seq_len * d_head
115    total_flops = 2 * flops_per_matmul
116    if causal:
117        total_flops *= 0.5
118    if is_bwd:
119        total_flops *= 2.5  # 2.0(bwd) + 0.5(recompute)
120
121    tf_ps = total_flops * 1e-12 / (ms * 1e-3)
122    logger.log((f'{name}', logger.Text.key), ': ', f'{ms :,.1f}ms', ' ', f'{tf_ps :,.2f}TFps')125def _test():
126    device = torch.device('cuda:0')
127    torch.cuda.set_device(device)
128
129    dtype = torch.float16only works on post-Ampere GPUs right now
132    _test_op(1, 4, 1, 2048, 2048, 128, True, dtype=dtype, device=device)
133    _test_op(16, 32, 8, 2001, 4001, 128, False, dtype=dtype, device=device)
134    _test_op(4, 32, 8, 2048, 1024, 128, False, dtype=dtype, device=device)
135    _test_op(4, 32, 8, 2001, 4001, 128, True, dtype=dtype, device=device)
136
137    _conf = {
138        'batch_size': 16,
139        'k_heads': 8,
140        'n_groups': 4,
141        'seq_len': 2048,
142        'd_head': 128,
143    }
144
145    for _causal in [False, True]:
146        for is_bwd in [False, True]:
147            logger.log(f'{"Causal" if _causal else "Non-causal"} {" Backward" if is_bwd else ""}', logger.Text.title)
148            _perf_fn(f'flash', _perf_flash(causal=_causal, device=device, dtype=dtype, **_conf),
149                     is_bwd=is_bwd,
150                     causal=_causal, **_conf)
151            _perf_fn(f'triton', _perf_triton_fn(causal=_causal, device=device, dtype=dtype, **_conf),
152                     is_bwd=is_bwd,
153                     causal=_causal, **_conf)
154
155
156if __name__ == "__main__":
157    _test()