Skip to content

Commit 23ff744

Browse files
authored
Bump tvm ffi version to 0.1.4 (#2155)
<!-- .github/pull_request_template.md --> ## 📌 Description Bump tvm ffi version to 0.1.4 and use `ffi::CUDADeviceGuard` instead of `cudaSetDevice`. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved GPU device management across CUDA operations for more reliable multi-GPU support and automatic resource cleanup. * **Chores** * Updated `apache-tvm-ffi` dependency to version 0.1.4 or later. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent e59226b commit 23ff744

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+77
-76
lines changed

csrc/batch_attention.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Array<int64_t> BatchPagedAttentionPlan(TensorView float_workspace_buffer,
4848

4949
HolisticPlanInfo<2> plan_info;
5050

51-
cudaSetDevice(float_workspace_buffer.device().device_id);
51+
ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id);
5252
const cudaStream_t stream = get_stream(float_workspace_buffer.device());
5353

5454
cudaError_t status = TwoStageHolisticPlan<IdType>(
@@ -102,7 +102,7 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo
102102
v_stride_n = v_cache.stride(2);
103103
}
104104

105-
cudaSetDevice(q.device().device_id);
105+
ffi::CUDADeviceGuard device_guard(q.device().device_id);
106106
const cudaStream_t stream = get_stream(q.device());
107107

108108
DISPATCH_context(

csrc/batch_decode.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlan(
5353
<< "CUDA cores template only supports equal head dim for QK and VO, please use tensor "
5454
"cores template for different head dim";
5555

56-
cudaSetDevice(float_workspace_buffer.device().device_id);
56+
ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id);
5757
const cudaStream_t stream = get_stream(float_workspace_buffer.device());
5858
DISPATCH_context(
5959
DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
@@ -130,7 +130,7 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
130130
}
131131
kv_cache_strides = k_strides.data();
132132

133-
cudaSetDevice(q.device().device_id);
133+
ffi::CUDADeviceGuard device_guard(q.device().device_id);
134134
const cudaStream_t stream = get_stream(q.device());
135135

136136
DISPATCH_context(

csrc/batch_decode_mla_cute_sm80.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::TensorView float_workspac
2323
int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer);
2424

2525
DecodePlanInfo plan_info;
26-
cudaSetDevice(float_workspace_buffer.device().device_id);
26+
ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id);
2727
const cudaStream_t stream = get_stream(float_workspace_buffer.device());
2828

2929
auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80<
@@ -103,7 +103,7 @@ void BatchDecodeWithPagedKVCacheRunMLA(
103103
}
104104
params.padded_batch_size = plan_info.padded_batch_size;
105105

106-
cudaSetDevice(paged_ckv_cache.device().device_id);
106+
ffi::CUDADeviceGuard device_guard(paged_ckv_cache.device().device_id);
107107
const cudaStream_t stream = get_stream(paged_ckv_cache.device());
108108
cudaError_t status = BatchDecodeWithPagedKVCacheDispatchedMlaCuteSM80<HEAD_DIM_CKV, HEAD_DIM_KPE,
109109
QO_TILE_LEN, Params>(

csrc/batch_decode_mla_plan.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buf
1515
TensorView indptr, int64_t batch_size,
1616
int64_t num_qo_heads, int64_t page_size,
1717
bool enable_cuda_graph) {
18-
cudaSetDevice(float_workspace_buffer.device().device_id);
18+
ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id);
1919
const cudaStream_t stream = get_stream(float_workspace_buffer.device());
2020

2121
size_t float_workspace_size_in_bytes =

csrc/batch_decode_mla_run.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void BatchDecodeWithPagedKVCacheRunMLA(
3535
void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
3636
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());
3737

38-
cudaSetDevice(q_nope.device().device_id);
38+
ffi::CUDADeviceGuard device_guard(q_nope.device().device_id);
3939
const cudaStream_t stream = get_stream(q_nope.device());
4040

4141
paged_kv_mla_t<DTypeKV, IdType> paged_kv(

csrc/batch_mla_plan.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Array<int64_t> BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer,
3838

3939
int batch_size = kv_len.size(0);
4040

41-
cudaSetDevice(float_workspace_buffer.device().device_id);
41+
ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id);
4242
const cudaStream_t stream = get_stream(float_workspace_buffer.device());
4343

4444
cudaError_t status =

csrc/batch_mla_run.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int
5656
unsigned int o_stride_n = o.stride(0);
5757
unsigned int o_stride_h = o.stride(1);
5858

59-
cudaSetDevice(q_nope.device().device_id);
59+
ffi::CUDADeviceGuard device_guard(q_nope.device().device_id);
6060
const cudaStream_t stream = get_stream(q_nope.device());
6161

6262
DISPATCH_context(

csrc/batch_mla_sm90_plan.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Array<int64_t> BatchMLAPagedAttentionSM90Plan(TensorView float_workspace_buffer,
3838

3939
int batch_size = kv_len.size(0);
4040

41-
cudaSetDevice(float_workspace_buffer.device().device_id);
41+
ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id);
4242
const cudaStream_t stream = get_stream(float_workspace_buffer.device());
4343

4444
cudaError_t status =

csrc/batch_mla_sm90_run.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer,
5656
unsigned int o_stride_n = o.stride(0);
5757
unsigned int o_stride_h = o.stride(1);
5858

59-
cudaSetDevice(q_nope.device().device_id);
59+
ffi::CUDADeviceGuard device_guard(q_nope.device().device_id);
6060
const cudaStream_t stream = get_stream(q_nope.device());
6161

6262
DISPATCH_context(

csrc/batch_pod.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ void batch_pod_with_kv_cache_tensor(
100100
}
101101
kv_cache_strides_p = k_strides_p.data();
102102

103-
cudaSetDevice(float_workspace_buffer_p.device().device_id);
103+
ffi::CUDADeviceGuard device_guard(float_workspace_buffer_p.device().device_id);
104104
const cudaStream_t stream = get_stream(float_workspace_buffer_p.device());
105105

106106
// Decode setup (TensorView decode = batched prefill)
@@ -152,7 +152,7 @@ void batch_pod_with_kv_cache_tensor(
152152
kv_cache_strides_d = k_strides_d.data();
153153

154154
// Already handled by prefill
155-
// cudaSetDevice(float_workspace_buffer_d.device().device_id);
155+
// ffi::CUDADeviceGuard device_guard(float_workspace_buffer_d.device().device_id);
156156
// const cudaStream_t stream = get_stream(float_workspace_buffer_d.device());
157157

158158
DISPATCH_context(

0 commit comments

Comments
 (0)