diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/utils.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/utils.py index 10ad70c0c2a..2470be07802 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/utils.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/utils.py @@ -514,26 +514,67 @@ def print_query_plan( def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): # type: ignore[no-untyped-def] - """Initialize a Dask distributed cluster.""" + """ + Initialize a Dask distributed cluster. + + This function either creates a new LocalCUDACluster or connects to an + existing Dask cluster depending on the provided arguments. + + Parameters + ---------- + run_config : RunConfig + The run configuration. + args : argparse.Namespace + Parsed command line arguments. If ``args.scheduler_address`` or + ``args.scheduler_file`` is provided, we connect to an existing + cluster instead of creating a LocalCUDACluster. + + Returns + ------- + Client or None + A Dask distributed Client, or None if not using distributed mode. + """ if run_config.cluster != "distributed": return None - from dask_cuda import LocalCUDACluster from distributed import Client - kwargs = { - "n_workers": run_config.n_workers, - "dashboard_address": ":8585", - "protocol": args.protocol, - "rmm_pool_size": args.rmm_pool_size, - "rmm_async": args.rmm_async, - "rmm_release_threshold": args.rmm_release_threshold, - "threads_per_worker": run_config.threads, - } + # Check if we should connect to an existing cluster + scheduler_address = args.scheduler_address + scheduler_file = args.scheduler_file + + if scheduler_address is not None: + # Connect to existing cluster via scheduler address + client = Client(address=scheduler_address) + n_workers = len(client.scheduler_info().get("workers", {})) + print( + f"Connected to existing Dask cluster at {scheduler_address} " + f"with {n_workers} workers" + ) + elif scheduler_file is not None: + # Connect to existing cluster via scheduler file + client = Client(scheduler_file=scheduler_file) + n_workers = len(client.scheduler_info().get("workers", {})) + print( + f"Connected to existing Dask cluster via scheduler file: {scheduler_file} " + f"with {n_workers} workers" + ) + else: + # Create a new LocalCUDACluster + from dask_cuda import LocalCUDACluster + + kwargs = { + "n_workers": run_config.n_workers, + "dashboard_address": ":8585", + "protocol": args.protocol, + "rmm_pool_size": args.rmm_pool_size, + "rmm_async": args.rmm_async, + "rmm_release_threshold": args.rmm_release_threshold, + "threads_per_worker": run_config.threads, + } - # Avoid UVM in distributed cluster - client = Client(LocalCUDACluster(**kwargs)) - client.wait_for_workers(run_config.n_workers) + client = Client(LocalCUDACluster(**kwargs)) + client.wait_for_workers(run_config.n_workers) if run_config.shuffle != "tasks": try: @@ -730,6 +771,27 @@ def parse_args( type=int, help="Number of Dask-CUDA workers (requires 'distributed' cluster).", ) + external_cluster_group = parser.add_mutually_exclusive_group() + external_cluster_group.add_argument( + "--scheduler-address", + default=None, + type=str, + help=textwrap.dedent("""\ + Scheduler address for connecting to an existing Dask cluster. + If provided, a cluster is not created and worker + configuration options (--n-workers, --rmm-pool-size, etc.) + are ignored since the workers are assumed to be started separately."""), + ) + external_cluster_group.add_argument( + "--scheduler-file", + default=None, + type=str, + help=textwrap.dedent("""\ + Path to a scheduler file for connecting to an existing Dask cluster. + If provided, a cluster is not created and worker + configuration options (--n-workers, --rmm-pool-size, etc.) + are ignored since the workers are assumed to be started separately."""), + ) parser.add_argument( "--blocksize", default=None, @@ -933,6 +995,11 @@ def run_polars( client = initialize_dask_cluster(run_config, args) + # Update n_workers from the actual cluster when using scheduler file/address + if client is not None: + actual_n_workers = len(client.scheduler_info().get("workers", {})) + run_config = dataclasses.replace(run_config, n_workers=actual_n_workers) + records: defaultdict[int, list[Record]] = defaultdict(list) engine: pl.GPUEngine | None = None