@@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
354354 # Only build Marlin kernels if we are building for at least some compatible archs.
355355 # Keep building Marlin for 9.0 as there are some group sizes and shapes that
356356 # are not supported by Machete yet.
357- # 9.0 for latest bf16 atomicAdd PTX
358- cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS} " )
357+
358+ # marlin arches for fp16 output
359+ cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS} " )
360+ # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
361+ cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS} " )
362+ # marlin arches for fp8 input
363+ # - sm80 doesn't support fp8 computation
364+ # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
365+ # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
366+ cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS} " )
367+
359368 if (MARLIN_ARCHS)
360369
361370 #
@@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
365374 set (MARLIN_GEN_SCRIPT
366375 ${CMAKE_CURRENT_SOURCE_DIR} /csrc/quantization/gptq_marlin/generate_kernels.py)
367376 file (MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
377+ list (JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
378+ set (MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH} (ARCH:${CUDA_ARCHS_STR} )" )
368379
369- message (STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH } " )
370- message (STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH } " )
380+ message (STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH } " )
381+ message (STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH } " )
371382
372- if (NOT DEFINED CACHE {MARLIN_GEN_SCRIPT_HASH }
373- OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH } STREQUAL ${MARLIN_GEN_SCRIPT_HASH } )
383+ if (NOT DEFINED CACHE {MARLIN_GEN_SCRIPT_HASH_AND_ARCH }
384+ OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH } STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH } )
374385 execute_process (
375386 COMMAND ${CMAKE_COMMAND} -E env
376387 PYTHONPATH=$PYTHONPATH
377- ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
388+ ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
378389 RESULT_VARIABLE marlin_generation_result
379390 OUTPUT_VARIABLE marlin_generation_result
380391 OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR} /marlin_generation.log
@@ -387,28 +398,50 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
387398 "\n Check the log for details: "
388399 "${CMAKE_CURRENT_BINARY_DIR} /marlin_generation.log" )
389400 else ()
390- set (MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH }
391- CACHE STRING "Last run Marlin generate script hash" FORCE)
401+ set (MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH }
402+ CACHE STRING "Last run Marlin generate script hash and arch " FORCE)
392403 message (STATUS "Marlin generation completed successfully." )
393404 endif ()
394405 else ()
395406 message (STATUS "Marlin generation script has not changed, skipping generation." )
396407 endif ()
397408
398- file (GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_* .cu" )
409+ file (GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16 .cu" )
399410 set_gencode_flags_for_srcs(
400411 SRCS "${MARLIN_TEMPLATE_KERNEL_SRC} "
401412 CUDA_ARCHS "${MARLIN_ARCHS} " )
402413 if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
403414 set_source_files_properties (${MARLIN_TEMPLATE_KERNEL_SRC}
404415 PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false" )
405416 endif ()
406-
407417 list (APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC} )
408418
419+ file (GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu" )
420+ set_gencode_flags_for_srcs(
421+ SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC} "
422+ CUDA_ARCHS "${MARLIN_BF16_ARCHS} " )
423+ if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
424+ set_source_files_properties (${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
425+ PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false" )
426+ endif ()
427+ list (APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC} )
428+
429+ if (MARLIN_FP8_ARCHS)
430+ file (GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu" )
431+ set_gencode_flags_for_srcs(
432+ SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC} "
433+ CUDA_ARCHS "${MARLIN_FP8_ARCHS} " )
434+ if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
435+ set_source_files_properties (${MARLIN_TEMPLATE_FP8_KERNEL_SRC}
436+ PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false" )
437+ endif ()
438+ list (APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC} )
439+ endif ()
440+
409441 set (MARLIN_SRCS
410442 "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
411443 "csrc/quantization/gptq_marlin/gptq_marlin.cu"
444+ "csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
412445 "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
413446 "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" )
414447 set_gencode_flags_for_srcs(
@@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
941974 CUDA_ARCHS "${CUDA_ARCHS} " )
942975
943976 list (APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC} " )
944- # 9.0 for latest bf16 atomicAdd PTX
945- cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS} " )
977+ # moe marlin arches
978+ # note that we always set `use_atomic_add=False` for moe marlin now,
979+ # so we don't need 9.0 for bf16 atomicAdd PTX
980+ cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS} " )
981+ # moe marlin arches for fp8 input
982+ # - sm80 doesn't support fp8 computation
983+ # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
984+ # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
985+ cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS} " )
946986 if (MARLIN_MOE_ARCHS)
947987
948988 #
@@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
952992 set (MOE_MARLIN_GEN_SCRIPT
953993 ${CMAKE_CURRENT_SOURCE_DIR} /csrc/moe/marlin_moe_wna16/generate_kernels.py)
954994 file (MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
995+ list (JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
996+ set (MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH} (ARCH:${CUDA_ARCHS_STR} )" )
955997
956- message (STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH } " )
957- message (STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH } " )
998+ message (STATUS "Marlin MOE generation script hash with arch : ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH } " )
999+ message (STATUS "Last run Marlin MOE generate script hash with arch : $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH } " )
9581000
959- if (NOT DEFINED CACHE {MOE_MARLIN_GEN_SCRIPT_HASH }
960- OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH } STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH } )
1001+ if (NOT DEFINED CACHE {MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH }
1002+ OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH } STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH } )
9611003 execute_process (
9621004 COMMAND ${CMAKE_COMMAND} -E env
9631005 PYTHONPATH=$PYTHONPATH
964- ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
1006+ ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
9651007 RESULT_VARIABLE moe_marlin_generation_result
9661008 OUTPUT_VARIABLE moe_marlin_generation_output
9671009 OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR} /moe_marlin_generation.log
@@ -974,24 +1016,36 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
9741016 "\n Check the log for details: "
9751017 "${CMAKE_CURRENT_BINARY_DIR} /moe_marlin_generation.log" )
9761018 else ()
977- set (MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH }
1019+ set (MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH }
9781020 CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
9791021 message (STATUS "Marlin MOE generation completed successfully." )
9801022 endif ()
9811023 else ()
9821024 message (STATUS "Marlin MOE generation script has not changed, skipping generation." )
9831025 endif ()
9841026
985- file (GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu" )
1027+ file (GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu" )
1028+ list (APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu" )
9861029 set_gencode_flags_for_srcs(
987- SRCS "${MOE_WNAA16_MARLIN_SRC } "
1030+ SRCS "${MARLIN_MOE_SRC } "
9881031 CUDA_ARCHS "${MARLIN_MOE_ARCHS} " )
9891032 if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
990- set_source_files_properties (${MOE_WNAA16_MARLIN_SRC }
1033+ set_source_files_properties (${MARLIN_MOE_SRC }
9911034 PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false" )
9921035 endif ()
993-
994- list (APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC} )
1036+ list (APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC} )
1037+
1038+ if (MARLIN_MOE_FP8_ARCHS)
1039+ file (GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu" )
1040+ set_gencode_flags_for_srcs(
1041+ SRCS "${MARLIN_MOE_FP8_SRC} "
1042+ CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS} " )
1043+ if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
1044+ set_source_files_properties (${MARLIN_MOE_FP8_SRC}
1045+ PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false" )
1046+ endif ()
1047+ list (APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC} )
1048+ endif ()
9951049
9961050 message (STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS} " )
9971051 else ()
0 commit comments