Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 81 additions & 14 deletions python/cudf_polars/cudf_polars/experimental/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,26 +514,67 @@ def print_query_plan(


def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): # type: ignore[no-untyped-def]
"""Initialize a Dask distributed cluster."""
"""
Initialize a Dask distributed cluster.

This function either creates a new LocalCUDACluster or connects to an
existing Dask cluster depending on the provided arguments.

Parameters
----------
run_config : RunConfig
The run configuration.
args : argparse.Namespace
Parsed command line arguments. If ``args.scheduler_address`` or
``args.scheduler_file`` is provided, we connect to an existing
cluster instead of creating a LocalCUDACluster.

Returns
-------
Client or None
A Dask distributed Client, or None if not using distributed mode.
"""
if run_config.cluster != "distributed":
return None

from dask_cuda import LocalCUDACluster
from distributed import Client

kwargs = {
"n_workers": run_config.n_workers,
"dashboard_address": ":8585",
"protocol": args.protocol,
"rmm_pool_size": args.rmm_pool_size,
"rmm_async": args.rmm_async,
"rmm_release_threshold": args.rmm_release_threshold,
"threads_per_worker": run_config.threads,
}
# Check if we should connect to an existing cluster
scheduler_address = args.scheduler_address
scheduler_file = args.scheduler_file

if scheduler_address is not None:
# Connect to existing cluster via scheduler address
client = Client(address=scheduler_address)
n_workers = len(client.scheduler_info().get("workers", {}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: we serialize run_config.n_workers in the JSON output. When a scheduler file is provided the run_config.n_workers won't be accurate.

I wonder if we can mutate run_config.n_workers here? It's not ideal, but I think it's an OK tradeoff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in c8933c3

print(
f"Connected to existing Dask cluster at {scheduler_address} "
f"with {n_workers} workers"
)
elif scheduler_file is not None:
# Connect to existing cluster via scheduler file
client = Client(scheduler_file=scheduler_file)
n_workers = len(client.scheduler_info().get("workers", {}))
print(
f"Connected to existing Dask cluster via scheduler file: {scheduler_file} "
f"with {n_workers} workers"
)
else:
# Create a new LocalCUDACluster
from dask_cuda import LocalCUDACluster

kwargs = {
"n_workers": run_config.n_workers,
"dashboard_address": ":8585",
"protocol": args.protocol,
"rmm_pool_size": args.rmm_pool_size,
"rmm_async": args.rmm_async,
"rmm_release_threshold": args.rmm_release_threshold,
"threads_per_worker": run_config.threads,
}

# Avoid UVM in distributed cluster
client = Client(LocalCUDACluster(**kwargs))
client.wait_for_workers(run_config.n_workers)
client = Client(LocalCUDACluster(**kwargs))
client.wait_for_workers(run_config.n_workers)

if run_config.shuffle != "tasks":
try:
Expand Down Expand Up @@ -730,6 +771,27 @@ def parse_args(
type=int,
help="Number of Dask-CUDA workers (requires 'distributed' cluster).",
)
external_cluster_group = parser.add_mutually_exclusive_group()
external_cluster_group.add_argument(
"--scheduler-address",
default=None,
type=str,
help=textwrap.dedent("""\
Scheduler address for connecting to an existing Dask cluster.
If provided, a cluster is not created and worker
configuration options (--n-workers, --rmm-pool-size, etc.)
are ignored since the workers are assumed to be started separately."""),
)
external_cluster_group.add_argument(
"--scheduler-file",
default=None,
type=str,
help=textwrap.dedent("""\
Path to a scheduler file for connecting to an existing Dask cluster.
If provided, a cluster is not created and worker
configuration options (--n-workers, --rmm-pool-size, etc.)
are ignored since the workers are assumed to be started separately."""),
)
parser.add_argument(
"--blocksize",
default=None,
Expand Down Expand Up @@ -933,6 +995,11 @@ def run_polars(

client = initialize_dask_cluster(run_config, args)

# Update n_workers from the actual cluster when using scheduler file/address
if client is not None:
actual_n_workers = len(client.scheduler_info().get("workers", {}))
run_config = dataclasses.replace(run_config, n_workers=actual_n_workers)

records: defaultdict[int, list[Record]] = defaultdict(list)
engine: pl.GPUEngine | None = None

Expand Down