diff --git a/csrc/batch_attention.cu b/csrc/batch_attention.cu index 3590d1ce88..722029c419 100644 --- a/csrc/batch_attention.cu +++ b/csrc/batch_attention.cu @@ -68,7 +68,9 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t page_size, double sm_scale, + int64_t page_size, + double v_scale, // must use double due to pytorch binding + double sm_scale, double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) { HolisticPlanInfo<2> plan_info; plan_info.FromVector(tensor_to_vec(plan_info_vec)); @@ -171,6 +173,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo params[i].v_stride_n = v_stride_n; params[i].sm_scale = sm_scale; + params[i].v_scale = v_scale; params[i].logits_soft_cap = logits_soft_cap; // NOTE(Wenxuan) directly using the additional_params_decl from generate_additional_params // will be problematic because of the params[i] diff --git a/csrc/batch_attention_customize_config.jinja b/csrc/batch_attention_customize_config.jinja index 941a18bcd3..3cf9312748 100644 --- a/csrc/batch_attention_customize_config.jinja +++ b/csrc/batch_attention_customize_config.jinja @@ -109,7 +109,8 @@ struct PersistentParams { uint32_t v_stride_n; float sm_scale; - double logits_soft_cap; + float logits_soft_cap; + float v_scale; {{ additional_params_decl }} PROFILER_PARAMS_DECL diff --git a/csrc/batch_attention_jit_pybind.cu b/csrc/batch_attention_jit_pybind.cu index d2c6f8c150..096891a4de 100644 --- a/csrc/batch_attention_jit_pybind.cu +++ b/csrc/batch_attention_jit_pybind.cu @@ -28,7 +28,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t page_size, double sm_scale, + int64_t page_size, double v_scale, double sm_scale, double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS); TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { diff --git a/flashinfer/attention.py b/flashinfer/attention.py index a43798b26d..4d32d94ded 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention.py @@ -109,7 +109,6 @@ def plan( self._num_qo_heads = num_qo_heads self._num_kv_heads = num_kv_heads self._page_size = page_size - self._sm_scale = sm_scale self._use_profiler = use_profiler # No addtional buf allocated for CUDA graph tensor @@ -135,6 +134,8 @@ def run( kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, logits_soft_cap: float = 0.0, profiler_buffer: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -157,9 +158,13 @@ def run( q.shape[0], q.shape[1], device=q.device, dtype=torch.float32 ) head_dim_qk = q.shape[2] - if self._sm_scale is None: - self._sm_scale = 1.0 / math.sqrt(head_dim_qk) - + sm_scale = self._sm_scale + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim_qk) + if k_scale is not None: + sm_scale *= k_scale + if v_scale is None: + v_scale = 1.0 # profiler_buffer is optional profiler_args = (profiler_buffer,) if self._use_profiler else () @@ -178,7 +183,8 @@ def run( self._num_qo_heads, self._num_kv_heads, self._page_size, - self._sm_scale, + v_scale, + sm_scale, logits_soft_cap, # ADDITIONAL_FUNC_PARAMS # PROFILER_FUNC_PARAMS diff --git a/include/flashinfer/attention/persistent.cuh b/include/flashinfer/attention/persistent.cuh index d0f0f4ab8c..53c5be329d 100644 --- a/include/flashinfer/attention/persistent.cuh +++ b/include/flashinfer/attention/persistent.cuh @@ -140,7 +140,8 @@ __device__ __forceinline__ void write_o_(float (*o_frag)[KTraits::NUM_MMA_D_VO][ template __device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], - typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { + typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2], + float v_scale = 1.0f) { using AttentionVariant = typename KTraits::AttentionVariant; if constexpr (AttentionVariant::use_softmax) { float d_rcp[KTraits::NUM_MMA_Q][2]; @@ -163,6 +164,9 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_V for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_frag[mma_q][mma_d][reg_id] = o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id >> 1) & 1]; + if (v_scale != 1.0f) { + o_frag[mma_q][mma_d][reg_id] *= v_scale; + } } } } @@ -391,7 +395,7 @@ struct BlockBatchPagedAttentionPersistent { threadblock_sync_mdo_states(o_frag, smem_storage, m, d, warp_idx, lane_idx, tid); // normalize d - normalize_d(o_frag, m, d); + normalize_d(o_frag, m, d, params.v_scale); // write back to global memory // o_indptr (partial_o): [packed_qo_len * num_kv_chunks, num_kv_heads, head_dim] diff --git a/tests/test_batch_attention.py b/tests/test_batch_attention.py index 31a6f489d5..a717372305 100644 --- a/tests/test_batch_attention.py +++ b/tests/test_batch_attention.py @@ -90,6 +90,7 @@ def _run_attention( num_kv_heads=1, num_qo_heads=1, head_dim=128, + v_scale=None, layout="NHD", test_dtype=torch.bfloat16, logits_soft_cap=0.0, @@ -140,7 +141,7 @@ def _run_attention( # --------- old scheduler --------- # wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=dev), + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=dev), kv_layout=layout, backend="fa2", ) @@ -159,7 +160,7 @@ def _run_attention( kv_data_type=test_dtype, logits_soft_cap=logits_soft_cap, ) - out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True) + out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True, v_scale=v_scale) # --------- new / mixed scheduler --------- # wrapper = flashinfer.BatchAttention(kv_layout=layout) @@ -178,7 +179,9 @@ def _run_attention( kv_data_type=test_dtype, logits_soft_cap=logits_soft_cap, ) - out_new, lse_new = wrapper.run(q, kv_data, logits_soft_cap=logits_soft_cap) + out_new, lse_new = wrapper.run( + q, kv_data, v_scale=v_scale, logits_soft_cap=logits_soft_cap + ) torch.cuda.synchronize() torch.testing.assert_close(out_old, out_new, rtol=1e-2, atol=1e-2) @@ -191,6 +194,7 @@ def _run_attention( @pytest.mark.parametrize("num_kv_heads", [1, 4]) @pytest.mark.parametrize("gqa_group_size", [1, 4, 7, 8]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("v_scale", [2.0, None]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("layout", ["HND", "NHD"]) @pytest.mark.parametrize("test_dtype", [torch.bfloat16, torch.float16]) @@ -201,6 +205,7 @@ def test_batch_attention_correctness( num_kv_heads, gqa_group_size, head_dim, + v_scale, causal, layout, test_dtype, @@ -217,6 +222,7 @@ def test_batch_attention_correctness( num_kv_heads=num_kv_heads, num_qo_heads=num_qo_heads, head_dim=head_dim, + v_scale=v_scale, causal=causal, layout=layout, test_dtype=test_dtype,