Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions flashinfer/cute_dsl/blockscaled_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,9 @@ def __init__(
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
:type cluster_shape_mn: Tuple[int, int]
"""
assert sm_version == "sm_100", (
"sm_100 is the only supported SM version for cute-dsl backend."
supported_sm_versions = ["sm_100", "sm_103"]
assert sm_version in supported_sm_versions, (
f"{supported_sm_versions} are the only supported SM versions for cute-dsl backend, but encountered {sm_version}"
)

self.acc_dtype = cutlass.Float32
Expand Down Expand Up @@ -561,7 +562,12 @@ def __init__(
self.cta_sync_bar_id = 0
self.epilog_sync_bar_id = 1
self.tmem_ptr_sync_bar_id = 2
self.smem_capacity = utils.get_smem_capacity_in_bytes(sm_version)

# HACK "sm_103" doesn't work yet for the query
# https://github.com/NVIDIA/cutlass/blob/5016493cc0d8650d5b2f6d2c2751cf49bc217e86/python/CuTeDSL/cutlass/utils/smem_allocator.py#L19
# self.smem_capacity = utils.get_smem_capacity_in_bytes(sm_version)
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
Comment on lines +566 to +569
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this hack works for the currently supported SM versions, it's a bit fragile. Hardcoding "sm_100" will be incorrect if another SM version is added in the future that is supported by get_smem_capacity_in_bytes. A more robust approach would be to only apply the fallback for sm_103. This also makes the intent clearer and the hack easier to remove. I'd also recommend adding a TODO to track this technical debt.

Suggested change
# HACK "sm_103" doesn't work yet for the query
# https://github.com/NVIDIA/cutlass/blob/5016493cc0d8650d5b2f6d2c2751cf49bc217e86/python/CuTeDSL/cutlass/utils/smem_allocator.py#L19
# self.smem_capacity = utils.get_smem_capacity_in_bytes(sm_version)
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
# TODO: Remove this workaround once nvidia-cutlass-dsl supports sm_103 for smem capacity queries.
# HACK: "sm_103" is not yet supported by get_smem_capacity_in_bytes. Using "sm_100" as a fallback.
# See: https://github.com/NVIDIA/cutlass/blob/5016493cc0d8650d5b2f6d2c2751cf49bc217e86/python/CuTeDSL/cutlass/utils/smem_allocator.py#L19
smem_query_version = "sm_100" if sm_version == "sm_103" else sm_version
self.smem_capacity = utils.get_smem_capacity_in_bytes(smem_query_version)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this get_smem_capacity_in_bytes issue has been reported to cutlass team internally, if there is a quick turn around for this i'll just patch it later. also since this is just one kernel, if it works, it works, this is not wrong

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no problem with this at the moment, considering sm_100 and sm_103 should have the same shared memory size.


SM100_TMEM_CAPACITY_COLUMNS = 512
self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ einops
ninja
numpy
nvidia-cudnn-frontend>=1.13.0
nvidia-cutlass-dsl>=4.2.1
nvidia-cutlass-dsl>=4.3.1
nvidia-ml-py
packaging>=24.2
requests
Expand Down
10 changes: 6 additions & 4 deletions tests/gemm/test_cute_dsl_blockscaled_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def test_blockscaled_gemm_python_interface(
):
torch.manual_seed(42)
device = torch.device("cuda:0")
major, minor = torch.cuda.get_device_capability(device)

if not (major == 10 and minor == 0):
pytest.skip("Cute-dsl backend is only supported on SM100.")
device_ver = torch.cuda.get_device_capability(device)
supported_device_vers = [(10, 0), (10, 3)]
if device_ver not in supported_device_vers:
pytest.skip(
f"Cute-dsl backend is only supported on {supported_device_vers}, skipping {device_ver}."
)

l, m = lm
k, n = kn
Expand Down