1515 */
1616
1717#include < torch/all.h>
18+ #include < c10/cuda/CUDAGuard.h>
19+ #include " cutlass_extensions/common.hpp"
1820
1921#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
2022void cutlass_scaled_fp4_mm_sm100a (torch::Tensor& D, torch::Tensor const & A,
@@ -32,23 +34,34 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
3234 torch::Tensor const & alpha);
3335#endif
3436
35- void cutlass_scaled_fp4_mm (torch::Tensor& D, torch::Tensor const & A,
36- torch::Tensor const & B, torch::Tensor const & A_sf,
37- torch::Tensor const & B_sf,
38- torch::Tensor const & alpha) {
39- #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
40- return cutlass_scaled_fp4_mm_sm100a (D, A, B, A_sf, B_sf, alpha);
41- #elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
42- return cutlass_scaled_fp4_mm_sm120a (D, A, B, A_sf, B_sf, alpha);
37+ void cutlass_scaled_fp4_mm (torch::Tensor& D, const torch::Tensor& A,
38+ const torch::Tensor& B, const torch::Tensor& A_sf,
39+ const torch::Tensor& B_sf,
40+ const torch::Tensor& alpha) {
41+ // Make sure we’re on A’s device.
42+ const c10::cuda::OptionalCUDAGuard device_guard (device_of (A));
43+ const int32_t sm = get_sm_version_num ();
44+
45+ #if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
46+ if (sm >= 100 && sm < 120 ) {
47+ cutlass_scaled_fp4_mm_sm100a (D, A, B, A_sf, B_sf, alpha);
48+ return ;
49+ }
50+ #endif
51+
52+ #if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
53+ if (sm >= 120 && sm < 130 ) {
54+ cutlass_scaled_fp4_mm_sm120a (D, A, B, A_sf, B_sf, alpha);
55+ return ;
56+ }
4357#endif
44- TORCH_CHECK_NOT_IMPLEMENTED (false ,
45- " No compiled nvfp4 mm kernel, vLLM should "
46- " be compiled using CUDA 12.8 and target "
47- " compute capability 100 or above." );
58+
59+ TORCH_CHECK_NOT_IMPLEMENTED (false , " No compiled nvfp4 mm kernel for SM " , sm,
60+ " . Recompile with CUDA >= 12.8 and CC >= 100." );
4861}
4962
5063bool cutlass_scaled_mm_supports_fp4 (int64_t cuda_device_capability) {
5164 int runtimeVersion;
5265 cudaRuntimeGetVersion (&runtimeVersion);
5366 return cuda_device_capability >= 100 && runtimeVersion >= 12080 ;
54- }
67+ }
0 commit comments