diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 40823da388..b29e04873f 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -43,6 +43,7 @@ is_sm120a_supported, is_sm121a_supported, LibraryError, + supports_backends, ) CUDNN_AVAILABLE = False @@ -1644,7 +1645,7 @@ def _validate_fp8_output_dtype(dtype: torch.dtype): @functools.cache -def build_cudnn_gemm_block_scale_dequantize_graph( +def create_cudnn_execution_plans_fp4_gemm( a_shape, a_stride, b_shape, @@ -1741,12 +1742,49 @@ def build_cudnn_gemm_block_scale_dequantize_graph( # in older cuDNN versions, so we deselect it. if (alpha is not None) and (not _is_cublas_fp4_available_in_cudnn()): graph.deselect_engines(["eng0"]) - graph.check_support() - graph.build_plans() return graph +@functools.cache +def build_plans_cudnn_fp4_gemm_graph( + a_shape, + a_stride, + b_shape, + b_stride, + a_descale_shape, + a_descale_stride, + b_descale_shape, + b_descale_stride, + ab_type, + o_type, + block_size, + device, + alpha, + use_nvfp4, +): + graph = create_cudnn_execution_plans_fp4_gemm( + a_shape, + a_stride, + b_shape, + b_stride, + a_descale_shape, + a_descale_stride, + b_descale_shape, + b_descale_stride, + ab_type, + o_type, + block_size, + device, + alpha, + use_nvfp4, + ) + + graph.check_support() + graph.build_plans() + return graph + + def execute_cudnn_gemm_fp4_graph( graph, a, @@ -1999,6 +2037,127 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): return (tuple(block_scale_shape), tuple(block_scale_stride)) +def _check_mm_fp4_backend_supported( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_8x4_sf_layout: bool = False, + backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + use_nvfp4: bool = True, +): + # Generic checks + ## pre-check the input tensor, block scale tensor and alpha tensor + if a.ndim != 2 or b.ndim != 2: + raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}") + if a.shape[1] != b.shape[0]: + raise ValueError( + f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}" + ) + if a.dtype not in {torch.uint8, _get_native_fp4_dtype()} or b.dtype not in { + torch.uint8, + _get_native_fp4_dtype(), + }: + raise ValueError( + f"a and b must have float4_e2m1fn_x2 packed into uint8. " + f"Got {a.dtype} and {b.dtype}." + ) + if a_descale.dtype not in { + torch.float8_e4m3fn, + torch.uint8, + } or b_descale.dtype not in {torch.float8_e4m3fn, torch.uint8}: + raise ValueError( + f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. " + f"Got {a_descale.dtype} and {b_descale.dtype}." + ) + if alpha is not None and alpha.dtype != torch.float: + raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}") + if alpha is not None and alpha.numel() != 1: + raise ValueError(f"alpha must be a scalar, got {alpha.numel()}") + + if out_dtype not in (torch.bfloat16, torch.float16): + raise ValueError( + f"Unsupported output dtype: {out_dtype}. " + f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations." + ) + + if use_nvfp4 and block_size != 16: + raise ValueError("nvfp4 only supports block_size = 16.") + if not use_nvfp4 and block_size != 32: + raise ValueError("mxfp4 supports block_size = 32.") + + if backend != "trtllm" and use_8x4_sf_layout: + raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") + if backend != "cudnn" and not use_nvfp4: + raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.") + + # Backend specific checks + if backend == "cudnn": + if ( + not use_nvfp4 + and _match_sm_version(a.device, ["120"]) + and cudnn.backend_version() < 91400 + ): + raise LibraryError( + "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0." + ) + + _check_cudnn_fp4_availability() + + # the fp4 cudnn graph will be shared for both mm and bmm, so + # here we need to get the 3d shape and stride including the + # batch dimension for both input and block scale tensors. + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + expanded_a_descale_shape, expanded_a_descale_stride = ( + _expand_block_scale_tensor_shape(a_descale, batch) + ) + expanded_b_descale_shape, expanded_b_descale_stride = ( + _expand_block_scale_tensor_shape(b_descale, batch) + ) + + # build the fp4 cudnn graph + graph = create_cudnn_execution_plans_fp4_gemm( + real_a_shape, + real_a_stride, + real_b_shape, + real_b_stride, + expanded_a_descale_shape, + expanded_a_descale_stride, + expanded_b_descale_shape, + expanded_b_descale_stride, + cudnn.data_type.FP4_E2M1, + _torch_data_type_to_cudnn_data_type(out_dtype), + block_size, + a.device, + alpha, + use_nvfp4, + ) + graph.check_support() + + elif backend == "trtllm": + if out_dtype != torch.bfloat16: + raise ValueError( + f"Unsupported output dtype: {out_dtype}. " + f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations." + ) + elif backend == "cutlass": + # No additional checks for cutlass + pass + return True + + +@supports_backends( + ["cudnn", "trtllm", "cutlass"], + anti_capabilities={"trtllm": ["110"]}, + capability_tensor_arg="a", + problem_size_check=_check_mm_fp4_backend_supported, +) def mm_fp4( a: torch.Tensor, b: torch.Tensor, @@ -2073,59 +2232,6 @@ def mm_fp4( >>> out.shape torch.Size([48, 256]) """ - # pre-check the input tensor, block scale tensor and alpha tensor - if a.ndim != 2 or b.ndim != 2: - raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}") - if a.shape[1] != b.shape[0]: - raise ValueError( - f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}" - ) - if a.dtype not in {torch.uint8, _get_native_fp4_dtype()} or b.dtype not in { - torch.uint8, - _get_native_fp4_dtype(), - }: - raise ValueError( - f"a and b must have float4_e2m1fn_x2 packed into uint8. " - f"Got {a.dtype} and {b.dtype}." - ) - if a_descale.dtype not in { - torch.float8_e4m3fn, - torch.uint8, - } or b_descale.dtype not in {torch.float8_e4m3fn, torch.uint8}: - raise ValueError( - f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. " - f"Got {a_descale.dtype} and {b_descale.dtype}." - ) - if alpha is not None and alpha.dtype != torch.float: - raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}") - if alpha is not None and alpha.numel() != 1: - raise ValueError(f"alpha must be a scalar, got {alpha.numel()}") - - if out_dtype not in (torch.bfloat16, torch.float16): - raise ValueError( - f"Unsupported output dtype: {out_dtype}. " - f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations." - ) - - if use_nvfp4 and block_size != 16: - raise ValueError("nvfp4 only supports block_size = 16.") - if not use_nvfp4 and block_size != 32: - raise ValueError("mxfp4 supports block_size = 32.") - if backend != "trtllm" and use_8x4_sf_layout: - raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") - if backend == "trtllm" and _match_sm_version(a.device, ["110"]): - raise ValueError("TRTLLM FP4 GEMM is not supported on SM110.") - if backend != "cudnn" and not use_nvfp4: - raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.") - if ( - backend == "cudnn" - and not use_nvfp4 - and _match_sm_version(a.device, ["120"]) - and cudnn.backend_version() < 91400 - ): - raise LibraryError( - "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0." - ) # allocate the output tensor if not provided if out is None: @@ -2140,8 +2246,6 @@ def mm_fp4( ) if backend == "cudnn": - _check_cudnn_fp4_availability() - # the fp4 cudnn graph will be shared for both mm and bmm, so # here we need to get the 3d shape and stride including the # batch dimension for both input and block scale tensors. @@ -2156,7 +2260,7 @@ def mm_fp4( ) # build the fp4 cudnn graph - graph = build_cudnn_gemm_block_scale_dequantize_graph( + graph = build_plans_cudnn_fp4_gemm_graph( real_a_shape, real_a_stride, real_b_shape, @@ -2178,12 +2282,6 @@ def mm_fp4( graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer ) elif backend == "trtllm": - if out_dtype != torch.bfloat16: - raise ValueError( - f"Unsupported output dtype: {out_dtype}. " - f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations." - ) - get_trtllm_fp4_gemm_module().trtllm_fp4_gemm( a, b.T, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 06deaf55ea..75a108cfb1 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -18,6 +18,8 @@ import math from enum import Enum from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union +import inspect + import torch import torch.version @@ -60,6 +62,12 @@ class LibraryError(Exception): pass +class BackendSupportedError(Exception): + """Custom exception for backend-related errors.""" + + pass + + def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor: if x.ndim not in [4, 5]: raise ValueError("x must be 4D or 5D") @@ -740,3 +748,134 @@ def get_shuffle_matrix_sf_a_row_indices( row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) return row_indices + + +def supports_backends( + backends, + capabilities=None, + anti_capabilities=None, + capability_tensor_arg=None, + problem_size_check=None, +): + """Decorator to validate backend and capability support for functions. + + This decorator wraps functions to ensure they are only called with supported + backends and optionally validates compute capabilities for specific tensor arguments. + + Args: + backends (list): List of supported backend strings (e.g., ['cudnn', 'trtllm', 'cutlass']). + capabilities (dict, optional): Dictionary mapping backends to lists of supported + capabilities. Format: {'backend': ['capability1', 'capability2']}. + If provided, only listed capabilities are allowed for each backend. + anti_capabilities (dict, optional): Dictionary mapping backends to lists of + unsupported capabilities. Format: {'backend': ['capability1', 'capability2']}. + Takes precedence over capabilities - if a capability is in anti_capabilities, + it will be rejected even if listed in capabilities. + capability_tensor_arg (str, optional): Name of the tensor argument to use for + automatic compute capability detection. The tensor's device compute capability + will be extracted and used for validation. + problem_size_check (callable, optional): Function to check if the problem size is supported. + Returns: + callable: Decorator function that wraps the target function. + + Raises: + BackendSupportedError: If the function is called with an unsupported backend + or backend/capability combination. + ValueError: If capability_tensor_arg is specified but the tensor is None. + + Added Attributes: + The decorated function gains two additional methods: + - is_compute_capability_supported(capability): Returns True if any backend supports the capability. + - is_backend_supported(backend, capability=None): Returns True if the specific + backend supports the given capability. + - is_problem_size_supported(args, kwargs): Returns True if the problem size is supported. + + Example: + >>> def simple_problem_size_check(input_tensor, backend): + ... return True + >>> @supports_backends(["cudnn", "trtllm"], + ... problem_size_check=simple_problem_size_check, + ... capabilities={'cudnn': ['100', '110', '102']}, + ... anti_capabilities={"trtllm": ["110"]}, + ... capability_tensor_arg='input_tensor') + ... def my_function(input_tensor, backend): + ... return f"Processing on {backend}" + + >>> # Check support without calling + >>> my_function.is_compute_capability_supported('110') # True + >>> my_function.is_backend_supported('cudnn', '110') # True + >>> my_function.is_backend_supported('trtllm', '110') # False (anti-capability) + + >>> # Function calls with validation + >>> tensor = torch.tensor([1, 2, 3]).cuda() # Assume SM_100 + >>> result = my_function(tensor, backend='cudnn') # OK + >>> result = my_function(tensor, backend='cutlass') # Raises BackendSupportedError + """ + + def decorator(func): + # Returns True if backend is supported; with capability, also checks if backend specifically supports it + def is_backend_supported(backend, capability=None): + if backend not in backends: + return False + if capability: + # Anti-capabilities take precedence + if anti_capabilities and backend in anti_capabilities: + if capability in anti_capabilities[backend]: + return False + # Capabilities allow-list + if capabilities and backend in capabilities: + return capability in capabilities[backend] + return True + + # Returns True if any backend supports this capability + def is_compute_capability_supported(capability): + for backend in backends: + if capabilities and backend in capabilities: + if capability in capabilities[backend]: + return True + elif anti_capabilities and backend in anti_capabilities: + if capability in anti_capabilities[backend]: + return False + else: + return True + return False + + def is_problem_size_supported(*args, **kwargs): + if not problem_size_check: + raise ValueError( + f"Problem size check function is not provided for {func.__name__}" + ) + else: + return problem_size_check(*args, **kwargs) + + def wrapper(*args, **kwargs): + backend = kwargs.get("backend") + capability = None + if capability_tensor_arg: + tensor = kwargs.get(capability_tensor_arg) + # When it wasn't provided as a keyword argument, try to get it from the arguments + if tensor is None: + params = list(inspect.signature(func).parameters) + idx = params.index(capability_tensor_arg) + tensor = args[idx] + if tensor is None: + raise ValueError("Invalid tensor on capability support check") + major, minor = get_compute_capability(tensor.device) + capability = f"{major * 10 + minor}" + if not is_backend_supported(backend, capability): + extra = f" with capability {capability}" if capability else "" + raise BackendSupportedError( + f"{func.__name__} does not support backend '{backend}'{extra}" + ) + if not is_problem_size_supported(*args, **kwargs): + raise ValueError(f"Problem size is not supported for {func.__name__}") + return func(*args, **kwargs) + + wrapper.is_compute_capability_supported = is_compute_capability_supported + wrapper.is_backend_supported = is_backend_supported + wrapper.is_problem_size_supported = is_problem_size_supported + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + return wrapper + + return decorator diff --git a/tests/GEMM/test_mm_fp4.py b/tests/GEMM/test_mm_fp4.py index ef2d3b9659..5a7eac178d 100644 --- a/tests/GEMM/test_mm_fp4.py +++ b/tests/GEMM/test_mm_fp4.py @@ -89,7 +89,7 @@ def test_mm_fp4( cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) assert cos_sim > 0.97 - except LibraryError: + except LibraryError as e: # TODO: Remove this check once cuDNN backend version is updated to 9.14.0 if ( backend == "cudnn" @@ -100,7 +100,7 @@ def test_mm_fp4( "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0." ) else: - pytest.fail("Unexpected LibraryError") + pytest.fail(str(e)) if __name__ == "__main__":