diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 520029f0ec..d5f363839a 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -235,17 +235,7 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cudnn", "cublas", "cutlass"], "12.0": ["cudnn", "cublas"], }, - "mm_fp4": { - "7.5": [], - "8.0": [], - "8.6": [], - "8.9": [], - "9.0": [], - "10.0": ["cudnn", "trtllm", "cutlass"], - "10.3": ["cudnn", "trtllm", "cutlass"], - "12.0": ["cudnn", "cutlass"], - "12.1": ["cudnn", "cutlass"], - }, + # Note: mm_fp4 uses support checkers to filter backends, so it is not listed here # MOE "trtllm_fp4_block_scale_moe": { "7.5": [], diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 17336189d0..9f95f17fb4 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -131,7 +131,7 @@ def parse_gemm_args(line, parser): required=False, nargs="+", default=["cudnn"], - choices=["cudnn", "cublas", "trtllm", "cutlass"], + choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"], help="Kernel backends to test. Default: cudnn", ) parser.add_argument( @@ -790,61 +790,14 @@ def testMmFp4(args): run_refcheck = args.refcheck use_128x4_sf_layout = args.use_128x4_sf_layout use_nvfp4 = args.use_nvfp4 - autotune_supported_backends = ["cutlass", "trtllm"] + autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"] res = [] - backends = filter_backends_by_compute_capability(backends, args.routine, device) - res_dtype = dtype_str_to_torch_dtype(args.out_dtype) if res_dtype not in [torch.bfloat16, torch.float16]: raise ValueError( f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16." ) - ## Done parsing input arguments - - if "trtllm" in backends: - remove_trtllm = False - if res_dtype == torch.float16: - print("[INFO] trtllm backend does not support float16 output") - remove_trtllm = True - if remove_trtllm: - backends.remove("trtllm") - if not use_nvfp4: - print( - "[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - backends.remove("trtllm") - if "cutlass" in backends: - remove_cutlass = False - if not use_128x4_sf_layout: - print("[INFO] cutlass backend does not support use_128x4_sf_layout=False") - remove_cutlass = True - if not use_nvfp4: - print( - "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - backends.remove("cutlass") - if remove_cutlass: - backends.remove("cutlass") - if "cudnn" in backends: - remove_cudnn = False - if not use_128x4_sf_layout: - print("[INFO] cudnn backend does not support use_128x4_sf_layout=False") - remove_cudnn = True - if remove_cudnn: - backends.remove("cudnn") - if getattr(args, "autotune", False): - backends_to_remove = [] - for cur_backend in backends: - if cur_backend not in autotune_supported_backends: - print(f"[INFO] {cur_backend} backend does not support autotune") - backends_to_remove.append(cur_backend) - for cur_backend in backends_to_remove: - backends.remove(cur_backend) - - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return input = torch.randn([m, k], device=device, dtype=torch.bfloat16) mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16) @@ -886,11 +839,22 @@ def testMmFp4(args): print(f"[VVERBOSE] {mat2_fp4.dtype = }") alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None - # res = torch.empty([m, n], device="cuda", dtype=res_dtype) + # Completed preparing inputs. Now programmatically filter backends + block_size = 16 if use_nvfp4 else 32 + backends_to_remove = [] - def run_backend(backend): - if backend in ["cudnn", "trtllm", "cutlass"]: - return flashinfer.gemm.mm_fp4( + for backend in backends: + # Skip autotune check for now (handled separately below) + if ( + getattr(args, "autotune", False) + and backend not in autotune_supported_backends + ): + print(f"[INFO] {backend} backend does not support autotune") + backends_to_remove.append(backend) + continue + + try: + flashinfer.gemm.mm_fp4( a=input_fp4, b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T, a_descale=input_inv_s, @@ -904,6 +868,34 @@ def run_backend(backend): backend=backend, use_nvfp4=use_nvfp4, ) + except Exception as e: + print( + f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" + ) + backends_to_remove.append(backend) + + # Remove unsupported backends + for backend in backends_to_remove: + backends.remove(backend) + + if len(backends) == 0: + print("[ERROR] No backends passed validation. Exiting.") + return + + def run_backend(backend): + if backend in ["cudnn", "trtllm", "cutlass", "auto"]: + return flashinfer.gemm.mm_fp4( + a=input_fp4, + b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T, + a_descale=input_inv_s, + b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T, + alpha=alpha, + out_dtype=res_dtype, + block_size=block_size, + use_8x4_sf_layout=not use_128x4_sf_layout, + backend=backend, + use_nvfp4=use_nvfp4, + ) else: raise ValueError(f"Unsupported backend: {backend}") @@ -917,12 +909,11 @@ def run_backend(backend): args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 ) for cur_backend in backends: - if cur_backend in autotune_supported_backends: - if args.verbose >= 1: - print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters") - with autotune(True): - for _ in range(warmup_iters): - run_backend(cur_backend) + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters") + with autotune(True): + for _ in range(warmup_iters): + run_backend(cur_backend) # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ac0fbab4a0..589c651aca 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -54,6 +54,7 @@ from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module from ..jit.gemm import gen_deepgemm_sm100_module +from ..jit.cpp_ext import get_cuda_version CUDNN_AVAILABLE = False @@ -406,87 +407,50 @@ def fp8_gemm_sm100( def _create_cutlass_fp4_gemm_module(module, op_name: str, tuner_name: str): """Helper function to create cutlass FP4 GEMM module.""" - class CutlassFp4GemmRunner(TunableRunner): - def __init__(self): - self._fp4_gemm_runner = module.fp4_gemm + def cutlass_fp4_gemm_runner(): + class CutlassFp4GemmRunner(TunableRunner): + def __init__(self): + self._fp4_gemm_runner = module.fp4_gemm - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - return list(range(module.fp4_gemm_tactic_num())) - - def forward( - self, - inputs: List[torch.Tensor], - tactic: int = -1, - do_preparation: bool = False, - **kwargs, - ): - a, b, a_descale, b_descale, alpha, out, workspace_buffer = inputs - module.fp4_gemm( - a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic - ) - return out - - @register_custom_op( - op_name, - mutates_args=(""), - ) - def cutlass_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - workspace_buffer: torch.Tensor, - ): - tuner = AutoTuner.get() - - a_tensor_index = 0 - a_scale_tensor_index = 2 - out_tensor_index = 5 - - def pad_up(x, y): - return ((x + y - 1) // y) * y - - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up(shapes[a_tensor_index][0], 128), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] - ), - ), - ) - - fp4_runner = CutlassFp4GemmRunner() + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + return list(range(module.fp4_gemm_tactic_num())) - inputs = [a, b, a_descale, b_descale, alpha, out, workspace_buffer] - _, tactic = tuner.choose_one( - tuner_name, - [fp4_runner], - tuning_config, - inputs, - ) + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ): + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: + a_descale = a_descale.view(torch.uint8) + if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: + b_descale = b_descale.view(torch.uint8) + module.fp4_gemm( + a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer, tactic + ) + return out - fp4_runner(inputs=inputs, tactic=tactic) + return CutlassFp4GemmRunner() return SimpleNamespace( - cutlass_fp4_gemm=cutlass_fp4_gemm, + cutlass_fp4_gemm_runner=cutlass_fp4_gemm_runner, ) @@ -508,6 +472,17 @@ def get_gemm_sm120_module_cutlass_fp4(): ) +def get_cutlass_fp4_gemm_module( + sm_major: int, +): + if sm_major in [10, 11]: + return get_gemm_sm100_module_cutlass_fp4() + elif sm_major == 12: + return get_gemm_sm120_module_cutlass_fp4() + else: + raise ValueError(f"Unsupported SM major version: {sm_major}") + + @functools.cache def get_tgv_gemm_sm10x_module( dtype: torch.dtype = torch.bfloat16, use_sm_100f: bool = False @@ -1139,7 +1114,6 @@ def _check_cudnn_fp4_availability(): def _is_cublas_fp4_available_in_cudnn(): """Check if cuBLAS backend for FP4 GEMM is available in cuDNN.""" - _check_cudnn_availability() # Check cuDNN backend version for FP4 support (requires cudnn_version == 9.11.1 or cudnn_version >= 9.13) backend_version = cudnn.backend_version() @@ -1191,7 +1165,6 @@ def create_cudnn_execution_plans_fp4_gemm( alpha_is_not_none, use_nvfp4, ): - _check_cudnn_availability() stream = torch.cuda.current_stream(device) with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _): scale_type = cudnn.data_type.FP8_E4M3 if use_nvfp4 else cudnn.data_type.FP8_E8M0 @@ -1292,7 +1265,9 @@ def build_plans_cudnn_fp4_gemm_graph( device, alpha, use_nvfp4, + tactic: int = -1, ): + # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement graph = create_cudnn_execution_plans_fp4_gemm( a_shape, a_stride, @@ -1311,7 +1286,10 @@ def build_plans_cudnn_fp4_gemm_graph( ) graph.check_support() - graph.build_plans() + if tactic != -1: + graph.build_plan_at_index(tactic) + else: + graph.build_plans() return graph @@ -1324,6 +1302,7 @@ def execute_cudnn_gemm_fp4_graph( alpha, c_final, workspace_buffer, + tactic: int = -1, ): variant_pack = { UIDs.A_UID.value: a.view(get_native_fp4_dtype()), @@ -1343,7 +1322,12 @@ def execute_cudnn_gemm_fp4_graph( stream = torch.cuda.current_stream(a.device) - graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream)) + if tactic == -1: + graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream)) + else: + graph.execute_plan_at_index( + variant_pack, workspace_buffer, tactic, handle=_get_cudnn_handle(stream) + ) @functools.cache @@ -1677,7 +1661,54 @@ def mm_fp8( return out -def _check_mm_fp4_problem_size( +def _get_cudnn_fp4_gemm_graph( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_nvfp4: bool = True, + tactic: int = -1, +): + # the fp4 cudnn graph will be shared for both mm and bmm, so + # here we need to get the 3d shape and stride including the + # batch dimension for both input and block scale tensors. + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + expanded_a_descale_shape, expanded_a_descale_stride = ( + _expand_block_scale_tensor_shape(a_descale, batch) + ) + expanded_b_descale_shape, expanded_b_descale_stride = ( + _expand_block_scale_tensor_shape(b_descale, batch) + ) + + # build the fp4 cudnn graph + # Constructed graph is cached, via @functools.cache decorator. + graph = build_plans_cudnn_fp4_gemm_graph( + real_a_shape, + real_a_stride, + real_b_shape, + real_b_stride, + expanded_a_descale_shape, + expanded_a_descale_stride, + expanded_b_descale_shape, + expanded_b_descale_stride, + cudnn.data_type.FP4_E2M1, + _torch_data_type_to_cudnn_data_type(out_dtype), + block_size, + a.device, + alpha is not None, + use_nvfp4, + tactic=tactic, + ) + return graph + + +def _cudnn_gemm_fp4( a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, @@ -1686,8 +1717,114 @@ def _check_mm_fp4_problem_size( out_dtype: torch.dtype = torch.bfloat16, out: Optional[torch.Tensor] = None, block_size: int = 16, - use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + use_nvfp4: bool = True, + workspace_buffer: torch.Tensor = None, + tactic: int = -1, +): + # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement + graph = _get_cudnn_fp4_gemm_graph( + a=a, + b=b, + a_descale=a_descale, + b_descale=b_descale, + alpha=alpha, + out_dtype=out_dtype, + out=out, + block_size=block_size, + use_nvfp4=use_nvfp4, + tactic=tactic, + ) + # execute the fp4 cudnn graph + execute_cudnn_gemm_fp4_graph( + graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic=tactic + ) + + +def _cudnn_gemm_fp4_runner(): + class CudnnFp4GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic + ( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ) = inputs + + # Graph should have been already cached, when we ran _cudnn_gemm_fp4_requirement + graph = _get_cudnn_fp4_gemm_graph( + a=a, + b=b, + a_descale=a_descale, + b_descale=b_descale, + alpha=alpha, + out_dtype=out_dtype, + out=out, + block_size=block_size, + use_nvfp4=use_nvfp4, + tactic=-1, + ) + + num_plans = graph.get_execution_plan_count() + return list(range(num_plans)) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + ( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ) = inputs + _cudnn_gemm_fp4( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + tactic=tactic, + ) + + return CudnnFp4GemmRunner() + + +def _check_mm_fp4_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, # unused + block_size: int = 16, + use_8x4_sf_layout: bool = False, # unused + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): # Generic checks @@ -1725,11 +1862,6 @@ def _check_mm_fp4_problem_size( f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations." ) - if backend != "trtllm" and use_8x4_sf_layout: - raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") - if backend != "cudnn" and not use_nvfp4: - raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.") - if use_nvfp4 and block_size != 16: raise ValueError("nvfp4 only supports block_size = 16.") if not use_nvfp4 and block_size != 32: @@ -1746,12 +1878,14 @@ def _cudnn_gemm_fp4_requirement( b_descale: torch.Tensor, alpha: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, # unused block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): + if use_8x4_sf_layout: + raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") if ( not use_nvfp4 and _match_sm_version(a.device, ["120"]) @@ -1774,7 +1908,8 @@ def _cudnn_gemm_fp4_requirement( _expand_block_scale_tensor_shape(b_descale, batch) ) - # build the fp4 cudnn graph + # build the fp4 cudnn graph. This graph will be cached & reused in mm_fp4() + # because the graph is constructed with @functools.cache decorator graph = create_cudnn_execution_plans_fp4_gemm( real_a_shape, real_a_stride, @@ -1798,18 +1933,20 @@ def _cudnn_gemm_fp4_requirement( @supported_compute_capability([100, 103]) def _trtllm_gemm_fp4_requirement( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: Optional[torch.Tensor] = None, + a: torch.Tensor, # unused + b: torch.Tensor, # unused + a_descale: torch.Tensor, # unused + b_descale: torch.Tensor, # unused + alpha: Optional[torch.Tensor] = None, # unused out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, - block_size: int = 16, - use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + out: Optional[torch.Tensor] = None, # unused + block_size: int = 16, # unused + use_8x4_sf_layout: bool = False, # unused + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused use_nvfp4: bool = True, ): + if not use_nvfp4: + raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.") if out_dtype != torch.bfloat16: raise ValueError( f"Unsupported output dtype: {out_dtype}. " @@ -1820,6 +1957,27 @@ def _trtllm_gemm_fp4_requirement( @supported_compute_capability([100, 103, 110, 120, 121]) def _cutlass_gemm_fp4_requirement( + a: torch.Tensor, # unused + b: torch.Tensor, # unused + a_descale: torch.Tensor, # unused + b_descale: torch.Tensor, # unused + alpha: Optional[torch.Tensor] = None, # unused + out_dtype: torch.dtype = torch.bfloat16, # unused + out: Optional[torch.Tensor] = None, # unused + block_size: int = 16, # unused + use_8x4_sf_layout: bool = False, + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused + use_nvfp4: bool = True, +): + if use_8x4_sf_layout: + raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") + if not use_nvfp4: + raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.") + return True + + +def _heuristic_func_mm_fp4( + suitable_backends: List[str], a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, @@ -1829,19 +1987,42 @@ def _cutlass_gemm_fp4_requirement( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn", use_nvfp4: bool = True, ): - return True + r""" + Heuristic function for mm_fp4 backend selection. Routes to either cudnn or cutlass. + Note: trtllm is not considered in the backend selection because it requires a specific + input quantization (swizzling/shuffling) that differs from the preparation used + for cudnn and cutlass backends. + + Logic for which comes first: + - If cuda version is 12 - use cutlass. + - If cuda version is 13 and cudnn version is less than 9.15 - use cutlass. + - If cuda version is 13 and cudnn version is 9.15 or greater - use cudnn. + + """ + cuda_major = get_cuda_version().major + # If cuda version is 13 or greater: + # cudnn is more performant if cudnn version is 9.15 or greater. + if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91500: + candidate_backends = ("cudnn", "cutlass") + # Otherwise, prioritize cutlass + else: + candidate_backends = ("cutlass", "cudnn") + + # Filter and return only supported backends + return [c for c in candidate_backends if c in suitable_backends] @backend_requirement( { - "cudnn": _cudnn_gemm_fp4_requirement, # Each backend has its own requirement function + "cudnn": _cudnn_gemm_fp4_requirement, "trtllm": _trtllm_gemm_fp4_requirement, "cutlass": _cutlass_gemm_fp4_requirement, }, - common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends + common_check=_check_mm_fp4_problem_size, + heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends ) def mm_fp4( a: torch.Tensor, @@ -1853,7 +2034,7 @@ def mm_fp4( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", use_nvfp4: bool = True, ) -> torch.Tensor: r"""MM FP4 @@ -1887,8 +2068,8 @@ def mm_fp4( use_8x4_sf_layout: bool Whether to use 8x4 scale factor layout or 128x4 scale factor layout, defaults to False. - backend: Literal["cudnn", "trtllm", "cutlass"] - Backend to use, defaults to "cudnn". + backend: Literal["cudnn", "trtllm", "cutlass", "auto"] + Backend to use, defaults to "auto", which automatically selects the best backend between cudnn and cutlass. use_nvfp4: bool Whether to use nvfp4 quantization or mxfp4 quantization, defaults to False. @@ -1930,70 +2111,78 @@ def mm_fp4( "mm_fp4_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) - if backend == "cudnn": - # the fp4 cudnn graph will be shared for both mm and bmm, so - # here we need to get the 3d shape and stride including the - # batch dimension for both input and block scale tensors. - real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) - real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) - batch = real_a_shape[0] - expanded_a_descale_shape, expanded_a_descale_stride = ( - _expand_block_scale_tensor_shape(a_descale, batch) - ) - expanded_b_descale_shape, expanded_b_descale_stride = ( - _expand_block_scale_tensor_shape(b_descale, batch) - ) + # Auto-select the best backend + if backend == "auto": + backends = mm_fp4.suitable_auto_backends + else: + backends = [backend] - # build the fp4 cudnn graph - graph = build_plans_cudnn_fp4_gemm_graph( - real_a_shape, - real_a_stride, - real_b_shape, - real_b_stride, - expanded_a_descale_shape, - expanded_a_descale_stride, - expanded_b_descale_shape, - expanded_b_descale_stride, - cudnn.data_type.FP4_E2M1, - _torch_data_type_to_cudnn_data_type(out_dtype), - block_size, - a.device, - alpha is not None, - use_nvfp4, - ) + # At this point, backends contains a supported backend if specified, or all supported backends if backend='auto'. + # Lazy initialization of runners to avoid overhead of creating a new runner that will not be used + major, _ = get_compute_capability(a.device) - # execute the fp4 cudnn graph - execute_cudnn_gemm_fp4_graph( - graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer - ) - elif backend == "trtllm": - get_trtllm_fp4_gemm_module().trtllm_fp4_gemm( - a, - b.T, - a_descale, - b_descale.T, - alpha, - out, - use_8x4_sf_layout=use_8x4_sf_layout, - workspace_buffer=workspace_buffer, - ) - elif backend == "cutlass": - # cutlass require uint8 scale when a/b is fp4 packed uint8. - if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: - a_descale = a_descale.view(torch.uint8) - if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: - b_descale = b_descale.view(torch.uint8) - - # Dispatch to the correct module based on device architecture - major, _ = get_compute_capability(a.device) - if major == 12: - gemm_module = get_gemm_sm120_module_cutlass_fp4() - else: - gemm_module = get_gemm_sm100_module_cutlass_fp4() + backend_to_runner_factory = { + "cudnn": lambda: _cudnn_gemm_fp4_runner(), + "trtllm": lambda: get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner( + use_8x4_sf_layout + ), + "cutlass": lambda: get_cutlass_fp4_gemm_module(major).cutlass_fp4_gemm_runner(), + } + runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends] - gemm_module.cutlass_fp4_gemm( - a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer - ) + # Now we have a list of runners for desired & supported backends. + tuner = AutoTuner.get() + + a_tensor_index = 0 + a_scale_tensor_index = 2 + out_tensor_index = 6 + + def pad_up(x, y): + return ((x + y - 1) // y) * y + + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (a_tensor_index,), + (0,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + a_scale_tensor_index, + 0, + lambda shapes: pad_up( + shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 + ), + ), + ConstraintSpec( + out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] + ), + ), + ) + + inputs = [ + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + ] + runner, tactic = tuner.choose_one( + "fp4_gemm", + runners, + tuning_config, + inputs, + ) + + runner(inputs=inputs, tactic=tactic) return out @@ -2355,139 +2544,82 @@ def get_trtllm_fp4_gemm_module(): op = mod.build_and_load() setup_cubin_loader(mod.get_library_path()) - class TrtllmFp4GemmRunner(TunableRunner): - def __init__(self, use_8x4_sf_layout: bool = True): - self._fp4_gemm_runner = op.trtllm_gemm - self._use_8x4_sf_layout = use_8x4_sf_layout + def trtllm_fp4_gemm_runner(use_8x4_sf_layout: bool = True): + class TrtllmFp4GemmRunner(TunableRunner): + def __init__(self, use_8x4_sf_layout: bool = True): + self._fp4_gemm_runner = op.trtllm_gemm + self._use_8x4_sf_layout = use_8x4_sf_layout - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - a_tensor_index = 1 - b_tensor_index = 2 - - a = profile.get_opt_shapes()[a_tensor_index] - b = profile.get_opt_shapes()[b_tensor_index] - m = a[0] - n = b[0] - k = a[1] * 2 - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs - type_e2m1 = 0 - type_bf16 = 2 - return list( - op.trtllm_gemm_tactics( - m, n, k, type_e2m1, type_bf16, self._use_8x4_sf_layout + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + a_tensor_index = 1 + b_tensor_index = 2 + + a = profile.get_opt_shapes()[a_tensor_index] + b = profile.get_opt_shapes()[b_tensor_index] + m = a[0] + n = b[0] + k = a[1] * 2 + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + type_e2m1 = 0 + type_bf16 = 2 + return list( + op.trtllm_gemm_tactics( + m, n, k, type_e2m1, type_bf16, self._use_8x4_sf_layout + ) ) - ) - def forward( - self, - inputs: List[torch.Tensor], - tactic: int = -1, - do_preparation: bool = False, - **kwargs, - ): - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs - op.trtllm_gemm( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - self._use_8x4_sf_layout, - tactic, - ) - return out - - @register_custom_op( - "flashinfer::trtllm_fp4_gemm", - mutates_args=(""), - ) - def trtllm_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - use_8x4_sf_layout: bool, - workspace_buffer: torch.Tensor, - ): - tuner = AutoTuner.get() - - a_tensor_index = 1 - a_scale_tensor_index = 3 - out_tensor_index = 6 - - def pad_up(x, y): - return ((x + y - 1) // y) * y - - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up( - shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 - ), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] - ), - ), - ) - - fp4_runner = TrtllmFp4GemmRunner(use_8x4_sf_layout) - - inputs = [ - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ] - _, tactic = tuner.choose_one( - "trtllm_fp4_gemm_8x4" if use_8x4_sf_layout else "trtllm_fp4_gemm_128x4", - [fp4_runner], - tuning_config, - inputs, - ) + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ): + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + self._fp4_gemm_runner( + workspace_buffer, + a, + b.T, + a_descale, + b_descale.T, + alpha, + out, + self._use_8x4_sf_layout, + tactic, + ) + return out - fp4_runner(inputs=inputs, tactic=tactic) + return TrtllmFp4GemmRunner(use_8x4_sf_layout) # Register the module return SimpleNamespace( - trtllm_fp4_gemm=trtllm_fp4_gemm, + trtllm_fp4_gemm_runner=trtllm_fp4_gemm_runner, ) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 76689bab84..e323125efa 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -921,6 +921,13 @@ def backend_requirement( backends. Should accept the same arguments as the decorated function and return True if requirements are met, False otherwise. In the case where the kernel function does not have any specific backends, this can be decorated with @supported_compute_capability to specify the function's supported compute capabilities. + heuristic_func : callable, optional + A function that performs heuristic backend selection when backend is "auto". + Must be provided if backend is "auto". Does not do anything if backend is not "auto". + Should accept the same arguments as the decorated function. + Should return an ordered list of runnable backends with the most preferred backend first. + When decorated function is not autotuned, the first backend in the heuristic list will be run. + When decorated function is autotuned, the backends in the heuristic list will be autotuned over to find the best backend. Returns ------- @@ -1076,8 +1083,8 @@ def suitable_auto_backends(cc, *args, **kwargs): except ValueError: continue # If a heuristic function is provided, filter the suitable backends based on the heuristic function - if heuristic_func is not None: - suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) + assert heuristic_func is not None, "Heuristic function must be provided" + suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) if not suitable_backends: return False wrapper.suitable_auto_backends = suitable_backends diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index 9d7a7abbbd..cc85f6126a 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -12,16 +12,7 @@ from flashinfer.gemm.gemm_base import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR -# TODO: Consdier splitting this function up for the various backends -@pytest.mark.parametrize("m", [1, 48, 128, 256, 512]) -@pytest.mark.parametrize("n", [128, 256, 512]) -@pytest.mark.parametrize("k", [128, 256, 512]) -@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"]) -@pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) -@pytest.mark.parametrize("auto_tuning", [False, True]) -@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) -def test_mm_fp4( +def _test_mm_fp4( m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type ): use_nvfp4 = fp4_type == "nvfp4" @@ -40,10 +31,8 @@ def test_mm_fp4( pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.") if not use_128x4_sf_layout and backend != "trtllm": pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False") - if auto_tuning and backend == "cudnn": - pytest.skip("Skipping test for cudnn fp4 with auto_tuning=True") - if not use_nvfp4 and backend != "cudnn": - pytest.skip("mx_fp4 is only supported for cudnn backend") + if not use_nvfp4 and backend not in ["cudnn", "auto"]: + pytest.skip("mx_fp4 is only supported for cudnn and auto backends") input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) @@ -105,5 +94,38 @@ def test_mm_fp4( pytest.fail(str(e)) +# TODO: Consdier splitting this function up for the various backends +@pytest.mark.parametrize("m", [1, 48, 128, 256, 512]) +@pytest.mark.parametrize("n", [128, 256, 512]) +@pytest.mark.parametrize("k", [128, 256, 512]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"]) +@pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) +@pytest.mark.parametrize("auto_tuning", [False, True]) +@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) +def test_mm_fp4( + m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type +): + # Non-auto backends + _test_mm_fp4( + m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type + ) + + +# Split tests for checking auto functionality +@pytest.mark.parametrize("m", [1, 48, 256, 512]) +@pytest.mark.parametrize("n", [256, 512]) +@pytest.mark.parametrize("k", [256, 512]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("use_128x4_sf_layout", [True]) +@pytest.mark.parametrize("auto_tuning", [False, True]) +@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) +def test_mm_fp4_backend_auto( + m, n, k, res_dtype, use_128x4_sf_layout, auto_tuning, fp4_type +): + # Some test cases for auto backend. + _test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index ebbda781fb..f8659c2e44 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -344,7 +344,27 @@ def _cutlass_check(x, backend): def _cudnn_check(x, backend): return x.shape[0] > 5 - @backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check}) + # When using an auto backend, some heuristic function must exist + def _heuristic_func(suitable_backends, x, backend): + candidate_backends = None + if x.shape[0] > 5: + candidate_backends = ["cudnn", "cutlass"] + else: + candidate_backends = ["cutlass", "cudnn"] + + heuristic_backends = [] + for backend in candidate_backends: + if backend in suitable_backends: + heuristic_backends.append(backend) + return heuristic_backends + + @backend_requirement( + backend_checks={ + "cutlass": _cutlass_check, + "cudnn": _cudnn_check, + }, + heuristic_func=_heuristic_func, + ) def my_kernel(x, backend="auto"): backends = my_kernel.suitable_auto_backends if x.shape[0] > 5: