Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions csrc/batch_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params pa

using namespace flashinfer;

Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
Tensor kv_indptr, Tensor kv_len, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_o, bool causal) {
Array<int64_t> BatchPagedAttentionPlan(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView qo_indptr, TensorView kv_indptr,
TensorView kv_len, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t head_dim_o, bool causal) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
Expand All @@ -63,11 +64,12 @@ Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int
return Array(plan_info.ToVector());
}

void BatchPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor kv_indices, Tensor o, Optional<Tensor> maybe_lse,
int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size,
void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q, TensorView k_cache,
TensorView v_cache, TensorView kv_indices, TensorView o,
Optional<TensorView> maybe_lse, int64_t mask_mode_code,
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size,
double v_scale, // must use double due to pytorch binding
double sm_scale,
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
Expand Down
23 changes: 12 additions & 11 deletions csrc/batch_attention_jit_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@
using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
Tensor kv_indptr, Tensor kv_len, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_o, bool causal);
Array<int64_t> BatchPagedAttentionPlan(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView qo_indptr, TensorView kv_indptr,
TensorView kv_len, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t head_dim_o, bool causal);

void BatchPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor kv_indices, Tensor o, Optional<Tensor> maybe_lse,
int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, double v_scale,
double sm_scale,
void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q, TensorView k_cache,
TensorView v_cache, TensorView kv_indices, TensorView o,
Optional<TensorView> maybe_lse, int64_t mask_mode_code,
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size, double v_scale, double sm_scale,
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, &BatchPagedAttentionPlan);
Expand Down
20 changes: 11 additions & 9 deletions csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchDecodeWithPagedKVCachePlan(
Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor indptr, int64_t batch_size,
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer, TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
Tensor empty_q_data, Tensor empty_kv_data) {
TensorView empty_q_data, TensorView empty_kv_data) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
Expand Down Expand Up @@ -78,12 +78,14 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlan(
return Array(plan_info.ToVector());
}

void BatchDecodeWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q, Tensor paged_k_cache,
Tensor paged_v_cache, Tensor paged_kv_indptr,
Tensor paged_kv_indices, Tensor paged_kv_last_page_len,
Tensor o, Optional<Tensor> maybe_lse, int64_t kv_layout_code,
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
TensorView int_workspace_buffer, Array<int64_t> plan_info_vec,
TensorView q, TensorView paged_k_cache,
TensorView paged_v_cache, TensorView paged_kv_indptr,
TensorView paged_kv_indices, TensorView paged_kv_last_page_len,
TensorView o, Optional<TensorView> maybe_lse,
int64_t kv_layout_code, int64_t window_left,
bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
DecodePlanInfo plan_info;
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
Expand Down
20 changes: 11 additions & 9 deletions csrc/batch_decode_jit_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchDecodeWithPagedKVCachePlan(
Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor indptr, int64_t batch_size,
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer, TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
Tensor empty_q_data, Tensor empty_kv_data);
TensorView empty_q_data, TensorView empty_kv_data);

void BatchDecodeWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q, Tensor paged_k_cache,
Tensor paged_v_cache, Tensor paged_kv_indptr,
Tensor paged_kv_indices, Tensor paged_kv_last_page_len,
Tensor o, Optional<Tensor> maybe_lse, int64_t kv_layout_code,
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS);
void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
TensorView int_workspace_buffer, Array<int64_t> plan_info_vec,
TensorView q, TensorView paged_k_cache,
TensorView paged_v_cache, TensorView paged_kv_indptr,
TensorView paged_kv_indices, TensorView paged_kv_last_page_len,
TensorView o, Optional<TensorView> maybe_lse,
int64_t kv_layout_code, int64_t window_left,
bool enable_pdl ADDITIONAL_FUNC_PARAMS);

// Batched decode with paged KV-Cache plan
TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchDecodeWithPagedKVCachePlan);
Expand Down
23 changes: 11 additions & 12 deletions csrc/batch_decode_mla_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@
using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(Tensor float_workspace_buffer,
Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer,
Tensor indptr, int64_t batch_size,
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t page_size,
bool enable_cuda_graph);

void BatchDecodeWithPagedKVCacheRunMLA(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q_nope, Tensor q_pe,
Tensor paged_ckv_cache, Tensor paged_kpe_cache,
Tensor paged_kv_indptr, Tensor paged_kv_indices,
Tensor paged_kv_last_page_len, Tensor o, double sm_scale,
int64_t window_left, double logits_soft_cap,
double rope_scale, double rope_theta,
Optional<Tensor> maybe_lse, bool enable_pdl);
void BatchDecodeWithPagedKVCacheRunMLA(
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView paged_ckv_cache,
TensorView paged_kpe_cache, TensorView paged_kv_indptr, TensorView paged_kv_indices,
TensorView paged_kv_last_page_len, TensorView o, double sm_scale, int64_t window_left,
double logits_soft_cap, double rope_scale, double rope_theta, Optional<TensorView> maybe_lse,
bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchDecodeWithPagedKVCachePlanMLA);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, BatchDecodeWithPagedKVCacheRunMLA);
20 changes: 11 additions & 9 deletions csrc/batch_decode_mla_cute_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ using namespace flashinfer;
using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::Tensor float_workspace_buffer,
ffi::Tensor int_workspace_buffer,
ffi::Tensor page_locked_int_workspace_buffer,
ffi::Tensor indptr, int64_t batch_size,
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::TensorView float_workspace_buffer,
ffi::TensorView int_workspace_buffer,
ffi::TensorView page_locked_int_workspace_buffer,
ffi::TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t page_size,
bool enable_cuda_graph) {
size_t float_workspace_size_in_bytes =
Expand Down Expand Up @@ -43,11 +43,13 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::Tensor float_workspace_bu
}

void BatchDecodeWithPagedKVCacheRunMLA(
ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, ffi::Tensor q_nope, ffi::Tensor q_pe, ffi::Tensor paged_ckv_cache,
ffi::Tensor paged_kpe_cache, ffi::Tensor paged_kv_indptr, ffi::Tensor paged_kv_indices,
ffi::Tensor paged_kv_last_page_len, ffi::Tensor o, double sm_scale, int64_t window_left,
double logits_soft_cap, double rope_scale, double rope_theta, Optional<ffi::Tensor> maybe_lse,
ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, ffi::TensorView q_nope, ffi::TensorView q_pe,
ffi::TensorView paged_ckv_cache, ffi::TensorView paged_kpe_cache,
ffi::TensorView paged_kv_indptr, ffi::TensorView paged_kv_indices,
ffi::TensorView paged_kv_last_page_len, ffi::TensorView o, double sm_scale, int64_t window_left,
double logits_soft_cap, double rope_scale, double rope_theta,
Optional<ffi::TensorView> maybe_lse,
bool enable_pdl // fake placeholder, sm80 does not support pdl
) {
DecodePlanInfo plan_info;
Expand Down
8 changes: 4 additions & 4 deletions csrc/batch_decode_mla_plan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ using namespace flashinfer;

using tvm::ffi::Array;

Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(Tensor float_workspace_buffer,
Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer,
Tensor indptr, int64_t batch_size,
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t page_size,
bool enable_cuda_graph) {
cudaSetDevice(float_workspace_buffer->device.device_id);
Expand Down
15 changes: 7 additions & 8 deletions csrc/batch_decode_mla_run.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ using namespace flashinfer;
using tvm::ffi::Array;
using tvm::ffi::Optional;

void BatchDecodeWithPagedKVCacheRunMLA(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q_nope, Tensor q_pe,
Tensor paged_ckv_cache, Tensor paged_kpe_cache,
Tensor paged_kv_indptr, Tensor paged_kv_indices,
Tensor paged_kv_last_page_len, Tensor o, double sm_scale,
int64_t window_left, double logits_soft_cap,
double rope_scale, double rope_theta,
Optional<Tensor> maybe_lse, bool enable_pdl) {
void BatchDecodeWithPagedKVCacheRunMLA(
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView paged_ckv_cache,
TensorView paged_kpe_cache, TensorView paged_kv_indptr, TensorView paged_kv_indices,
TensorView paged_kv_last_page_len, TensorView o, double sm_scale, int64_t window_left,
double logits_soft_cap, double rope_scale, double rope_theta, Optional<TensorView> maybe_lse,
bool enable_pdl) {
DecodePlanInfo plan_info;
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));

Expand Down
19 changes: 10 additions & 9 deletions csrc/batch_mla_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@
using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchMLAPagedAttentionPlan(Tensor float_workspace_buffer,
Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
Tensor kv_indptr, Tensor kv_len, int64_t num_heads,
int64_t head_dim_o, bool causal);
Array<int64_t> BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView qo_indptr, TensorView kv_indptr,
TensorView kv_len, int64_t num_heads, int64_t head_dim_o,
bool causal);

void BatchMLAPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q_nope, Tensor q_pe,
Tensor ckv_cache, Tensor kpe_cache, Tensor kv_indices, Tensor o,
Optional<Tensor> maybe_lse, int64_t mask_mode_code,
void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q_nope, TensorView q_pe,
TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices,
TensorView o, Optional<TensorView> maybe_lse, int64_t mask_mode_code,
int64_t num_heads, int64_t page_size, double sm_scale);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchMLAPagedAttentionPlan);
Expand Down
11 changes: 6 additions & 5 deletions csrc/batch_mla_plan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ using namespace flashinfer;

using tvm::ffi::Array;

Array<int64_t> BatchMLAPagedAttentionPlan(Tensor float_workspace_buffer,
Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
Tensor kv_indptr, Tensor kv_len, int64_t num_heads,
int64_t head_dim_o, bool causal) {
Array<int64_t> BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView qo_indptr, TensorView kv_indptr,
TensorView kv_len, int64_t num_heads, int64_t head_dim_o,
bool causal) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
Expand Down
Loading