-
Notifications
You must be signed in to change notification settings - Fork 530
Support checks PoC #1809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support checks PoC #1809
Conversation
|
||
|
||
@supports_backends( | ||
["cudnn", "trtllm", "cutlass"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this redundant with the declaration of the backend
parameter?
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
flashinfer/utils.py
Outdated
|
||
|
||
def supports_backends( | ||
backends, capabilities=None, anti_capabilities=None, capability_tensor_arg=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I wonder if "cc" or "compute_capabilities" would be more clear than "capabilities" -- or did you mean to signal something more generic than compute capabilities?
wrapper.is_backend_supported = is_backend_supported | ||
wrapper.is_problem_size_supported = is_problem_size_supported | ||
wrapper.__name__ = func.__name__ | ||
wrapper.__doc__ = func.__doc__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can use functools.wraps for more robust wrapping and standardized interface.
backend = kwargs.get("backend") | ||
capability = None | ||
if capability_tensor_arg: | ||
tensor = kwargs.get(capability_tensor_arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why we need capability_tensor_arg
instead of finding torch.Tensor
s automatically and get the capability from them? We can also assert that they're all on the same device.
capabilities=None, | ||
anti_capabilities=None, | ||
capability_tensor_arg=None, | ||
problem_size_check=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type hints, perhaps excluding problem_size_check
if it's too tedious or just typing.Callable
.
The checks currently live very far away from the implementation and updating them to be consistent with each other can eventually become a maintenance problem. The conditional checks are also quite tricky to get correct. For example, it's not easy to tell if the mxfp4 checks are correct. if not use_nvfp4 and block_size != 32:
raise ValueError("mxfp4 supports block_size = 32.")
if backend != "cudnn" and not use_nvfp4:
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.") Shouldn't the checks be reordered to avoid confusing error messages?
Instead of having one top level
For example: def cudnn_gemm_fp4_requirement(
# ...
):
if (
not use_nvfp4
and _match_sm_version(a.device, ["120"])
and cudnn.backend_version() < 91400
):
raise LibraryError(
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
)
_check_cudnn_fp4_availability()
# ...
@requirement(cudnn_gemm_fp4_requirement, capability=["100", "101", "102"])
def execute_cudnn_gemm_fp4_graph(
# ...
@backend_requirement({
"cudnn": execute_cudnn_gemm_fp4_graph.requirement,
"trtllm": #...
})
def mm_fp4(
# ... This also means that all requirements are enforced to be local to the backend and won't affect each other. |
📌 Description
This PR adds is_*supported checks for backend and compute capability, through decorators.
Example of (1):

Note that to support uniformity across API's this is limited to backend and compute capability checks. Any other type of checks are still part of the regular function implementation.
Note further that compute capability checks are optional, and by default, all compute capabilities are assumed to pass unless specified otherwise.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes