Skip to content

Conversation

nvmbreughe
Copy link
Contributor

@nvmbreughe nvmbreughe commented Sep 29, 2025

📌 Description

This PR adds is_*supported checks for backend and compute capability, through decorators.

  1. This allows us to check support before running
  2. It also wraps the original function so it calls back the support check before running.

Example of (1):
Screenshot 2025-09-29 at 6 13 35 PM

Note that to support uniformity across API's this is limited to backend and compute capability checks. Any other type of checks are still part of the regular function implementation.

Note further that compute capability checks are optional, and by default, all compute capabilities are assumed to pass unless specified otherwise.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes



@supports_backends(
["cudnn", "trtllm", "cutlass"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this redundant with the declaration of the backend parameter?

backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",



def supports_backends(
backends, capabilities=None, anti_capabilities=None, capability_tensor_arg=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if "cc" or "compute_capabilities" would be more clear than "capabilities" -- or did you mean to signal something more generic than compute capabilities?

wrapper.is_backend_supported = is_backend_supported
wrapper.is_problem_size_supported = is_problem_size_supported
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use functools.wraps for more robust wrapping and standardized interface.

backend = kwargs.get("backend")
capability = None
if capability_tensor_arg:
tensor = kwargs.get(capability_tensor_arg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we need capability_tensor_arg instead of finding torch.Tensors automatically and get the capability from them? We can also assert that they're all on the same device.

capabilities=None,
anti_capabilities=None,
capability_tensor_arg=None,
problem_size_check=None,
Copy link
Contributor

@nvjullin nvjullin Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type hints, perhaps excluding problem_size_check if it's too tedious or just typing.Callable.

@nvjullin
Copy link
Contributor

nvjullin commented Oct 1, 2025

The checks currently live very far away from the implementation and updating them to be consistent with each other can eventually become a maintenance problem. The conditional checks are also quite tricky to get correct. For example, it's not easy to tell if the mxfp4 checks are correct.

    if not use_nvfp4 and block_size != 32:
        raise ValueError("mxfp4 supports block_size = 32.")

    if backend != "cudnn" and not use_nvfp4:
        raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")

Shouldn't the checks be reordered to avoid confusing error messages?

  1. User tries trtllm + block_size=16 and gets rejected by block_size=32
  2. User then tries trtllm + block_size=32 and gets rejected by only cudnn is supported

Instead of having one top level supports_backends, perhaps consider a two level design:

  1. Local requirement decorator requirement written for each backend entrypoint
  2. Top level backend_requirement that composes requirements

For example:

def cudnn_gemm_fp4_requirement(
    # ...
):
        if (
            not use_nvfp4
            and _match_sm_version(a.device, ["120"])
            and cudnn.backend_version() < 91400
        ):
            raise LibraryError(
                "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
            )

        _check_cudnn_fp4_availability()
        # ...

@requirement(cudnn_gemm_fp4_requirement, capability=["100", "101", "102"])
def execute_cudnn_gemm_fp4_graph(
    # ...


@backend_requirement({
    "cudnn": execute_cudnn_gemm_fp4_graph.requirement,
    "trtllm": #...
})
def mm_fp4(
    # ...

This also means that all requirements are enforced to be local to the backend and won't affect each other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants