Skip to content

Commit c0dfc89

Browse files
hholtmannmgoin
andauthored
SM120 / NVFP4: add device guard and runtime SM dispatch to cutlass_scaled_fp4_mm (#29711)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
1 parent 44822d7 commit c0dfc89

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616

1717
#include <torch/all.h>
18+
#include <c10/cuda/CUDAGuard.h>
19+
#include "cutlass_extensions/common.hpp"
1820

1921
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
2022
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
@@ -32,23 +34,34 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
3234
torch::Tensor const& alpha);
3335
#endif
3436

35-
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
36-
torch::Tensor const& B, torch::Tensor const& A_sf,
37-
torch::Tensor const& B_sf,
38-
torch::Tensor const& alpha) {
39-
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
40-
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
41-
#elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
42-
return cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
37+
void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
38+
const torch::Tensor& B, const torch::Tensor& A_sf,
39+
const torch::Tensor& B_sf,
40+
const torch::Tensor& alpha) {
41+
// Make sure we’re on A’s device.
42+
const c10::cuda::OptionalCUDAGuard device_guard(device_of(A));
43+
const int32_t sm = get_sm_version_num();
44+
45+
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
46+
if (sm >= 100 && sm < 120) {
47+
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
48+
return;
49+
}
50+
#endif
51+
52+
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
53+
if (sm >= 120 && sm < 130) {
54+
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
55+
return;
56+
}
4357
#endif
44-
TORCH_CHECK_NOT_IMPLEMENTED(false,
45-
"No compiled nvfp4 mm kernel, vLLM should "
46-
"be compiled using CUDA 12.8 and target "
47-
"compute capability 100 or above.");
58+
59+
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm,
60+
". Recompile with CUDA >= 12.8 and CC >= 100.");
4861
}
4962

5063
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
5164
int runtimeVersion;
5265
cudaRuntimeGetVersion(&runtimeVersion);
5366
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
54-
}
67+
}

0 commit comments

Comments
 (0)