From 00b1eec47586694311afa35194c93a5ffac9c482 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 17 Sep 2025 14:51:29 +0000 Subject: [PATCH 1/5] upd --- include/flashinfer/attention/scheduler.cuh | 3 +- tests/test_batch_invariant_fa2.py | 215 +++++++++++++-------- 2 files changed, 134 insertions(+), 84 deletions(-) diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 4f888e716b..dd163bf272 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -599,9 +599,10 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, } o_indptr.push_back(o_indptr.back() + qo_len * num_chunks_kv); } - const size_t padded_batch_size = enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size; + printf("new_batch_size: %d, total_num_tiles_q: %d, padded_batch_size: %d, kv_chunk_size: %ld\n", + new_batch_size, total_num_tiles_q, padded_batch_size, kv_chunk_size); FLASHINFER_CHECK(new_batch_size <= padded_batch_size, "new batch size should not exceed padded batch size. If you are using fixed " "split size, please consider disabling cuda graph."); diff --git a/tests/test_batch_invariant_fa2.py b/tests/test_batch_invariant_fa2.py index 63e22f1354..160e7ede24 100644 --- a/tests/test_batch_invariant_fa2.py +++ b/tests/test_batch_invariant_fa2.py @@ -46,9 +46,10 @@ def warmup_jit(): yield -@pytest.mark.parametrize("batch_size", [5, 12]) -@pytest.mark.parametrize("invariant_bs", [4]) -@pytest.mark.parametrize("kv_len", [4096, 8192, 5000]) +@pytest.mark.parametrize("batch_size", [3, 4]) +@pytest.mark.parametrize("invariant_bs", [2]) +@pytest.mark.parametrize("kv_len", [4096, 5000]) +@pytest.mark.parametrize("qo_len", [128, 256]) @pytest.mark.parametrize("fixed_split_size", [2048]) @pytest.mark.parametrize("disable_split_kv", [True, False]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @@ -57,10 +58,11 @@ def warmup_jit(): @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) -def test_batch_decode_tensor_cores( +def test_batch_prefill_tensor_cores( batch_size: int, invariant_bs: int, kv_len: int, + qo_len: int, fixed_split_size: int, disable_split_kv: bool, page_size: int, @@ -72,7 +74,14 @@ def test_batch_decode_tensor_cores( ): num_qo_heads = num_kv_heads * group_size q = torch.randn( - batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 + batch_size * qo_len, + num_qo_heads, + head_dim, + device="cuda:0", + dtype=torch.float16, + ) + q_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len ) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size @@ -108,12 +117,15 @@ def test_batch_decode_tensor_cores( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" ) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + workspace_buffer = torch.empty( + 1024 * 1024 * 1024, dtype=torch.int8, device="cuda:0" + ) - wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, use_tensor_cores=True + wrapper_tcore = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout ) - wrapper_tensor_cores.plan( + wrapper_tcore.plan( + q_indptr, kv_indptr, kv_indices, kv_last_page_len, @@ -122,18 +134,19 @@ def test_batch_decode_tensor_cores( head_dim, page_size, pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, q_data_type=torch.float16, + kv_data_type=torch.float16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, ) - o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run( - q, kv_data, return_lse=True - ) + o_tensor_cores, lse_tensor_cores = wrapper_tcore.run(q, kv_data, return_lse=True) + # Test invariant batch size + q_indptr_invariant = q_indptr[: invariant_bs + 1] kv_indptr_invariant = kv_indptr[: invariant_bs + 1] kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] - wrapper_tensor_cores.plan( + wrapper_tcore.plan( + q_indptr_invariant, kv_indptr_invariant, kv_indices, kv_last_page_len_invariant, @@ -142,50 +155,81 @@ def test_batch_decode_tensor_cores( head_dim, page_size, pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, q_data_type=torch.float16, + kv_data_type=torch.float16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, ) - o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run( - q[:invariant_bs], kv_data, return_lse=True + o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( + q[: invariant_bs * qo_len], kv_data, return_lse=True ) - assert torch.equal(o_tensor_cores[:invariant_bs], o_tensor_cores_invariant) - assert torch.equal(lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant) + # Compare outputs for the invariant batch size + assert torch.equal( + o_tensor_cores[: invariant_bs * qo_len], o_tensor_cores_invariant + ) + assert torch.equal( + lse_tensor_cores[: invariant_bs * qo_len], lse_tensor_cores_invariant + ) -# test that without fixed split size, precision is different -# TODO: this works for the first 29 cases, but then fails with "illegal memory access"..? + if disable_split_kv: + # Test cuda graph + del o_tensor_cores_invariant, lse_tensor_cores_invariant + g_invariant = torch.cuda.CUDAGraph() + with torch.cuda.graph(g_invariant): + o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( + q[: invariant_bs * qo_len], kv_data, return_lse=True + ) + wrapper_tcore.plan( + q_indptr_invariant, + kv_indptr_invariant, + kv_indices, + kv_last_page_len_invariant, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode=pos_encoding_mode, + q_data_type=torch.float16, + kv_data_type=torch.float16, + disable_split_kv=True, + ) + g_invariant.replay() -# wrapper_tensor_cores.plan( -# kv_indptr, -# kv_indices, -# kv_last_page_len, -# num_qo_heads, -# num_kv_heads, -# head_dim, -# page_size, -# ) -# o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run( -# q[:invariant_bs], kv_data, return_lse=True -# ) -# try: -# torch.testing.assert_close( -# o_tensor_cores[:invariant_bs], o_tensor_cores_invariant, rtol=1e-7, atol=1e-7 -# ) -# torch.testing.assert_close( -# lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant, rtol=1e-7, atol=1e-7 -# ) -# except AssertionError: -# pass -# else: -# raise AssertionError("Precision is the same without fixed split size") + # capture for full batch + del o_tensor_cores, lse_tensor_cores + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + o_tensor_cores, lse_tensor_cores = wrapper_tcore.run( + q, kv_data, return_lse=True + ) + wrapper_tcore.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode=pos_encoding_mode, + q_data_type=torch.float16, + kv_data_type=torch.float16, + disable_split_kv=True, + ) + g.replay() + # compare outputs + assert torch.equal( + o_tensor_cores_invariant, o_tensor_cores[: invariant_bs * qo_len] + ) + assert torch.equal( + lse_tensor_cores_invariant, lse_tensor_cores[: invariant_bs * qo_len] + ) -@pytest.mark.parametrize("batch_size", [3, 4]) -@pytest.mark.parametrize("invariant_bs", [2]) -@pytest.mark.parametrize("kv_len", [4096, 5000]) -@pytest.mark.parametrize("qo_len", [128, 256]) +@pytest.mark.parametrize("batch_size", [5, 12]) +@pytest.mark.parametrize("invariant_bs", [4]) +@pytest.mark.parametrize("kv_len", [4096, 8192, 5000]) @pytest.mark.parametrize("fixed_split_size", [2048]) @pytest.mark.parametrize("disable_split_kv", [True, False]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @@ -194,11 +238,10 @@ def test_batch_decode_tensor_cores( @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) -def test_batch_prefill_tensor_cores( +def test_batch_decode_tensor_cores( batch_size: int, invariant_bs: int, kv_len: int, - qo_len: int, fixed_split_size: int, disable_split_kv: bool, page_size: int, @@ -210,14 +253,7 @@ def test_batch_prefill_tensor_cores( ): num_qo_heads = num_kv_heads * group_size q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - q_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len + batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 ) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size @@ -253,15 +289,12 @@ def test_batch_prefill_tensor_cores( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" ) - workspace_buffer = torch.empty( - 1024 * 1024 * 1024, dtype=torch.int8, device="cuda:0" - ) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") - wrapper_tensor_cores = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + wrapper_tcore = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, use_tensor_cores=True ) - wrapper_tensor_cores.plan( - q_indptr, + wrapper_tcore.plan( kv_indptr, kv_indices, kv_last_page_len, @@ -270,21 +303,16 @@ def test_batch_prefill_tensor_cores( head_dim, page_size, pos_encoding_mode=pos_encoding_mode, + data_type=torch.float16, q_data_type=torch.float16, - kv_data_type=torch.float16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, ) - o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run( - q, kv_data, return_lse=True - ) + o_tensor_cores, lse_tensor_cores = wrapper_tcore.run(q, kv_data, return_lse=True) - # Test invariant batch size - q_indptr_invariant = q_indptr[: invariant_bs + 1] kv_indptr_invariant = kv_indptr[: invariant_bs + 1] kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] - wrapper_tensor_cores.plan( - q_indptr_invariant, + wrapper_tcore.plan( kv_indptr_invariant, kv_indices, kv_last_page_len_invariant, @@ -293,19 +321,40 @@ def test_batch_prefill_tensor_cores( head_dim, page_size, pos_encoding_mode=pos_encoding_mode, + data_type=torch.float16, q_data_type=torch.float16, - kv_data_type=torch.float16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, ) - o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run( - q[: invariant_bs * qo_len], kv_data, return_lse=True + o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( + q[:invariant_bs], kv_data, return_lse=True ) + assert torch.equal(o_tensor_cores[:invariant_bs], o_tensor_cores_invariant) + assert torch.equal(lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant) - # Compare outputs for the invariant batch size - assert torch.equal( - o_tensor_cores[: invariant_bs * qo_len], o_tensor_cores_invariant - ) - assert torch.equal( - lse_tensor_cores[: invariant_bs * qo_len], lse_tensor_cores_invariant - ) + # test that without fixed split size, precision is different + # TODO: this works for the first 29 cases, but then fails with "illegal memory access"..? + + # wrapper_tcore.plan( + # kv_indptr, + # kv_indices, + # kv_last_page_len, + # num_qo_heads, + # num_kv_heads, + # head_dim, + # page_size, + # ) + # o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( + # q[:invariant_bs], kv_data, return_lse=True + # ) + # try: + # torch.testing.assert_close( + # o_tensor_cores[:invariant_bs], o_tensor_cores_invariant, rtol=1e-7, atol=1e-7 + # ) + # torch.testing.assert_close( + # lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant, rtol=1e-7, atol=1e-7 + # ) + # except AssertionError: + # pass + # else: + # raise AssertionError("Precision is the same without fixed split size") From 7019ae70b6252658ed451c993fc9de5df7e14d8b Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 17 Sep 2025 14:53:49 +0000 Subject: [PATCH 2/5] upd --- tests/test_batch_invariant_fa2.py | 269 +++++++++++++++--------------- 1 file changed, 136 insertions(+), 133 deletions(-) diff --git a/tests/test_batch_invariant_fa2.py b/tests/test_batch_invariant_fa2.py index 160e7ede24..7959c12b78 100644 --- a/tests/test_batch_invariant_fa2.py +++ b/tests/test_batch_invariant_fa2.py @@ -46,6 +46,142 @@ def warmup_jit(): yield +# @pytest.mark.parametrize("batch_size", [5, 12]) +# @pytest.mark.parametrize("invariant_bs", [4]) +# @pytest.mark.parametrize("kv_len", [4096, 8192, 5000]) +# @pytest.mark.parametrize("fixed_split_size", [2048]) +# @pytest.mark.parametrize("disable_split_kv", [True, False]) +# @pytest.mark.parametrize("page_size", [1, 8, 16]) +# @pytest.mark.parametrize("num_kv_heads", [4]) +# @pytest.mark.parametrize("group_size", [1, 4, 8]) +# @pytest.mark.parametrize("head_dim", [128, 256]) +# @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +# @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) +# def test_batch_decode_tensor_cores( +# batch_size: int, +# invariant_bs: int, +# kv_len: int, +# fixed_split_size: int, +# disable_split_kv: bool, +# page_size: int, +# num_kv_heads: int, +# group_size: int, +# head_dim: int, +# kv_layout: str, +# pos_encoding_mode: str, +# ): +# num_qo_heads = num_kv_heads * group_size +# q = torch.randn( +# batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 +# ) +# num_pages_per_seq = (kv_len + page_size - 1) // page_size +# total_num_pages = num_pages_per_seq * batch_size +# kv_data = ( +# torch.randn( +# total_num_pages, +# 2, +# num_kv_heads, +# page_size, +# head_dim, +# device="cuda:0", +# dtype=torch.float16, +# ) +# / 10 +# if kv_layout == "HND" +# else torch.randn( +# total_num_pages, +# 2, +# page_size, +# num_kv_heads, +# head_dim, +# device="cuda:0", +# dtype=torch.float16, +# ) +# / 10 +# ) +# kv_indptr = ( +# torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) +# * num_pages_per_seq +# ) +# kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) +# kv_last_page_len = torch.full( +# (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" +# ) + +# workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + +# wrapper_tcore = flashinfer.BatchDecodeWithPagedKVCacheWrapper( +# workspace_buffer, kv_layout, use_tensor_cores=True +# ) +# wrapper_tcore.plan( +# kv_indptr, +# kv_indices, +# kv_last_page_len, +# num_qo_heads, +# num_kv_heads, +# head_dim, +# page_size, +# pos_encoding_mode=pos_encoding_mode, +# data_type=torch.float16, +# q_data_type=torch.float16, +# fixed_split_size=fixed_split_size if not disable_split_kv else None, +# disable_split_kv=disable_split_kv, +# ) +# o_tensor_cores, lse_tensor_cores = wrapper_tcore.run( +# q, kv_data, return_lse=True +# ) + +# kv_indptr_invariant = kv_indptr[: invariant_bs + 1] +# kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] +# wrapper_tcore.plan( +# kv_indptr_invariant, +# kv_indices, +# kv_last_page_len_invariant, +# num_qo_heads, +# num_kv_heads, +# head_dim, +# page_size, +# pos_encoding_mode=pos_encoding_mode, +# data_type=torch.float16, +# q_data_type=torch.float16, +# fixed_split_size=fixed_split_size if not disable_split_kv else None, +# disable_split_kv=disable_split_kv, +# ) +# o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( +# q[:invariant_bs], kv_data, return_lse=True +# ) +# assert torch.equal(o_tensor_cores[:invariant_bs], o_tensor_cores_invariant) +# assert torch.equal(lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant) + + +# test that without fixed split size, precision is different +# TODO: this works for the first 29 cases, but then fails with "illegal memory access"..? + +# wrapper_tcore.plan( +# kv_indptr, +# kv_indices, +# kv_last_page_len, +# num_qo_heads, +# num_kv_heads, +# head_dim, +# page_size, +# ) +# o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( +# q[:invariant_bs], kv_data, return_lse=True +# ) +# try: +# torch.testing.assert_close( +# o_tensor_cores[:invariant_bs], o_tensor_cores_invariant, rtol=1e-7, atol=1e-7 +# ) +# torch.testing.assert_close( +# lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant, rtol=1e-7, atol=1e-7 +# ) +# except AssertionError: +# pass +# else: +# raise AssertionError("Precision is the same without fixed split size") + + @pytest.mark.parametrize("batch_size", [3, 4]) @pytest.mark.parametrize("invariant_bs", [2]) @pytest.mark.parametrize("kv_len", [4096, 5000]) @@ -225,136 +361,3 @@ def test_batch_prefill_tensor_cores( assert torch.equal( lse_tensor_cores_invariant, lse_tensor_cores[: invariant_bs * qo_len] ) - - -@pytest.mark.parametrize("batch_size", [5, 12]) -@pytest.mark.parametrize("invariant_bs", [4]) -@pytest.mark.parametrize("kv_len", [4096, 8192, 5000]) -@pytest.mark.parametrize("fixed_split_size", [2048]) -@pytest.mark.parametrize("disable_split_kv", [True, False]) -@pytest.mark.parametrize("page_size", [1, 8, 16]) -@pytest.mark.parametrize("num_kv_heads", [4]) -@pytest.mark.parametrize("group_size", [1, 4, 8]) -@pytest.mark.parametrize("head_dim", [128, 256]) -@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) -@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) -def test_batch_decode_tensor_cores( - batch_size: int, - invariant_bs: int, - kv_len: int, - fixed_split_size: int, - disable_split_kv: bool, - page_size: int, - num_kv_heads: int, - group_size: int, - head_dim: int, - kv_layout: str, - pos_encoding_mode: str, -): - num_qo_heads = num_kv_heads * group_size - q = torch.randn( - batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) - num_pages_per_seq = (kv_len + page_size - 1) // page_size - total_num_pages = num_pages_per_seq * batch_size - kv_data = ( - torch.randn( - total_num_pages, - 2, - num_kv_heads, - page_size, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - / 10 - if kv_layout == "HND" - else torch.randn( - total_num_pages, - 2, - page_size, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - / 10 - ) - kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq - ) - kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" - ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") - - wrapper_tcore = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, use_tensor_cores=True - ) - wrapper_tcore.plan( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, - q_data_type=torch.float16, - fixed_split_size=fixed_split_size if not disable_split_kv else None, - disable_split_kv=disable_split_kv, - ) - o_tensor_cores, lse_tensor_cores = wrapper_tcore.run(q, kv_data, return_lse=True) - - kv_indptr_invariant = kv_indptr[: invariant_bs + 1] - kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] - wrapper_tcore.plan( - kv_indptr_invariant, - kv_indices, - kv_last_page_len_invariant, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, - q_data_type=torch.float16, - fixed_split_size=fixed_split_size if not disable_split_kv else None, - disable_split_kv=disable_split_kv, - ) - o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( - q[:invariant_bs], kv_data, return_lse=True - ) - assert torch.equal(o_tensor_cores[:invariant_bs], o_tensor_cores_invariant) - assert torch.equal(lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant) - - # test that without fixed split size, precision is different - # TODO: this works for the first 29 cases, but then fails with "illegal memory access"..? - - # wrapper_tcore.plan( - # kv_indptr, - # kv_indices, - # kv_last_page_len, - # num_qo_heads, - # num_kv_heads, - # head_dim, - # page_size, - # ) - # o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( - # q[:invariant_bs], kv_data, return_lse=True - # ) - # try: - # torch.testing.assert_close( - # o_tensor_cores[:invariant_bs], o_tensor_cores_invariant, rtol=1e-7, atol=1e-7 - # ) - # torch.testing.assert_close( - # lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant, rtol=1e-7, atol=1e-7 - # ) - # except AssertionError: - # pass - # else: - # raise AssertionError("Precision is the same without fixed split size") From 6e29284bbfc9208826ae9be370faf1c4b9fa6e75 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 17 Sep 2025 14:54:14 +0000 Subject: [PATCH 3/5] upd --- tests/test_batch_invariant_fa2.py | 204 +++++++++++++++--------------- 1 file changed, 101 insertions(+), 103 deletions(-) diff --git a/tests/test_batch_invariant_fa2.py b/tests/test_batch_invariant_fa2.py index 7959c12b78..63c9ede26e 100644 --- a/tests/test_batch_invariant_fa2.py +++ b/tests/test_batch_invariant_fa2.py @@ -46,112 +46,110 @@ def warmup_jit(): yield -# @pytest.mark.parametrize("batch_size", [5, 12]) -# @pytest.mark.parametrize("invariant_bs", [4]) -# @pytest.mark.parametrize("kv_len", [4096, 8192, 5000]) -# @pytest.mark.parametrize("fixed_split_size", [2048]) -# @pytest.mark.parametrize("disable_split_kv", [True, False]) -# @pytest.mark.parametrize("page_size", [1, 8, 16]) -# @pytest.mark.parametrize("num_kv_heads", [4]) -# @pytest.mark.parametrize("group_size", [1, 4, 8]) -# @pytest.mark.parametrize("head_dim", [128, 256]) -# @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) -# @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) -# def test_batch_decode_tensor_cores( -# batch_size: int, -# invariant_bs: int, -# kv_len: int, -# fixed_split_size: int, -# disable_split_kv: bool, -# page_size: int, -# num_kv_heads: int, -# group_size: int, -# head_dim: int, -# kv_layout: str, -# pos_encoding_mode: str, -# ): -# num_qo_heads = num_kv_heads * group_size -# q = torch.randn( -# batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 -# ) -# num_pages_per_seq = (kv_len + page_size - 1) // page_size -# total_num_pages = num_pages_per_seq * batch_size -# kv_data = ( -# torch.randn( -# total_num_pages, -# 2, -# num_kv_heads, -# page_size, -# head_dim, -# device="cuda:0", -# dtype=torch.float16, -# ) -# / 10 -# if kv_layout == "HND" -# else torch.randn( -# total_num_pages, -# 2, -# page_size, -# num_kv_heads, -# head_dim, -# device="cuda:0", -# dtype=torch.float16, -# ) -# / 10 -# ) -# kv_indptr = ( -# torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) -# * num_pages_per_seq -# ) -# kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) -# kv_last_page_len = torch.full( -# (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" -# ) +@pytest.mark.parametrize("batch_size", [5, 12]) +@pytest.mark.parametrize("invariant_bs", [4]) +@pytest.mark.parametrize("kv_len", [4096, 8192, 5000]) +@pytest.mark.parametrize("fixed_split_size", [2048]) +@pytest.mark.parametrize("disable_split_kv", [True, False]) +@pytest.mark.parametrize("page_size", [1, 8, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("group_size", [1, 4, 8]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) +def test_batch_decode_tensor_cores( + batch_size: int, + invariant_bs: int, + kv_len: int, + fixed_split_size: int, + disable_split_kv: bool, + page_size: int, + num_kv_heads: int, + group_size: int, + head_dim: int, + kv_layout: str, + pos_encoding_mode: str, +): + num_qo_heads = num_kv_heads * group_size + q = torch.randn( + batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 + ) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = ( + torch.randn( + total_num_pages, + 2, + num_kv_heads, + page_size, + head_dim, + device="cuda:0", + dtype=torch.float16, + ) + / 10 + if kv_layout == "HND" + else torch.randn( + total_num_pages, + 2, + page_size, + num_kv_heads, + head_dim, + device="cuda:0", + dtype=torch.float16, + ) + / 10 + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + ) -# workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") -# wrapper_tcore = flashinfer.BatchDecodeWithPagedKVCacheWrapper( -# workspace_buffer, kv_layout, use_tensor_cores=True -# ) -# wrapper_tcore.plan( -# kv_indptr, -# kv_indices, -# kv_last_page_len, -# num_qo_heads, -# num_kv_heads, -# head_dim, -# page_size, -# pos_encoding_mode=pos_encoding_mode, -# data_type=torch.float16, -# q_data_type=torch.float16, -# fixed_split_size=fixed_split_size if not disable_split_kv else None, -# disable_split_kv=disable_split_kv, -# ) -# o_tensor_cores, lse_tensor_cores = wrapper_tcore.run( -# q, kv_data, return_lse=True -# ) + wrapper_tcore = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, use_tensor_cores=True + ) + wrapper_tcore.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode=pos_encoding_mode, + data_type=torch.float16, + q_data_type=torch.float16, + fixed_split_size=fixed_split_size if not disable_split_kv else None, + disable_split_kv=disable_split_kv, + ) + o_tensor_cores, lse_tensor_cores = wrapper_tcore.run(q, kv_data, return_lse=True) -# kv_indptr_invariant = kv_indptr[: invariant_bs + 1] -# kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] -# wrapper_tcore.plan( -# kv_indptr_invariant, -# kv_indices, -# kv_last_page_len_invariant, -# num_qo_heads, -# num_kv_heads, -# head_dim, -# page_size, -# pos_encoding_mode=pos_encoding_mode, -# data_type=torch.float16, -# q_data_type=torch.float16, -# fixed_split_size=fixed_split_size if not disable_split_kv else None, -# disable_split_kv=disable_split_kv, -# ) -# o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( -# q[:invariant_bs], kv_data, return_lse=True -# ) -# assert torch.equal(o_tensor_cores[:invariant_bs], o_tensor_cores_invariant) -# assert torch.equal(lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant) + kv_indptr_invariant = kv_indptr[: invariant_bs + 1] + kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] + wrapper_tcore.plan( + kv_indptr_invariant, + kv_indices, + kv_last_page_len_invariant, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode=pos_encoding_mode, + data_type=torch.float16, + q_data_type=torch.float16, + fixed_split_size=fixed_split_size if not disable_split_kv else None, + disable_split_kv=disable_split_kv, + ) + o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( + q[:invariant_bs], kv_data, return_lse=True + ) + assert torch.equal(o_tensor_cores[:invariant_bs], o_tensor_cores_invariant) + assert torch.equal(lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant) # test that without fixed split size, precision is different From 3a712a34e30825123e31de6ef6085187a22103f5 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 18 Sep 2025 01:06:33 +0000 Subject: [PATCH 4/5] add missing plan before capture --- include/flashinfer/attention/scheduler.cuh | 2 - tests/test_batch_invariant_fa2.py | 134 +++++++++------------ 2 files changed, 55 insertions(+), 81 deletions(-) diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index dd163bf272..ea83b25850 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -601,8 +601,6 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, } const size_t padded_batch_size = enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size; - printf("new_batch_size: %d, total_num_tiles_q: %d, padded_batch_size: %d, kv_chunk_size: %ld\n", - new_batch_size, total_num_tiles_q, padded_batch_size, kv_chunk_size); FLASHINFER_CHECK(new_batch_size <= padded_batch_size, "new batch size should not exceed padded batch size. If you are using fixed " "split size, please consider disabling cuda graph."); diff --git a/tests/test_batch_invariant_fa2.py b/tests/test_batch_invariant_fa2.py index 63c9ede26e..76b65932e8 100644 --- a/tests/test_batch_invariant_fa2.py +++ b/tests/test_batch_invariant_fa2.py @@ -145,11 +145,11 @@ def test_batch_decode_tensor_cores( fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, ) - o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( + o_invariant, lse_invariant = wrapper_tcore.run( q[:invariant_bs], kv_data, return_lse=True ) - assert torch.equal(o_tensor_cores[:invariant_bs], o_tensor_cores_invariant) - assert torch.equal(lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant) + assert torch.equal(o_tensor_cores[:invariant_bs], o_invariant) + assert torch.equal(lse_tensor_cores[:invariant_bs], lse_invariant) # test that without fixed split size, precision is different @@ -164,15 +164,15 @@ def test_batch_decode_tensor_cores( # head_dim, # page_size, # ) -# o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( +# o_invariant, lse_invariant = wrapper_tcore.run( # q[:invariant_bs], kv_data, return_lse=True # ) # try: # torch.testing.assert_close( -# o_tensor_cores[:invariant_bs], o_tensor_cores_invariant, rtol=1e-7, atol=1e-7 +# o_tensor_cores[:invariant_bs], o_invariant, rtol=1e-7, atol=1e-7 # ) # torch.testing.assert_close( -# lse_tensor_cores[:invariant_bs], lse_tensor_cores_invariant, rtol=1e-7, atol=1e-7 +# lse_tensor_cores[:invariant_bs], lse_invariant, rtol=1e-7, atol=1e-7 # ) # except AssertionError: # pass @@ -258,104 +258,80 @@ def test_batch_prefill_tensor_cores( wrapper_tcore = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) - wrapper_tcore.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode=pos_encoding_mode, - q_data_type=torch.float16, - kv_data_type=torch.float16, - fixed_split_size=fixed_split_size if not disable_split_kv else None, - disable_split_kv=disable_split_kv, - ) + + def default_plan(): + wrapper_tcore.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode=pos_encoding_mode, + q_data_type=torch.float16, + kv_data_type=torch.float16, + fixed_split_size=fixed_split_size if not disable_split_kv else None, + disable_split_kv=disable_split_kv, + ) + + def invariant_plan(): + wrapper_tcore.plan( + q_indptr_invariant, + kv_indptr_invariant, + kv_indices, + kv_last_page_len_invariant, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode=pos_encoding_mode, + q_data_type=torch.float16, + kv_data_type=torch.float16, + fixed_split_size=fixed_split_size if not disable_split_kv else None, + disable_split_kv=disable_split_kv, + ) + + default_plan() o_tensor_cores, lse_tensor_cores = wrapper_tcore.run(q, kv_data, return_lse=True) # Test invariant batch size q_indptr_invariant = q_indptr[: invariant_bs + 1] kv_indptr_invariant = kv_indptr[: invariant_bs + 1] kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] - wrapper_tcore.plan( - q_indptr_invariant, - kv_indptr_invariant, - kv_indices, - kv_last_page_len_invariant, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode=pos_encoding_mode, - q_data_type=torch.float16, - kv_data_type=torch.float16, - fixed_split_size=fixed_split_size if not disable_split_kv else None, - disable_split_kv=disable_split_kv, - ) - o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( + + invariant_plan() + o_invariant, lse_invariant = wrapper_tcore.run( q[: invariant_bs * qo_len], kv_data, return_lse=True ) # Compare outputs for the invariant batch size - assert torch.equal( - o_tensor_cores[: invariant_bs * qo_len], o_tensor_cores_invariant - ) - assert torch.equal( - lse_tensor_cores[: invariant_bs * qo_len], lse_tensor_cores_invariant - ) + assert torch.equal(o_tensor_cores[: invariant_bs * qo_len], o_invariant) + assert torch.equal(lse_tensor_cores[: invariant_bs * qo_len], lse_invariant) if disable_split_kv: # Test cuda graph - del o_tensor_cores_invariant, lse_tensor_cores_invariant + del o_invariant, lse_invariant g_invariant = torch.cuda.CUDAGraph() + invariant_plan() with torch.cuda.graph(g_invariant): - o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tcore.run( + o_invariant, lse_invariant = wrapper_tcore.run( q[: invariant_bs * qo_len], kv_data, return_lse=True ) - wrapper_tcore.plan( - q_indptr_invariant, - kv_indptr_invariant, - kv_indices, - kv_last_page_len_invariant, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode=pos_encoding_mode, - q_data_type=torch.float16, - kv_data_type=torch.float16, - disable_split_kv=True, - ) + invariant_plan() g_invariant.replay() # capture for full batch del o_tensor_cores, lse_tensor_cores g = torch.cuda.CUDAGraph() + default_plan() with torch.cuda.graph(g): o_tensor_cores, lse_tensor_cores = wrapper_tcore.run( q, kv_data, return_lse=True ) - wrapper_tcore.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode=pos_encoding_mode, - q_data_type=torch.float16, - kv_data_type=torch.float16, - disable_split_kv=True, - ) + default_plan() g.replay() # compare outputs - assert torch.equal( - o_tensor_cores_invariant, o_tensor_cores[: invariant_bs * qo_len] - ) - assert torch.equal( - lse_tensor_cores_invariant, lse_tensor_cores[: invariant_bs * qo_len] - ) + assert torch.equal(o_invariant, o_tensor_cores[: invariant_bs * qo_len]) + assert torch.equal(lse_invariant, lse_tensor_cores[: invariant_bs * qo_len]) From 83b4abd6446d09ef5c1703dba2f78515adc92065 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sat, 20 Sep 2025 03:23:12 +0000 Subject: [PATCH 5/5] add chunked prefill tests (not passing now, can leave for future pr) --- tests/test_batch_invariant_fa2.py | 292 ++++++++++++++++++------------ 1 file changed, 176 insertions(+), 116 deletions(-) diff --git a/tests/test_batch_invariant_fa2.py b/tests/test_batch_invariant_fa2.py index 76b65932e8..6928895a8e 100644 --- a/tests/test_batch_invariant_fa2.py +++ b/tests/test_batch_invariant_fa2.py @@ -25,16 +25,16 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes + [torch.bfloat16], # q_dtypes + [torch.bfloat16], # kv_dtypes [64, 128, 256], # head_dims [0, 1], # pos_encoding_modes [False], # use_sliding_windows [False], # use_logits_soft_caps ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes + [torch.bfloat16], # q_dtypes + [torch.bfloat16], # kv_dtypes [64, 128, 256], # head_dims [0, 1], # pos_encoding_modes [False], # use_sliding_windows @@ -46,110 +46,110 @@ def warmup_jit(): yield -@pytest.mark.parametrize("batch_size", [5, 12]) -@pytest.mark.parametrize("invariant_bs", [4]) -@pytest.mark.parametrize("kv_len", [4096, 8192, 5000]) -@pytest.mark.parametrize("fixed_split_size", [2048]) -@pytest.mark.parametrize("disable_split_kv", [True, False]) -@pytest.mark.parametrize("page_size", [1, 8, 16]) -@pytest.mark.parametrize("num_kv_heads", [4]) -@pytest.mark.parametrize("group_size", [1, 4, 8]) -@pytest.mark.parametrize("head_dim", [128, 256]) -@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) -@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) -def test_batch_decode_tensor_cores( - batch_size: int, - invariant_bs: int, - kv_len: int, - fixed_split_size: int, - disable_split_kv: bool, - page_size: int, - num_kv_heads: int, - group_size: int, - head_dim: int, - kv_layout: str, - pos_encoding_mode: str, -): - num_qo_heads = num_kv_heads * group_size - q = torch.randn( - batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) - num_pages_per_seq = (kv_len + page_size - 1) // page_size - total_num_pages = num_pages_per_seq * batch_size - kv_data = ( - torch.randn( - total_num_pages, - 2, - num_kv_heads, - page_size, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - / 10 - if kv_layout == "HND" - else torch.randn( - total_num_pages, - 2, - page_size, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - / 10 - ) - kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq - ) - kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" - ) +# @pytest.mark.parametrize("batch_size", [5, 12]) +# @pytest.mark.parametrize("invariant_bs", [4]) +# @pytest.mark.parametrize("kv_len", [4096, 8192, 5000]) +# @pytest.mark.parametrize("fixed_split_size", [2048]) +# @pytest.mark.parametrize("disable_split_kv", [True, False]) +# @pytest.mark.parametrize("page_size", [1, 8, 16]) +# @pytest.mark.parametrize("num_kv_heads", [4]) +# @pytest.mark.parametrize("group_size", [1, 4, 8]) +# @pytest.mark.parametrize("head_dim", [128, 256]) +# @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +# @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) +# def test_batch_decode_tensor_cores( +# batch_size: int, +# invariant_bs: int, +# kv_len: int, +# fixed_split_size: int, +# disable_split_kv: bool, +# page_size: int, +# num_kv_heads: int, +# group_size: int, +# head_dim: int, +# kv_layout: str, +# pos_encoding_mode: str, +# ): +# num_qo_heads = num_kv_heads * group_size +# q = torch.randn( +# batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.bfloat16 +# ) +# num_pages_per_seq = (kv_len + page_size - 1) // page_size +# total_num_pages = num_pages_per_seq * batch_size +# kv_data = ( +# torch.randn( +# total_num_pages, +# 2, +# num_kv_heads, +# page_size, +# head_dim, +# device="cuda:0", +# dtype=torch.bfloat16, +# ) +# / 10 +# if kv_layout == "HND" +# else torch.randn( +# total_num_pages, +# 2, +# page_size, +# num_kv_heads, +# head_dim, +# device="cuda:0", +# dtype=torch.bfloat16, +# ) +# / 10 +# ) +# kv_indptr = ( +# torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) +# * num_pages_per_seq +# ) +# kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) +# kv_last_page_len = torch.full( +# (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" +# ) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") +# workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") - wrapper_tcore = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, use_tensor_cores=True - ) - wrapper_tcore.plan( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, - q_data_type=torch.float16, - fixed_split_size=fixed_split_size if not disable_split_kv else None, - disable_split_kv=disable_split_kv, - ) - o_tensor_cores, lse_tensor_cores = wrapper_tcore.run(q, kv_data, return_lse=True) +# wrapper_tcore = flashinfer.BatchDecodeWithPagedKVCacheWrapper( +# workspace_buffer, kv_layout, use_tensor_cores=True +# ) +# wrapper_tcore.plan( +# kv_indptr, +# kv_indices, +# kv_last_page_len, +# num_qo_heads, +# num_kv_heads, +# head_dim, +# page_size, +# pos_encoding_mode=pos_encoding_mode, +# data_type=torch.bfloat16, +# q_data_type=torch.bfloat16, +# fixed_split_size=fixed_split_size if not disable_split_kv else None, +# disable_split_kv=disable_split_kv, +# ) +# o_tensor_cores, lse_tensor_cores = wrapper_tcore.run(q, kv_data, return_lse=True) - kv_indptr_invariant = kv_indptr[: invariant_bs + 1] - kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] - wrapper_tcore.plan( - kv_indptr_invariant, - kv_indices, - kv_last_page_len_invariant, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, - q_data_type=torch.float16, - fixed_split_size=fixed_split_size if not disable_split_kv else None, - disable_split_kv=disable_split_kv, - ) - o_invariant, lse_invariant = wrapper_tcore.run( - q[:invariant_bs], kv_data, return_lse=True - ) - assert torch.equal(o_tensor_cores[:invariant_bs], o_invariant) - assert torch.equal(lse_tensor_cores[:invariant_bs], lse_invariant) +# kv_indptr_invariant = kv_indptr[: invariant_bs + 1] +# kv_last_page_len_invariant = kv_last_page_len[:invariant_bs] +# wrapper_tcore.plan( +# kv_indptr_invariant, +# kv_indices, +# kv_last_page_len_invariant, +# num_qo_heads, +# num_kv_heads, +# head_dim, +# page_size, +# pos_encoding_mode=pos_encoding_mode, +# data_type=torch.bfloat16, +# q_data_type=torch.bfloat16, +# fixed_split_size=fixed_split_size if not disable_split_kv else None, +# disable_split_kv=disable_split_kv, +# ) +# o_invariant, lse_invariant = wrapper_tcore.run( +# q[:invariant_bs], kv_data, return_lse=True +# ) +# assert torch.equal(o_tensor_cores[:invariant_bs], o_invariant) +# assert torch.equal(lse_tensor_cores[:invariant_bs], lse_invariant) # test that without fixed split size, precision is different @@ -183,12 +183,13 @@ def test_batch_decode_tensor_cores( @pytest.mark.parametrize("batch_size", [3, 4]) @pytest.mark.parametrize("invariant_bs", [2]) @pytest.mark.parametrize("kv_len", [4096, 5000]) -@pytest.mark.parametrize("qo_len", [128, 256]) -@pytest.mark.parametrize("fixed_split_size", [2048]) +# @pytest.mark.parametrize("qo_len", [128, 256]) +@pytest.mark.parametrize("qo_len", [2048]) +@pytest.mark.parametrize("fixed_split_size", [4096]) @pytest.mark.parametrize("disable_split_kv", [True, False]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) -@pytest.mark.parametrize("group_size", [1, 4, 8]) +@pytest.mark.parametrize("group_size", [1, 4]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @@ -212,7 +213,7 @@ def test_batch_prefill_tensor_cores( num_qo_heads, head_dim, device="cuda:0", - dtype=torch.float16, + dtype=torch.bfloat16, ) q_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len @@ -227,7 +228,7 @@ def test_batch_prefill_tensor_cores( page_size, head_dim, device="cuda:0", - dtype=torch.float16, + dtype=torch.bfloat16, ) / 10 if kv_layout == "HND" @@ -238,7 +239,7 @@ def test_batch_prefill_tensor_cores( num_kv_heads, head_dim, device="cuda:0", - dtype=torch.float16, + dtype=torch.bfloat16, ) / 10 ) @@ -252,7 +253,7 @@ def test_batch_prefill_tensor_cores( ) workspace_buffer = torch.empty( - 1024 * 1024 * 1024, dtype=torch.int8, device="cuda:0" + 2048 * 1024 * 1024, dtype=torch.int8, device="cuda:0" ) wrapper_tcore = flashinfer.BatchPrefillWithPagedKVCacheWrapper( @@ -270,8 +271,8 @@ def default_plan(): head_dim, page_size, pos_encoding_mode=pos_encoding_mode, - q_data_type=torch.float16, - kv_data_type=torch.float16, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, ) @@ -287,8 +288,8 @@ def invariant_plan(): head_dim, page_size, pos_encoding_mode=pos_encoding_mode, - q_data_type=torch.float16, - kv_data_type=torch.float16, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, ) @@ -335,3 +336,62 @@ def invariant_plan(): # compare outputs assert torch.equal(o_invariant, o_tensor_cores[: invariant_bs * qo_len]) assert torch.equal(lse_invariant, lse_tensor_cores[: invariant_bs * qo_len]) + + # Test chunked prefill with 1024 chunk size using invariant batch size + chunk_size = 1024 + chunked_outputs = [] + chunked_lses = [] + + for chunk_start in range(0, qo_len, chunk_size): + chunk_end = min(chunk_start + chunk_size, qo_len) + current_chunk_size = chunk_end - chunk_start + + # Create chunked q_indptr for invariant batch size + q_indptr_chunk = ( + torch.arange(0, invariant_bs + 1, device="cuda:0", dtype=torch.int32) + * current_chunk_size + ) + + # Extract chunked q data for invariant batch size + q_chunk = torch.empty( + invariant_bs * current_chunk_size, + num_qo_heads, + head_dim, + device="cuda:0", + dtype=torch.bfloat16, + ) + for i in range(invariant_bs): + start_idx = i * qo_len + chunk_start + end_idx = i * qo_len + chunk_end + chunk_start_idx = i * current_chunk_size + chunk_end_idx = (i + 1) * current_chunk_size + q_chunk[chunk_start_idx:chunk_end_idx] = q[start_idx:end_idx] + + # Plan and run for this chunk using invariant batch size + wrapper_tcore.plan( + q_indptr_chunk, + kv_indptr_invariant, + kv_indices, + kv_last_page_len_invariant, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode=pos_encoding_mode, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + fixed_split_size=fixed_split_size if not disable_split_kv else None, + disable_split_kv=disable_split_kv, + ) + o_chunk, lse_chunk = wrapper_tcore.run(q_chunk, kv_data, return_lse=True) + + chunked_outputs.append(o_chunk) + chunked_lses.append(lse_chunk) + + # Concatenate all chunked results + o_chunked = torch.cat(chunked_outputs, dim=0) + lse_chunked = torch.cat(chunked_lses, dim=0) + + # Compare chunked results with invariant results + assert torch.equal(o_tensor_cores[: invariant_bs * qo_len], o_chunked) + assert torch.equal(lse_tensor_cores[: invariant_bs * qo_len], lse_chunked)