Skip to content

Commit 1656ad3

Browse files
jinzhen-linmgoin
andauthored
[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
1 parent fa59fe4 commit 1656ad3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+4363
-2232
lines changed

CMakeLists.txt

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
354354
# Only build Marlin kernels if we are building for at least some compatible archs.
355355
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
356356
# are not supported by Machete yet.
357-
# 9.0 for latest bf16 atomicAdd PTX
358-
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
357+
358+
# marlin arches for fp16 output
359+
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
360+
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
361+
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
362+
# marlin arches for fp8 input
363+
# - sm80 doesn't support fp8 computation
364+
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
365+
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
366+
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
367+
359368
if (MARLIN_ARCHS)
360369

361370
#
@@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
365374
set(MARLIN_GEN_SCRIPT
366375
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
367376
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
377+
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
378+
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
368379

369-
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
370-
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
380+
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
381+
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
371382

372-
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
373-
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
383+
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
384+
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
374385
execute_process(
375386
COMMAND ${CMAKE_COMMAND} -E env
376387
PYTHONPATH=$PYTHONPATH
377-
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
388+
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
378389
RESULT_VARIABLE marlin_generation_result
379390
OUTPUT_VARIABLE marlin_generation_result
380391
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
@@ -387,28 +398,50 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
387398
"\nCheck the log for details: "
388399
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
389400
else()
390-
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
391-
CACHE STRING "Last run Marlin generate script hash" FORCE)
401+
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
402+
CACHE STRING "Last run Marlin generate script hash and arch" FORCE)
392403
message(STATUS "Marlin generation completed successfully.")
393404
endif()
394405
else()
395406
message(STATUS "Marlin generation script has not changed, skipping generation.")
396407
endif()
397408

398-
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
409+
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
399410
set_gencode_flags_for_srcs(
400411
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
401412
CUDA_ARCHS "${MARLIN_ARCHS}")
402413
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
403414
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
404415
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
405416
endif()
406-
407417
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
408418

419+
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
420+
set_gencode_flags_for_srcs(
421+
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
422+
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
423+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
424+
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
425+
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
426+
endif()
427+
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
428+
429+
if (MARLIN_FP8_ARCHS)
430+
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
431+
set_gencode_flags_for_srcs(
432+
SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
433+
CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
434+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
435+
set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC}
436+
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
437+
endif()
438+
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC})
439+
endif()
440+
409441
set(MARLIN_SRCS
410442
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
411443
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
444+
"csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
412445
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
413446
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
414447
set_gencode_flags_for_srcs(
@@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
941974
CUDA_ARCHS "${CUDA_ARCHS}")
942975

943976
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
944-
# 9.0 for latest bf16 atomicAdd PTX
945-
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
977+
# moe marlin arches
978+
# note that we always set `use_atomic_add=False` for moe marlin now,
979+
# so we don't need 9.0 for bf16 atomicAdd PTX
980+
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
981+
# moe marlin arches for fp8 input
982+
# - sm80 doesn't support fp8 computation
983+
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
984+
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
985+
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
946986
if (MARLIN_MOE_ARCHS)
947987

948988
#
@@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
952992
set(MOE_MARLIN_GEN_SCRIPT
953993
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
954994
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
995+
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
996+
set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
955997

956-
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
957-
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
998+
message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
999+
message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
9581000

959-
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
960-
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
1001+
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
1002+
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
9611003
execute_process(
9621004
COMMAND ${CMAKE_COMMAND} -E env
9631005
PYTHONPATH=$PYTHONPATH
964-
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
1006+
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
9651007
RESULT_VARIABLE moe_marlin_generation_result
9661008
OUTPUT_VARIABLE moe_marlin_generation_output
9671009
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
@@ -974,24 +1016,36 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
9741016
"\nCheck the log for details: "
9751017
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
9761018
else()
977-
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
1019+
set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
9781020
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
9791021
message(STATUS "Marlin MOE generation completed successfully.")
9801022
endif()
9811023
else()
9821024
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
9831025
endif()
9841026

985-
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
1027+
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
1028+
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
9861029
set_gencode_flags_for_srcs(
987-
SRCS "${MOE_WNAA16_MARLIN_SRC}"
1030+
SRCS "${MARLIN_MOE_SRC}"
9881031
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
9891032
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
990-
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
1033+
set_source_files_properties(${MARLIN_MOE_SRC}
9911034
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
9921035
endif()
993-
994-
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
1036+
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
1037+
1038+
if (MARLIN_MOE_FP8_ARCHS)
1039+
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
1040+
set_gencode_flags_for_srcs(
1041+
SRCS "${MARLIN_MOE_FP8_SRC}"
1042+
CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}")
1043+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
1044+
set_source_files_properties(${MARLIN_MOE_FP8_SRC}
1045+
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
1046+
endif()
1047+
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
1048+
endif()
9951049

9961050
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
9971051
else()

benchmarks/kernels/benchmark_machete.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
237237
b_q_weight=w_q,
238238
b_bias=None,
239239
b_scales=w_s,
240+
a_scales=None,
240241
global_scale=None,
241242
b_zeros=w_zp,
242243
g_idx=g_idx,

benchmarks/kernels/benchmark_marlin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def gen_allspark_params():
263263

264264
results.append(
265265
benchmark.Timer(
266-
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
266+
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
267267
globals=globals,
268268
label=label,
269269
sub_label=sub_label,
@@ -273,7 +273,7 @@ def gen_allspark_params():
273273

274274
results.append(
275275
benchmark.Timer(
276-
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
276+
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
277277
globals=globals,
278278
label=label,
279279
sub_label=sub_label,
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
kernel_*.cu
1+
sm*_kernel_*.cu
2+
kernel_selector.h

0 commit comments

Comments
 (0)