From a3249a4f561cce10e4c6e918475455fc7ee84965 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 6 Oct 2025 00:21:21 +0000 Subject: [PATCH 1/8] Initial plan From af9d39c9d7a07b826b6964ce19bacd165f7147b3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 6 Oct 2025 00:34:17 +0000 Subject: [PATCH 2/8] Add dynamic CU count detection and helper functions Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .../benchmark_all_gather_gemm_pull.py | 9 ++++- .../benchmark_all_gather_gemm_push.py | 9 ++++- examples/07_gemm_all_scatter/benchmark.py | 7 +++- .../08_gemm_atomics_all_reduce/benchmark.py | 13 +++++-- .../09_gemm_one_shot_all_reduce/benchmark.py | 11 +++++- .../benchmark.py | 11 +++++- .../benchmark.py | 14 ++++++- .../benchmark.py | 12 +++++- examples/benchmark/bench_all_shapes.py | 37 ++++++++++++++----- iris/hip.py | 36 ++++++++++++++++++ scripts/link_bandwidth.py | 9 ++++- 11 files changed, 141 insertions(+), 27 deletions(-) diff --git a/benchmark/examples/benchmark_all_gather_gemm_pull.py b/benchmark/examples/benchmark_all_gather_gemm_pull.py index 01640050..7964d775 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_pull.py +++ b/benchmark/examples/benchmark_all_gather_gemm_pull.py @@ -17,6 +17,7 @@ import importlib.util from pathlib import Path import iris +from iris.hip import get_cu_count, get_default_gemm_sms current_dir = Path(__file__).parent file_path = (current_dir / "../../examples/14_all_gather_gemm/all_gather_gemm_pull.py").resolve() @@ -63,7 +64,7 @@ def parse_args(): parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for the kernel") parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for the kernel") parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension") - parser.add_argument("--num_sms", type=int, default=304, help="Number of SMs for the kernel") + parser.add_argument("--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)") parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") @@ -138,7 +139,11 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace): A_local_iris = shmem.empty((M, K_local), dtype=datatype) A_local_iris.copy_(A_local) - num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + # Use provided num_sms or auto-detect + if run_args["num_sms"] is None: + num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + else: + num_sms = run_args["num_sms"] main_stream = torch.cuda.Stream() kernel_timing = { diff --git a/benchmark/examples/benchmark_all_gather_gemm_push.py b/benchmark/examples/benchmark_all_gather_gemm_push.py index 1984006d..f245be48 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_push.py +++ b/benchmark/examples/benchmark_all_gather_gemm_push.py @@ -17,6 +17,7 @@ import importlib.util from pathlib import Path import iris +from iris.hip import get_cu_count, get_default_gemm_sms current_dir = Path(__file__).parent file_path = (current_dir / "../../examples/14_all_gather_gemm/all_gather_gemm_push.py").resolve() @@ -62,7 +63,7 @@ def parse_args(): parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for GEMM computation") parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for tiling") parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension") - parser.add_argument("--num_sms", type=int, default=304, help="Number of SMs for the kernel") + parser.add_argument("--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)") parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") @@ -142,7 +143,11 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace): num_k_tiles = (K_local + run_args["BLK_K"] - 1) // run_args["BLK_K"] signal_flags_iris = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32) - num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + # Use provided num_sms or auto-detect + if run_args["num_sms"] is None: + num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + else: + num_sms = run_args["num_sms"] main_stream = torch.cuda.Stream() kernel_timing = { diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index 6d872e61..8cc76e58 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -15,6 +15,7 @@ from matmul_wrapper import matmul import iris +from iris.hip import get_cu_count, get_default_gemm_sms from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -52,7 +53,7 @@ def parse_args(): parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("--gemm_sms", type=int, default=304, help="Number of SMs for persistent GEMM algorithm") + parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for persistent GEMM algorithm (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -69,6 +70,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() cu_count = shmem.get_cu_count() + # Set default SM values if not provided + if args["gemm_sms"] is None: + args["gemm_sms"] = get_default_gemm_sms(algorithm="all_scatter") + # GEMM datatype = torch.float32 if args["datatype"] == "fp16": diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index 31de4fa3..f47c6e1f 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -19,6 +19,7 @@ ) import iris +from iris.hip import get_cu_count, get_default_gemm_sms from matmul_wrapper import matmul from examples.common.validation import validate_gemm @@ -68,10 +69,8 @@ def parse_args(): parser.add_argument("--kpack", type=int, default=2, help="K packing size") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - # For All Scatter, use: 288 - # For One Shot, use: 256 - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") - parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)") + parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -88,6 +87,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() cu_count = shmem.get_cu_count() + # Set default SM values if not provided + if args["total_sms"] is None: + args["total_sms"] = cu_count + if args["gemm_sms"] is None: + args["gemm_sms"] = get_default_gemm_sms(algorithm="all_reduce") + # GEMM datatype = torch.float32 if args["datatype"] == "fp16": diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index 212bc857..d4bca52e 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -19,6 +19,7 @@ ) import iris +from iris.hip import get_cu_count, get_default_gemm_sms from matmul_wrapper import matmul from examples.common.validation import validate_gemm @@ -68,8 +69,8 @@ def parse_args(): parser.add_argument("--kpack", type=int, default=2, help="K packing size") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") - parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)") + parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -84,6 +85,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() cu_count = shmem.get_cu_count() + # Set default SM values if not provided + if args["total_sms"] is None: + args["total_sms"] = cu_count + if args["gemm_sms"] is None: + args["gemm_sms"] = get_default_gemm_sms(algorithm="all_reduce") + # GEMM datatype = torch.float32 if args["datatype"] == "fp16": diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index bb49bacb..549318ba 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -16,6 +16,7 @@ from examples.common.validation import validate_gemm import iris +from iris.hip import get_cu_count, get_default_gemm_sms from matmul_wrapper import matmul @@ -54,9 +55,9 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" + "--gemm_sms", type=int, default=None, help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)" ) - parser.add_argument("--num_sms", type=int, default=304, help="Number of total SMs for gemm + scatter kernel") + parser.add_argument("--num_sms", type=int, default=None, help="Number of total SMs for gemm + scatter kernel (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -72,6 +73,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() cu_count = shmem.get_cu_count() + # Set default SM values if not provided + if args["num_sms"] is None: + args["num_sms"] = cu_count + if args["gemm_sms"] is None: + args["gemm_sms"] = get_default_gemm_sms(algorithm="wg_specialized") + # GEMM datatype = torch.float32 if args["datatype"] == "fp16": diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index 8059c26f..c723bd3f 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -16,6 +16,7 @@ from examples.common.validation import validate_gemm import iris +from iris.hip import get_cu_count, get_default_gemm_sms from matmul_wrapper import matmul from gemm_all_scatter_producer_consumer import persistent_all_scatter @@ -55,9 +56,9 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" + "--gemm_sms", type=int, default=None, help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)" ) - parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Scatter kernel") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -73,6 +74,15 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() cu_count = shmem.get_cu_count() + # Set default SM values if not provided + if args["gemm_sms"] is None: + args["gemm_sms"] = get_default_gemm_sms(algorithm="wg_specialized") + if args["comm_sms"] is None: + # comm_sms is the leftover: total - next_power_of_2 + import math + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + args["comm_sms"] = cu_count - next_pow2 + # GEMM datatype = torch.float32 if args["datatype"] == "fp16": diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py index ea041361..0248c169 100755 --- a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py @@ -16,6 +16,7 @@ from examples.common.validation import validate_gemm import iris +from iris.hip import get_cu_count, get_default_gemm_sms from matmul_wrapper import matmul from gemm_all_scatter_bulk_synchronous import persistent_all_scatter @@ -55,9 +56,9 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" + "--gemm_sms", type=int, default=None, help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)" ) - parser.add_argument("--comm_sms", type=int, default=256, help="Number of SMs for All-Scatter kernel") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -73,6 +74,13 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() cu_count = shmem.get_cu_count() + # Set default SM values if not provided + if args["gemm_sms"] is None: + args["gemm_sms"] = get_default_gemm_sms(algorithm="wg_specialized") + if args["comm_sms"] is None: + # For bulk synchronous, use same as gemm_sms + args["comm_sms"] = get_default_gemm_sms(algorithm="wg_specialized") + # GEMM datatype = torch.float32 if args["datatype"] == "fp16": diff --git a/examples/benchmark/bench_all_shapes.py b/examples/benchmark/bench_all_shapes.py index 0f5ce5b3..74a2233c 100644 --- a/examples/benchmark/bench_all_shapes.py +++ b/examples/benchmark/bench_all_shapes.py @@ -7,6 +7,13 @@ import argparse import json +try: + from iris.hip import get_cu_count + + CU_COUNT_AVAILABLE = True +except Exception: + CU_COUNT_AVAILABLE = False + def launch_sbatch( config, @@ -110,16 +117,28 @@ def main(hashes, config, sbatch_script_content, input_json, tiling_json, dry_run if mkn not in mkn_gemm_tiles: mkn_gemm_tiles[mkn] = {key: entry[key] for key in optional_keys if key in entry} - if config["partition"] is not None: - if "mi300" in config["partition"]: - print("Running on MI300") - gemm_sms = 304 - elif "mi250" in config["partition"]: - print("Running on MI250") - gemm_sms = 104 + # Determine gemm_sms based on available GPU or partition name + if CU_COUNT_AVAILABLE: + try: + gemm_sms = get_cu_count() + print(f"Auto-detected CU count: {gemm_sms}") + except Exception: + # Fall back to partition-based detection + gemm_sms = None else: - print("Assuming MI300") - gemm_sms = 304 + gemm_sms = None + + if gemm_sms is None: + if config["partition"] is not None: + if "mi300" in config["partition"]: + print("Running on MI300 (partition-based)") + gemm_sms = 304 + elif "mi250" in config["partition"]: + print("Running on MI250 (partition-based)") + gemm_sms = 104 + else: + print("Assuming MI300 (default)") + gemm_sms = 304 enable_algorithms = False enable_mkn = True diff --git a/iris/hip.py b/iris/hip.py index ba9e2051..73cead7b 100644 --- a/iris/hip.py +++ b/iris/hip.py @@ -90,6 +90,42 @@ def get_cu_count(device_id=None): return cu_count.value +def get_default_gemm_sms(device_id=None, algorithm="all_scatter"): + """ + Compute the default number of SMs to use for GEMM operations. + + Args: + device_id: Device ID to query (default: current device) + algorithm: Algorithm type - "all_scatter", "all_reduce", or "wg_specialized" + + Returns: + int: Recommended number of SMs for GEMM + + For "all_scatter": Returns total CU count + For "wg_specialized": Returns next smaller power of 2 + For "all_reduce": Returns total - (total - next_pow2) // 3 + """ + import math + + cu_count = get_cu_count(device_id) + + if algorithm == "all_scatter": + return cu_count + + # Compute next smaller power of 2 + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + + if algorithm == "wg_specialized": + return next_pow2 + elif algorithm == "all_reduce": + # Reserve ~1/3 of leftover CUs for communication + leftover = cu_count - next_pow2 + return cu_count - leftover // 3 + else: + # Default to all_scatter behavior + return cu_count + + def get_rocm_version(): major, minor = -1, -1 diff --git a/scripts/link_bandwidth.py b/scripts/link_bandwidth.py index 29faf5f1..18dd266f 100644 --- a/scripts/link_bandwidth.py +++ b/scripts/link_bandwidth.py @@ -3,6 +3,13 @@ import json +try: + from iris.hip import get_cu_count + + cu_count = get_cu_count() +except Exception: + cu_count = 304 # Default for MI300 + # Sample input (replace with file read if needed) config = { "world_size": 8, @@ -26,7 +33,7 @@ "kpack": 2, "heap_size": 8589934592, "gemm_sms": 48, - "total_sms": 304, + "total_sms": cu_count, "communication_block_size": 256, "communication_sms_multiplier": 1, "M": 8192, From 93816776c2db8421ba48a56033a612300be48199 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 6 Oct 2025 00:49:40 +0000 Subject: [PATCH 3/8] Inline SM calculation logic at call sites instead of using helper function Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .../benchmark_all_gather_gemm_pull.py | 6 ++-- .../benchmark_all_gather_gemm_push.py | 6 ++-- examples/07_gemm_all_scatter/benchmark.py | 12 +++++-- .../08_gemm_atomics_all_reduce/benchmark.py | 9 +++-- .../09_gemm_one_shot_all_reduce/benchmark.py | 9 +++-- .../benchmark.py | 19 +++++++--- .../benchmark.py | 17 ++++++--- .../benchmark.py | 20 ++++++++--- iris/hip.py | 36 ------------------- 9 files changed, 74 insertions(+), 60 deletions(-) diff --git a/benchmark/examples/benchmark_all_gather_gemm_pull.py b/benchmark/examples/benchmark_all_gather_gemm_pull.py index 7964d775..72b9c82d 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_pull.py +++ b/benchmark/examples/benchmark_all_gather_gemm_pull.py @@ -17,7 +17,7 @@ import importlib.util from pathlib import Path import iris -from iris.hip import get_cu_count, get_default_gemm_sms +from iris.hip import get_cu_count current_dir = Path(__file__).parent file_path = (current_dir / "../../examples/14_all_gather_gemm/all_gather_gemm_pull.py").resolve() @@ -64,7 +64,9 @@ def parse_args(): parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for the kernel") parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for the kernel") parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension") - parser.add_argument("--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)") + parser.add_argument( + "--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)" + ) parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") diff --git a/benchmark/examples/benchmark_all_gather_gemm_push.py b/benchmark/examples/benchmark_all_gather_gemm_push.py index f245be48..492a2639 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_push.py +++ b/benchmark/examples/benchmark_all_gather_gemm_push.py @@ -17,7 +17,7 @@ import importlib.util from pathlib import Path import iris -from iris.hip import get_cu_count, get_default_gemm_sms +from iris.hip import get_cu_count current_dir = Path(__file__).parent file_path = (current_dir / "../../examples/14_all_gather_gemm/all_gather_gemm_push.py").resolve() @@ -63,7 +63,9 @@ def parse_args(): parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for GEMM computation") parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for tiling") parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension") - parser.add_argument("--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)") + parser.add_argument( + "--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)" + ) parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index 8cc76e58..dc230931 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -15,7 +15,7 @@ from matmul_wrapper import matmul import iris -from iris.hip import get_cu_count, get_default_gemm_sms +from iris.hip import get_cu_count from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -53,7 +53,12 @@ def parse_args(): parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for persistent GEMM algorithm (default: auto-detected)") + parser.add_argument( + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for persistent GEMM algorithm (default: auto-detected)", + ) parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -72,7 +77,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Set default SM values if not provided if args["gemm_sms"] is None: - args["gemm_sms"] = get_default_gemm_sms(algorithm="all_scatter") + # For all_scatter: use total CU count + args["gemm_sms"] = cu_count # GEMM datatype = torch.float32 diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index f47c6e1f..0683dc97 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -19,7 +19,7 @@ ) import iris -from iris.hip import get_cu_count, get_default_gemm_sms +from iris.hip import get_cu_count from matmul_wrapper import matmul from examples.common.validation import validate_gemm @@ -91,7 +91,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if args["total_sms"] is None: args["total_sms"] = cu_count if args["gemm_sms"] is None: - args["gemm_sms"] = get_default_gemm_sms(algorithm="all_reduce") + # For all_reduce: reserve ~1/3 of leftover CUs for communication + import math + + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + leftover = cu_count - next_pow2 + args["gemm_sms"] = cu_count - leftover // 3 # GEMM datatype = torch.float32 diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index d4bca52e..352c3638 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -19,7 +19,7 @@ ) import iris -from iris.hip import get_cu_count, get_default_gemm_sms +from iris.hip import get_cu_count from matmul_wrapper import matmul from examples.common.validation import validate_gemm @@ -89,7 +89,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if args["total_sms"] is None: args["total_sms"] = cu_count if args["gemm_sms"] is None: - args["gemm_sms"] = get_default_gemm_sms(algorithm="all_reduce") + # For all_reduce: reserve ~1/3 of leftover CUs for communication + import math + + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + leftover = cu_count - next_pow2 + args["gemm_sms"] = cu_count - leftover // 3 # GEMM datatype = torch.float32 diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index 549318ba..1c9b94e9 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -16,7 +16,7 @@ from examples.common.validation import validate_gemm import iris -from iris.hip import get_cu_count, get_default_gemm_sms +from iris.hip import get_cu_count from matmul_wrapper import matmul @@ -55,9 +55,17 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=None, help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)" + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)", + ) + parser.add_argument( + "--num_sms", + type=int, + default=None, + help="Number of total SMs for gemm + scatter kernel (default: auto-detected)", ) - parser.add_argument("--num_sms", type=int, default=None, help="Number of total SMs for gemm + scatter kernel (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -77,7 +85,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if args["num_sms"] is None: args["num_sms"] = cu_count if args["gemm_sms"] is None: - args["gemm_sms"] = get_default_gemm_sms(algorithm="wg_specialized") + # For wg_specialized: use next smaller power of 2 + import math + + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index c723bd3f..27dda8ff 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -16,7 +16,7 @@ from examples.common.validation import validate_gemm import iris -from iris.hip import get_cu_count, get_default_gemm_sms +from iris.hip import get_cu_count from matmul_wrapper import matmul from gemm_all_scatter_producer_consumer import persistent_all_scatter @@ -56,9 +56,14 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=None, help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)" + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)", + ) + parser.add_argument( + "--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)" ) - parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -76,10 +81,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Set default SM values if not provided if args["gemm_sms"] is None: - args["gemm_sms"] = get_default_gemm_sms(algorithm="wg_specialized") + # For wg_specialized: use next smaller power of 2 + import math + + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 if args["comm_sms"] is None: # comm_sms is the leftover: total - next_power_of_2 import math + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 args["comm_sms"] = cu_count - next_pow2 diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py index 0248c169..71bf4ff4 100755 --- a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py @@ -16,7 +16,7 @@ from examples.common.validation import validate_gemm import iris -from iris.hip import get_cu_count, get_default_gemm_sms +from iris.hip import get_cu_count from matmul_wrapper import matmul from gemm_all_scatter_bulk_synchronous import persistent_all_scatter @@ -56,9 +56,14 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=None, help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)" + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)", + ) + parser.add_argument( + "--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)" ) - parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -76,10 +81,15 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Set default SM values if not provided if args["gemm_sms"] is None: - args["gemm_sms"] = get_default_gemm_sms(algorithm="wg_specialized") + # For wg_specialized: use next smaller power of 2 + import math + + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 if args["comm_sms"] is None: # For bulk synchronous, use same as gemm_sms - args["comm_sms"] = get_default_gemm_sms(algorithm="wg_specialized") + import math + + args["comm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 diff --git a/iris/hip.py b/iris/hip.py index 73cead7b..ba9e2051 100644 --- a/iris/hip.py +++ b/iris/hip.py @@ -90,42 +90,6 @@ def get_cu_count(device_id=None): return cu_count.value -def get_default_gemm_sms(device_id=None, algorithm="all_scatter"): - """ - Compute the default number of SMs to use for GEMM operations. - - Args: - device_id: Device ID to query (default: current device) - algorithm: Algorithm type - "all_scatter", "all_reduce", or "wg_specialized" - - Returns: - int: Recommended number of SMs for GEMM - - For "all_scatter": Returns total CU count - For "wg_specialized": Returns next smaller power of 2 - For "all_reduce": Returns total - (total - next_pow2) // 3 - """ - import math - - cu_count = get_cu_count(device_id) - - if algorithm == "all_scatter": - return cu_count - - # Compute next smaller power of 2 - next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 - - if algorithm == "wg_specialized": - return next_pow2 - elif algorithm == "all_reduce": - # Reserve ~1/3 of leftover CUs for communication - leftover = cu_count - next_pow2 - return cu_count - leftover // 3 - else: - # Default to all_scatter behavior - return cu_count - - def get_rocm_version(): major, minor = -1, -1 From 0de4d43ff82c577441d61461fd98b95dc338f112 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Oct 2025 21:01:56 +0000 Subject: [PATCH 4/8] Fix all_reduce SM calculation: use next power of 2 for GEMM, rest for comm Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- examples/08_gemm_atomics_all_reduce/benchmark.py | 6 ++---- examples/09_gemm_one_shot_all_reduce/benchmark.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index 0683dc97..2ab121d2 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -91,12 +91,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if args["total_sms"] is None: args["total_sms"] = cu_count if args["gemm_sms"] is None: - # For all_reduce: reserve ~1/3 of leftover CUs for communication + # For all_reduce: use next smaller power of 2, rest for communication import math - next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 - leftover = cu_count - next_pow2 - args["gemm_sms"] = cu_count - leftover // 3 + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index 352c3638..be18d45f 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -89,12 +89,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if args["total_sms"] is None: args["total_sms"] = cu_count if args["gemm_sms"] is None: - # For all_reduce: reserve ~1/3 of leftover CUs for communication + # For all_reduce: use next smaller power of 2, rest for communication import math - next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 - leftover = cu_count - next_pow2 - args["gemm_sms"] = cu_count - leftover // 3 + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 From 00123679ec81a1d5b9e3d875d746fcaccfff90e1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Oct 2025 21:30:09 +0000 Subject: [PATCH 5/8] Use torch.cuda API instead of iris.hip.get_cu_count, move math import to top Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- examples/07_gemm_all_scatter/benchmark.py | 3 +-- .../08_gemm_atomics_all_reduce/benchmark.py | 6 ++---- .../09_gemm_one_shot_all_reduce/benchmark.py | 6 ++---- .../benchmark.py | 6 ++---- .../benchmark.py | 8 ++------ .../benchmark.py | 8 ++------ examples/benchmark/bench_all_shapes.py | 20 +++++++------------ scripts/link_bandwidth.py | 8 +++++--- 8 files changed, 23 insertions(+), 42 deletions(-) diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index dc230931..67cfc251 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -15,7 +15,6 @@ from matmul_wrapper import matmul import iris -from iris.hip import get_cu_count from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -73,11 +72,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() # Set default SM values if not provided if args["gemm_sms"] is None: # For all_scatter: use total CU count + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count args["gemm_sms"] = cu_count # GEMM diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index 2ab121d2..7b12431b 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import ( JSONWriter, @@ -19,7 +20,6 @@ ) import iris -from iris.hip import get_cu_count from matmul_wrapper import matmul from examples.common.validation import validate_gemm @@ -85,15 +85,13 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count if args["total_sms"] is None: args["total_sms"] = cu_count if args["gemm_sms"] is None: # For all_reduce: use next smaller power of 2, rest for communication - import math - args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index be18d45f..406bddd8 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import ( JSONWriter, @@ -19,7 +20,6 @@ ) import iris -from iris.hip import get_cu_count from matmul_wrapper import matmul from examples.common.validation import validate_gemm @@ -83,15 +83,13 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count if args["total_sms"] is None: args["total_sms"] = cu_count if args["gemm_sms"] is None: # For all_reduce: use next smaller power of 2, rest for communication - import math - args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index 1c9b94e9..036410db 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -11,12 +11,12 @@ import os import argparse import json +import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm import iris -from iris.hip import get_cu_count from matmul_wrapper import matmul @@ -79,15 +79,13 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count if args["num_sms"] is None: args["num_sms"] = cu_count if args["gemm_sms"] is None: # For wg_specialized: use next smaller power of 2 - import math - args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index 27dda8ff..8677d4fa 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -11,12 +11,12 @@ import os import argparse import json +import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm import iris -from iris.hip import get_cu_count from matmul_wrapper import matmul from gemm_all_scatter_producer_consumer import persistent_all_scatter @@ -77,18 +77,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count if args["gemm_sms"] is None: # For wg_specialized: use next smaller power of 2 - import math - args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 if args["comm_sms"] is None: # comm_sms is the leftover: total - next_power_of_2 - import math - next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 args["comm_sms"] = cu_count - next_pow2 diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py index 71bf4ff4..de2f855e 100755 --- a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py @@ -11,12 +11,12 @@ import os import argparse import json +import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm import iris -from iris.hip import get_cu_count from matmul_wrapper import matmul from gemm_all_scatter_bulk_synchronous import persistent_all_scatter @@ -77,18 +77,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count if args["gemm_sms"] is None: # For wg_specialized: use next smaller power of 2 - import math - args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 if args["comm_sms"] is None: # For bulk synchronous, use same as gemm_sms - import math - args["comm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM diff --git a/examples/benchmark/bench_all_shapes.py b/examples/benchmark/bench_all_shapes.py index 74a2233c..adb15657 100644 --- a/examples/benchmark/bench_all_shapes.py +++ b/examples/benchmark/bench_all_shapes.py @@ -6,13 +6,7 @@ from datetime import datetime import argparse import json - -try: - from iris.hip import get_cu_count - - CU_COUNT_AVAILABLE = True -except Exception: - CU_COUNT_AVAILABLE = False +import torch def launch_sbatch( @@ -118,14 +112,14 @@ def main(hashes, config, sbatch_script_content, input_json, tiling_json, dry_run mkn_gemm_tiles[mkn] = {key: entry[key] for key in optional_keys if key in entry} # Determine gemm_sms based on available GPU or partition name - if CU_COUNT_AVAILABLE: - try: - gemm_sms = get_cu_count() + try: + if torch.cuda.is_available(): + gemm_sms = torch.cuda.get_device_properties(0).multi_processor_count print(f"Auto-detected CU count: {gemm_sms}") - except Exception: - # Fall back to partition-based detection + else: gemm_sms = None - else: + except Exception: + # Fall back to partition-based detection gemm_sms = None if gemm_sms is None: diff --git a/scripts/link_bandwidth.py b/scripts/link_bandwidth.py index 18dd266f..1c79d250 100644 --- a/scripts/link_bandwidth.py +++ b/scripts/link_bandwidth.py @@ -2,11 +2,13 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import json +import torch try: - from iris.hip import get_cu_count - - cu_count = get_cu_count() + if torch.cuda.is_available(): + cu_count = torch.cuda.get_device_properties(0).multi_processor_count + else: + cu_count = 304 # Default for MI300 except Exception: cu_count = 304 # Default for MI300 From 95c653b9e6b628c1552ab2e436977d47a7ada01a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Oct 2025 21:47:33 +0000 Subject: [PATCH 6/8] Update run_args dict with auto-detected num_sms for correct logging Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- benchmark/examples/benchmark_all_gather_gemm_pull.py | 1 + benchmark/examples/benchmark_all_gather_gemm_push.py | 1 + 2 files changed, 2 insertions(+) diff --git a/benchmark/examples/benchmark_all_gather_gemm_pull.py b/benchmark/examples/benchmark_all_gather_gemm_pull.py index 72b9c82d..53a07e54 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_pull.py +++ b/benchmark/examples/benchmark_all_gather_gemm_pull.py @@ -144,6 +144,7 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace): # Use provided num_sms or auto-detect if run_args["num_sms"] is None: num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + run_args["num_sms"] = num_sms else: num_sms = run_args["num_sms"] diff --git a/benchmark/examples/benchmark_all_gather_gemm_push.py b/benchmark/examples/benchmark_all_gather_gemm_push.py index 492a2639..df586db9 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_push.py +++ b/benchmark/examples/benchmark_all_gather_gemm_push.py @@ -148,6 +148,7 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace): # Use provided num_sms or auto-detect if run_args["num_sms"] is None: num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + run_args["num_sms"] = num_sms else: num_sms = run_args["num_sms"] From 71404e41397272b2957c247b3348e321ebc45335 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Oct 2025 21:50:51 +0000 Subject: [PATCH 7/8] Remove unused get_cu_count imports from benchmark files Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- benchmark/examples/benchmark_all_gather_gemm_pull.py | 1 - benchmark/examples/benchmark_all_gather_gemm_push.py | 1 - 2 files changed, 2 deletions(-) diff --git a/benchmark/examples/benchmark_all_gather_gemm_pull.py b/benchmark/examples/benchmark_all_gather_gemm_pull.py index 53a07e54..0307fa57 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_pull.py +++ b/benchmark/examples/benchmark_all_gather_gemm_pull.py @@ -17,7 +17,6 @@ import importlib.util from pathlib import Path import iris -from iris.hip import get_cu_count current_dir = Path(__file__).parent file_path = (current_dir / "../../examples/14_all_gather_gemm/all_gather_gemm_pull.py").resolve() diff --git a/benchmark/examples/benchmark_all_gather_gemm_push.py b/benchmark/examples/benchmark_all_gather_gemm_push.py index df586db9..a56d53a7 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_push.py +++ b/benchmark/examples/benchmark_all_gather_gemm_push.py @@ -17,7 +17,6 @@ import importlib.util from pathlib import Path import iris -from iris.hip import get_cu_count current_dir = Path(__file__).parent file_path = (current_dir / "../../examples/14_all_gather_gemm/all_gather_gemm_push.py").resolve() From ad29338e422702aa458fc288c401e87d70a90208 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Oct 2025 22:05:18 +0000 Subject: [PATCH 8/8] Hoist repeated power-of-2 calculation to avoid duplication Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- examples/11_gemm_all_scatter_producer_consumer/benchmark.py | 5 +++-- examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index 8677d4fa..bfd6c567 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -80,12 +80,13 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Set default SM values if not provided cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + if args["gemm_sms"] is None: # For wg_specialized: use next smaller power of 2 - args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + args["gemm_sms"] = next_pow2 if args["comm_sms"] is None: # comm_sms is the leftover: total - next_power_of_2 - next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 args["comm_sms"] = cu_count - next_pow2 # GEMM diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py index de2f855e..4d3548f8 100755 --- a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py @@ -80,12 +80,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Set default SM values if not provided cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + if args["gemm_sms"] is None: # For wg_specialized: use next smaller power of 2 - args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + args["gemm_sms"] = next_pow2 if args["comm_sms"] is None: # For bulk synchronous, use same as gemm_sms - args["comm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + args["comm_sms"] = next_pow2 # GEMM datatype = torch.float32