From 0edf9ef3b85b28f22655f9f32907871297550282 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 1 Dec 2025 15:37:42 -0500 Subject: [PATCH] Rebase on main Signed-off-by: mgoin --- .../quantization/fp4/nvfp4_scaled_mm_entry.cu | 39 ++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu index 9cba2828aac2..d9c4d24d8e1f 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -15,6 +15,8 @@ */ #include +#include +#include "cutlass_extensions/common.hpp" #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, @@ -32,23 +34,34 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& alpha); #endif -void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 - return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); -#elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 - return cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); +void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& A_sf, + const torch::Tensor& B_sf, + const torch::Tensor& alpha) { + // Make sure we’re on A’s device. + const c10::cuda::OptionalCUDAGuard device_guard(device_of(A)); + const int32_t sm = get_sm_version_num(); + +#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + if (sm >= 100 && sm < 120) { + cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); + return; + } +#endif + +#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120 + if (sm >= 120 && sm < 130) { + cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); + return; + } #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, - "No compiled nvfp4 mm kernel, vLLM should " - "be compiled using CUDA 12.8 and target " - "compute capability 100 or above."); + + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm, + ". Recompile with CUDA >= 12.8 and CC >= 100."); } bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) { int runtimeVersion; cudaRuntimeGetVersion(&runtimeVersion); return cuda_device_capability >= 100 && runtimeVersion >= 12080; -} \ No newline at end of file +}