-
Notifications
You must be signed in to change notification settings - Fork 33
Description
Your current environment
CUDA 13.1
VS Community 2026, with v143 build tools (14.44)
How you are installing vllm
C:\Env\vllm-windows\build\temp.win-amd64-cpython-312\Release> C:\Pkg\CUDA\13.1\bin\nvcc.exe -forward-unknown-to-host-compiler -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1 -DPy_LIMITED_API=3 -DPy_NO_LINK_LIB -DQUTLASS_DISABLE_PYBIND=1 -DTARGET_CUDA_ARCH=120 -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_DISTRIBUTED -D_C_EXPORTS -IC:\Env\vllm-windows\csrc -IC:\Env\vllm-windows\.deps\cutlass-src\include -IC:\Env\vllm-windows\.deps\cutlass-src\tools\util\include -IC:\Env\vllm-windows\.deps\qutlass-src -IC:\Env\vllm-windows\.deps\qutlass-src\qutlass -IC:\Env\vllm-windows\.deps\qutlass-src\qutlass\csrc\include -IC:\Env\vllm-windows\.deps\qutlass-src\qutlass\csrc\include\cutlass_extensions -isystem C:\Users\sqweek\AppData\Roaming\uv\python\cpython-3.12.11-windows-x86_64-none\Include -isystem C:\Env\cui\.venv\Lib\site-packages\torch\include -isystem C:\Env\cui\.venv\Lib\site-packages\torch\include\torch\csrc\api\include -isystem C:\Pkg\CUDA\13.1\include -D_WINDOWS -Xcompiler=" /EHsc" -DONNX_NAMESPACE=onnx_c2 --use-local-env -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --Werror cross-execution-space-call --no-host-device-move-forward --expt-relaxed-constexpr --expt-extended-lambda -Xcompiler=" -O2 -Ob1" -DNDEBUG -Xcompiler /MD -std=c++17 -Xcompiler=-MD -Xcompiler=-Zi --expt-relaxed-constexpr -DENABLE_FP8 --threads=1 --compress-mode=size --disable-warnings -O2 -FS -Xptxas=-O2 -Xcompiler=/O2 -Xcompiler=/FS -Xcompiler=/Z7 -Xcompiler=/Zc:__cplusplus -Xcompiler=/DWIN32_LEAN_AND_MEAN -Xcompiler=/DUSE_CUDA -DENABLE_SCALED_MM_SM120=1 -DENABLE_SCALED_MM_C2X=1 -DENABLE_NVFP4_SM120=1 -DENABLE_CUTLASS_MOE_SM120=1 -DENABLE_CUTLASS_MLA=1 -gencode arch=compute_120f,code=sm_120f -MD -MT CMakeFiles\_C.dir\csrc\quantization\fp4\nvfp4_experts_quant.cu.obj -MF CMakeFiles\_C.dir\csrc\quantization\fp4\nvfp4_experts_quant.cu.obj.d -x cu -c C:\Env\vllm-windows\csrc\quantization\fp4\nvfp4_experts_quant.cu -o CMakeFiles\_C.dir\csrc\quantization\fp4\nvfp4_experts_quant.cu.obj -Xcompiler=-FdCMakeFiles\_C.dir\,-FS
nvcc warning : -Werror and -w are specified at the same time. NVCC will disable all warnings and not treat the warnings as errors.
nvfp4_experts_quant.cu
cl : Command line warning D9025 : overriding '/Zi' with '/Z7'
nvfp4_experts_quant.cu
cl : Command line warning D9025 : overriding '/Zi' with '/Z7'
nvfp4_experts_quant.cu
C:\Env\vllm-windows\csrc\quantization\fp4\nvfp4_experts_quant.cu(339): error: "FLOAT" has already been declared in the current scope
constexpr auto FLOAT = at::ScalarType::Float;
^
C:\Env\vllm-windows\csrc\quantization\fp4\nvfp4_experts_quant.cu(340): error: "INT" has already been declared in the current scope
constexpr auto INT = at::ScalarType::Int;
^
C:\Env\vllm-windows\csrc\quantization\fp4\nvfp4_experts_quant.cu(341): error: "UINT8" has already been declared in the current scope
constexpr auto UINT8 = at::ScalarType::Byte;
^
3 errors detected in the compilation of "C:/Env/vllm-windows/csrc/quantization/fp4/nvfp4_experts_quant.cu".Appears to be a name collision with something; the following patch is sufficient to fix the compilation error:
diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu
index 32685c201..a956d6cb9 100644
--- a/csrc/quantization/fp4/nvfp4_experts_quant.cu
+++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu
@@ -334,11 +334,11 @@ void quant_impl(void* output, void* output_scale, void* input,
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
-constexpr auto HALF = at::ScalarType::Half;
-constexpr auto BF16 = at::ScalarType::BFloat16;
-constexpr auto FLOAT = at::ScalarType::Float;
-constexpr auto INT = at::ScalarType::Int;
-constexpr auto UINT8 = at::ScalarType::Byte;
+constexpr auto HALF_TYPE = at::ScalarType::Half;
+constexpr auto BF16_TYPE = at::ScalarType::BFloat16;
+constexpr auto FLOAT_TYPE = at::ScalarType::Float;
+constexpr auto INT_TYPE = at::ScalarType::Int;
+constexpr auto UINT8_TYPE = at::ScalarType::Byte;
// Common validation for fp4 experts quantization entry points.
static void validate_fp4_experts_quant_inputs(
@@ -361,14 +361,14 @@ static void validate_fp4_experts_quant_inputs(
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
- TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
- TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
- TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
- TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
+ TORCH_CHECK(input.scalar_type() == HALF_TYPE || input.scalar_type() == BF16_TYPE);
+ TORCH_CHECK(input_global_scale.scalar_type() == FLOAT_TYPE);
+ TORCH_CHECK(input_offset_by_experts.scalar_type() == INT_TYPE);
+ TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT_TYPE);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
- TORCH_CHECK(output.scalar_type() == UINT8);
- TORCH_CHECK(output_scale.scalar_type() == INT);
+ TORCH_CHECK(output.scalar_type() == UINT8_TYPE);
+ TORCH_CHECK(output_scale.scalar_type() == INT_TYPE);
const int BLOCK_SIZE = 16;
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
Looking at the preprocessed output this seems to be the culprit:
#line 155 "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.26100.0\\shared\\minwindef.h"
typedef unsigned long DWORD;
typedef int BOOL;
typedef unsigned char BYTE;
typedef unsigned short WORD;
typedef float FLOAT;
typedef FLOAT *PFLOAT;
Maybe they're new in this version? Note the directory name says 10.0.26100.0 but when installing Visual Studio Community 2026 I have selected "Windows 11 SDK (10.0.26100.7705)".
Also note VS 2026 didn't work at all with CUDA13.1's initially -- nvcc simply complained that the compiler version was incompatible/too new. I also had to install the older build tools (C++/CLI support for v143 build tools (14.44-17.14)) which I am loading into the environment via vcvarsall.bat x64 -vcvars_ver=14.44. In retrospect I probably should have installed VS 2022 but docs said "latest version" so here I am XD
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.