Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion csrc/batch_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size, double sm_scale,
int64_t page_size,
double v_scale, // must use double due to pytorch binding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add the option of making this value (v_scale) and sm_scale on device to make sure we support dynamic scale in CUDAGraph mode (e.g. in #1630 ), we can leave it for future PRs anyways.

double sm_scale,
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
HolisticPlanInfo<2> plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));
Expand Down Expand Up @@ -171,6 +173,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
params[i].v_stride_n = v_stride_n;

params[i].sm_scale = sm_scale;
params[i].v_scale = v_scale;
params[i].logits_soft_cap = logits_soft_cap;
// NOTE(Wenxuan) directly using the additional_params_decl from generate_additional_params
// will be problematic because of the params[i]
Expand Down
3 changes: 2 additions & 1 deletion csrc/batch_attention_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ struct PersistentParams {
uint32_t v_stride_n;

float sm_scale;
double logits_soft_cap;
float logits_soft_cap;
float v_scale;
{{ additional_params_decl }}

PROFILER_PARAMS_DECL
Expand Down
2 changes: 1 addition & 1 deletion csrc/batch_attention_jit_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size, double sm_scale,
int64_t page_size, double v_scale, double sm_scale,
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
Expand Down
16 changes: 11 additions & 5 deletions flashinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def plan(
self._num_qo_heads = num_qo_heads
self._num_kv_heads = num_kv_heads
self._page_size = page_size
self._sm_scale = sm_scale
self._use_profiler = use_profiler

# No addtional buf allocated for CUDA graph tensor
Expand All @@ -135,6 +134,8 @@ def run(
kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
k_scale: Optional[torch.Tensor] = None,
v_scale: Optional[torch.Tensor] = None,
logits_soft_cap: float = 0.0,
profiler_buffer: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -157,9 +158,13 @@ def run(
q.shape[0], q.shape[1], device=q.device, dtype=torch.float32
)
head_dim_qk = q.shape[2]
if self._sm_scale is None:
self._sm_scale = 1.0 / math.sqrt(head_dim_qk)

sm_scale = self._sm_scale
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim_qk)
if k_scale is not None:
sm_scale *= k_scale
if v_scale is None:
v_scale = 1.0
# profiler_buffer is optional
profiler_args = (profiler_buffer,) if self._use_profiler else ()

Expand All @@ -178,7 +183,8 @@ def run(
self._num_qo_heads,
self._num_kv_heads,
self._page_size,
self._sm_scale,
v_scale,
sm_scale,
logits_soft_cap,
# ADDITIONAL_FUNC_PARAMS
# PROFILER_FUNC_PARAMS
Expand Down
8 changes: 6 additions & 2 deletions include/flashinfer/attention/persistent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ __device__ __forceinline__ void write_o_(float (*o_frag)[KTraits::NUM_MMA_D_VO][

template <typename KTraits>
__device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8],
typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) {
typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2],
float v_scale = 1.0f) {
using AttentionVariant = typename KTraits::AttentionVariant;
if constexpr (AttentionVariant::use_softmax) {
float d_rcp[KTraits::NUM_MMA_Q][2];
Expand All @@ -163,6 +164,9 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_V
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
o_frag[mma_q][mma_d][reg_id] =
o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id >> 1) & 1];
if (v_scale != 1.0f) {
o_frag[mma_q][mma_d][reg_id] *= v_scale;
}
}
}
}
Expand Down Expand Up @@ -391,7 +395,7 @@ struct BlockBatchPagedAttentionPersistent {
threadblock_sync_mdo_states<KTraits>(o_frag, smem_storage, m, d, warp_idx, lane_idx, tid);

// normalize d
normalize_d<KTraits>(o_frag, m, d);
normalize_d<KTraits>(o_frag, m, d, params.v_scale);

// write back to global memory
// o_indptr (partial_o): [packed_qo_len * num_kv_chunks, num_kv_heads, head_dim]
Expand Down
12 changes: 9 additions & 3 deletions tests/test_batch_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _run_attention(
num_kv_heads=1,
num_qo_heads=1,
head_dim=128,
v_scale=None,
layout="NHD",
test_dtype=torch.bfloat16,
logits_soft_cap=0.0,
Expand Down Expand Up @@ -140,7 +141,7 @@ def _run_attention(

# --------- old scheduler --------- #
wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=dev),
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=dev),
Copy link
Contributor Author

@Edenzzzz Edenzzzz Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when running the test, the prefill kernel allocator hit OOM, which seems unrelated to this PR (also hit on main), so I increased the buffer size.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Np, please go ahead with it.

kv_layout=layout,
backend="fa2",
)
Expand All @@ -159,7 +160,7 @@ def _run_attention(
kv_data_type=test_dtype,
logits_soft_cap=logits_soft_cap,
)
out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True)
out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True, v_scale=v_scale)

# --------- new / mixed scheduler --------- #
wrapper = flashinfer.BatchAttention(kv_layout=layout)
Expand All @@ -178,7 +179,9 @@ def _run_attention(
kv_data_type=test_dtype,
logits_soft_cap=logits_soft_cap,
)
out_new, lse_new = wrapper.run(q, kv_data, logits_soft_cap=logits_soft_cap)
out_new, lse_new = wrapper.run(
q, kv_data, v_scale=v_scale, logits_soft_cap=logits_soft_cap
)

torch.cuda.synchronize()
torch.testing.assert_close(out_old, out_new, rtol=1e-2, atol=1e-2)
Expand All @@ -191,6 +194,7 @@ def _run_attention(
@pytest.mark.parametrize("num_kv_heads", [1, 4])
@pytest.mark.parametrize("gqa_group_size", [1, 4, 7, 8])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
@pytest.mark.parametrize("v_scale", [2.0, None])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("layout", ["HND", "NHD"])
@pytest.mark.parametrize("test_dtype", [torch.bfloat16, torch.float16])
Expand All @@ -201,6 +205,7 @@ def test_batch_attention_correctness(
num_kv_heads,
gqa_group_size,
head_dim,
v_scale,
causal,
layout,
test_dtype,
Expand All @@ -217,6 +222,7 @@ def test_batch_attention_correctness(
num_kv_heads=num_kv_heads,
num_qo_heads=num_qo_heads,
head_dim=head_dim,
v_scale=v_scale,
causal=causal,
layout=layout,
test_dtype=test_dtype,
Expand Down