-
Notifications
You must be signed in to change notification settings - Fork 535
feat: Add k_scale and v_scale to persistent attention #1322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,6 +90,7 @@ def _run_attention( | |
num_kv_heads=1, | ||
num_qo_heads=1, | ||
head_dim=128, | ||
v_scale=None, | ||
layout="NHD", | ||
test_dtype=torch.bfloat16, | ||
logits_soft_cap=0.0, | ||
|
@@ -140,7 +141,7 @@ def _run_attention( | |
|
||
# --------- old scheduler --------- # | ||
wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( | ||
torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=dev), | ||
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=dev), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when running the test, the prefill kernel allocator hit OOM, which seems unrelated to this PR (also hit on main), so I increased the buffer size. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Np, please go ahead with it. |
||
kv_layout=layout, | ||
backend="fa2", | ||
) | ||
|
@@ -159,7 +160,7 @@ def _run_attention( | |
kv_data_type=test_dtype, | ||
logits_soft_cap=logits_soft_cap, | ||
) | ||
out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True) | ||
out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True, v_scale=v_scale) | ||
|
||
# --------- new / mixed scheduler --------- # | ||
wrapper = flashinfer.BatchAttention(kv_layout=layout) | ||
|
@@ -178,7 +179,9 @@ def _run_attention( | |
kv_data_type=test_dtype, | ||
logits_soft_cap=logits_soft_cap, | ||
) | ||
out_new, lse_new = wrapper.run(q, kv_data, logits_soft_cap=logits_soft_cap) | ||
out_new, lse_new = wrapper.run( | ||
q, kv_data, v_scale=v_scale, logits_soft_cap=logits_soft_cap | ||
) | ||
|
||
torch.cuda.synchronize() | ||
torch.testing.assert_close(out_old, out_new, rtol=1e-2, atol=1e-2) | ||
|
@@ -191,6 +194,7 @@ def _run_attention( | |
@pytest.mark.parametrize("num_kv_heads", [1, 4]) | ||
@pytest.mark.parametrize("gqa_group_size", [1, 4, 7, 8]) | ||
@pytest.mark.parametrize("head_dim", [64, 128, 256]) | ||
@pytest.mark.parametrize("v_scale", [2.0, None]) | ||
@pytest.mark.parametrize("causal", [False, True]) | ||
@pytest.mark.parametrize("layout", ["HND", "NHD"]) | ||
@pytest.mark.parametrize("test_dtype", [torch.bfloat16, torch.float16]) | ||
|
@@ -201,6 +205,7 @@ def test_batch_attention_correctness( | |
num_kv_heads, | ||
gqa_group_size, | ||
head_dim, | ||
v_scale, | ||
causal, | ||
layout, | ||
test_dtype, | ||
|
@@ -217,6 +222,7 @@ def test_batch_attention_correctness( | |
num_kv_heads=num_kv_heads, | ||
num_qo_heads=num_qo_heads, | ||
head_dim=head_dim, | ||
v_scale=v_scale, | ||
causal=causal, | ||
layout=layout, | ||
test_dtype=test_dtype, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add the option of making this value (
v_scale
) andsm_scale
on device to make sure we support dynamic scale in CUDAGraph mode (e.g. in #1630 ), we can leave it for future PRs anyways.