diff --git a/benchmark/examples/benchmark_all_gather_gemm_pull.py b/benchmark/examples/benchmark_all_gather_gemm_pull.py index 01640050..0307fa57 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_pull.py +++ b/benchmark/examples/benchmark_all_gather_gemm_pull.py @@ -63,7 +63,9 @@ def parse_args(): parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for the kernel") parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for the kernel") parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension") - parser.add_argument("--num_sms", type=int, default=304, help="Number of SMs for the kernel") + parser.add_argument( + "--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)" + ) parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") @@ -138,7 +140,12 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace): A_local_iris = shmem.empty((M, K_local), dtype=datatype) A_local_iris.copy_(A_local) - num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + # Use provided num_sms or auto-detect + if run_args["num_sms"] is None: + num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + run_args["num_sms"] = num_sms + else: + num_sms = run_args["num_sms"] main_stream = torch.cuda.Stream() kernel_timing = { diff --git a/benchmark/examples/benchmark_all_gather_gemm_push.py b/benchmark/examples/benchmark_all_gather_gemm_push.py index 1984006d..a56d53a7 100644 --- a/benchmark/examples/benchmark_all_gather_gemm_push.py +++ b/benchmark/examples/benchmark_all_gather_gemm_push.py @@ -62,7 +62,9 @@ def parse_args(): parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for GEMM computation") parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for tiling") parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension") - parser.add_argument("--num_sms", type=int, default=304, help="Number of SMs for the kernel") + parser.add_argument( + "--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)" + ) parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") @@ -142,7 +144,12 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace): num_k_tiles = (K_local + run_args["BLK_K"] - 1) // run_args["BLK_K"] signal_flags_iris = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32) - num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + # Use provided num_sms or auto-detect + if run_args["num_sms"] is None: + num_sms = torch.cuda.get_device_properties(rank).multi_processor_count + run_args["num_sms"] = num_sms + else: + num_sms = run_args["num_sms"] main_stream = torch.cuda.Stream() kernel_timing = { diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index 6d872e61..67cfc251 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -52,7 +52,12 @@ def parse_args(): parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("--gemm_sms", type=int, default=304, help="Number of SMs for persistent GEMM algorithm") + parser.add_argument( + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for persistent GEMM algorithm (default: auto-detected)", + ) parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -67,7 +72,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + if args["gemm_sms"] is None: + # For all_scatter: use total CU count + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + args["gemm_sms"] = cu_count # GEMM datatype = torch.float32 diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_atomics_all_reduce/benchmark.py index 31de4fa3..7b12431b 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_atomics_all_reduce/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import ( JSONWriter, @@ -68,10 +69,8 @@ def parse_args(): parser.add_argument("--kpack", type=int, default=2, help="K packing size") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - # For All Scatter, use: 288 - # For One Shot, use: 256 - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") - parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)") + parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -86,7 +85,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + if args["total_sms"] is None: + args["total_sms"] = cu_count + if args["gemm_sms"] is None: + # For all_reduce: use next smaller power of 2, rest for communication + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 diff --git a/examples/09_gemm_one_shot_all_reduce/benchmark.py b/examples/09_gemm_one_shot_all_reduce/benchmark.py index 212bc857..406bddd8 100755 --- a/examples/09_gemm_one_shot_all_reduce/benchmark.py +++ b/examples/09_gemm_one_shot_all_reduce/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import ( JSONWriter, @@ -68,8 +69,8 @@ def parse_args(): parser.add_argument("--kpack", type=int, default=2, help="K packing size") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") - parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)") + parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -82,7 +83,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + if args["total_sms"] is None: + args["total_sms"] = cu_count + if args["gemm_sms"] is None: + # For all_reduce: use next smaller power of 2, rest for communication + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index bb49bacb..036410db 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -54,9 +55,17 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)", + ) + parser.add_argument( + "--num_sms", + type=int, + default=None, + help="Number of total SMs for gemm + scatter kernel (default: auto-detected)", ) - parser.add_argument("--num_sms", type=int, default=304, help="Number of total SMs for gemm + scatter kernel") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -70,7 +79,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + if args["num_sms"] is None: + args["num_sms"] = cu_count + if args["gemm_sms"] is None: + # For wg_specialized: use next smaller power of 2 + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 # GEMM datatype = torch.float32 diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index 8059c26f..bfd6c567 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -55,9 +56,14 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)", + ) + parser.add_argument( + "--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)" ) - parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Scatter kernel") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -71,7 +77,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + + if args["gemm_sms"] is None: + # For wg_specialized: use next smaller power of 2 + args["gemm_sms"] = next_pow2 + if args["comm_sms"] is None: + # comm_sms is the leftover: total - next_power_of_2 + args["comm_sms"] = cu_count - next_pow2 # GEMM datatype = torch.float32 diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py index ea041361..4d3548f8 100755 --- a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py @@ -11,6 +11,7 @@ import os import argparse import json +import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -55,9 +56,14 @@ def parse_args(): parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument( - "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)", + ) + parser.add_argument( + "--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)" ) - parser.add_argument("--comm_sms", type=int, default=256, help="Number of SMs for All-Scatter kernel") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -71,7 +77,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem = iris.iris(args["heap_size"]) rank = shmem.get_rank() world_size = shmem.get_num_ranks() - cu_count = shmem.get_cu_count() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + + if args["gemm_sms"] is None: + # For wg_specialized: use next smaller power of 2 + args["gemm_sms"] = next_pow2 + if args["comm_sms"] is None: + # For bulk synchronous, use same as gemm_sms + args["comm_sms"] = next_pow2 # GEMM datatype = torch.float32 diff --git a/examples/benchmark/bench_all_shapes.py b/examples/benchmark/bench_all_shapes.py index 0f5ce5b3..adb15657 100644 --- a/examples/benchmark/bench_all_shapes.py +++ b/examples/benchmark/bench_all_shapes.py @@ -6,6 +6,7 @@ from datetime import datetime import argparse import json +import torch def launch_sbatch( @@ -110,16 +111,28 @@ def main(hashes, config, sbatch_script_content, input_json, tiling_json, dry_run if mkn not in mkn_gemm_tiles: mkn_gemm_tiles[mkn] = {key: entry[key] for key in optional_keys if key in entry} - if config["partition"] is not None: - if "mi300" in config["partition"]: - print("Running on MI300") + # Determine gemm_sms based on available GPU or partition name + try: + if torch.cuda.is_available(): + gemm_sms = torch.cuda.get_device_properties(0).multi_processor_count + print(f"Auto-detected CU count: {gemm_sms}") + else: + gemm_sms = None + except Exception: + # Fall back to partition-based detection + gemm_sms = None + + if gemm_sms is None: + if config["partition"] is not None: + if "mi300" in config["partition"]: + print("Running on MI300 (partition-based)") + gemm_sms = 304 + elif "mi250" in config["partition"]: + print("Running on MI250 (partition-based)") + gemm_sms = 104 + else: + print("Assuming MI300 (default)") gemm_sms = 304 - elif "mi250" in config["partition"]: - print("Running on MI250") - gemm_sms = 104 - else: - print("Assuming MI300") - gemm_sms = 304 enable_algorithms = False enable_mkn = True diff --git a/scripts/link_bandwidth.py b/scripts/link_bandwidth.py index 29faf5f1..1c79d250 100644 --- a/scripts/link_bandwidth.py +++ b/scripts/link_bandwidth.py @@ -2,6 +2,15 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import json +import torch + +try: + if torch.cuda.is_available(): + cu_count = torch.cuda.get_device_properties(0).multi_processor_count + else: + cu_count = 304 # Default for MI300 +except Exception: + cu_count = 304 # Default for MI300 # Sample input (replace with file read if needed) config = { @@ -26,7 +35,7 @@ "kpack": 2, "heap_size": 8589934592, "gemm_sms": 48, - "total_sms": 304, + "total_sms": cu_count, "communication_block_size": 256, "communication_sms_multiplier": 1, "M": 8192,