-
Notifications
You must be signed in to change notification settings - Fork 1k
Allow for scheduler file and existing dask cluster when using pdsh #21024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1a4c3a0
154401d
646d1b5
350455b
3b1bae7
0744cb4
fc2cab9
aa88799
7fa5a2c
9c83470
c8933c3
f422439
7481a1e
75f3f45
fb91090
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -514,26 +514,67 @@ def print_query_plan( | |
|
|
||
|
|
||
| def initialize_dask_cluster(run_config: RunConfig, args: argparse.Namespace): # type: ignore[no-untyped-def] | ||
| """Initialize a Dask distributed cluster.""" | ||
| """ | ||
| Initialize a Dask distributed cluster. | ||
|
|
||
| This function either creates a new LocalCUDACluster or connects to an | ||
| existing Dask cluster depending on the provided arguments. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| run_config : RunConfig | ||
| The run configuration. | ||
| args : argparse.Namespace | ||
| Parsed command line arguments. If ``args.scheduler_address`` or | ||
| ``args.scheduler_file`` is provided, we connect to an existing | ||
| cluster instead of creating a LocalCUDACluster. | ||
|
|
||
| Returns | ||
| ------- | ||
| Client or None | ||
| A Dask distributed Client, or None if not using distributed mode. | ||
| """ | ||
| if run_config.cluster != "distributed": | ||
| return None | ||
|
|
||
| from dask_cuda import LocalCUDACluster | ||
| from distributed import Client | ||
|
|
||
| kwargs = { | ||
| "n_workers": run_config.n_workers, | ||
| "dashboard_address": ":8585", | ||
| "protocol": args.protocol, | ||
| "rmm_pool_size": args.rmm_pool_size, | ||
| "rmm_async": args.rmm_async, | ||
| "rmm_release_threshold": args.rmm_release_threshold, | ||
| "threads_per_worker": run_config.threads, | ||
| } | ||
| # Check if we should connect to an existing cluster | ||
| scheduler_address = args.scheduler_address | ||
| scheduler_file = args.scheduler_file | ||
|
|
||
| if scheduler_address is not None: | ||
| # Connect to existing cluster via scheduler address | ||
| client = Client(address=scheduler_address) | ||
| n_workers = len(client.scheduler_info().get("workers", {})) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a note: we serialize I wonder if we can mutate
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed in c8933c3 |
||
| print( | ||
| f"Connected to existing Dask cluster at {scheduler_address} " | ||
| f"with {n_workers} workers" | ||
| ) | ||
| elif scheduler_file is not None: | ||
| # Connect to existing cluster via scheduler file | ||
| client = Client(scheduler_file=scheduler_file) | ||
| n_workers = len(client.scheduler_info().get("workers", {})) | ||
| print( | ||
| f"Connected to existing Dask cluster via scheduler file: {scheduler_file} " | ||
| f"with {n_workers} workers" | ||
| ) | ||
| else: | ||
| # Create a new LocalCUDACluster | ||
| from dask_cuda import LocalCUDACluster | ||
|
|
||
| kwargs = { | ||
| "n_workers": run_config.n_workers, | ||
| "dashboard_address": ":8585", | ||
| "protocol": args.protocol, | ||
| "rmm_pool_size": args.rmm_pool_size, | ||
| "rmm_async": args.rmm_async, | ||
| "rmm_release_threshold": args.rmm_release_threshold, | ||
| "threads_per_worker": run_config.threads, | ||
| } | ||
|
|
||
| # Avoid UVM in distributed cluster | ||
| client = Client(LocalCUDACluster(**kwargs)) | ||
| client.wait_for_workers(run_config.n_workers) | ||
| client = Client(LocalCUDACluster(**kwargs)) | ||
| client.wait_for_workers(run_config.n_workers) | ||
|
|
||
| if run_config.shuffle != "tasks": | ||
| try: | ||
|
|
@@ -730,6 +771,27 @@ def parse_args( | |
| type=int, | ||
| help="Number of Dask-CUDA workers (requires 'distributed' cluster).", | ||
| ) | ||
| external_cluster_group = parser.add_mutually_exclusive_group() | ||
| external_cluster_group.add_argument( | ||
| "--scheduler-address", | ||
| default=None, | ||
| type=str, | ||
| help=textwrap.dedent("""\ | ||
| Scheduler address for connecting to an existing Dask cluster. | ||
| If provided, a cluster is not created and worker | ||
| configuration options (--n-workers, --rmm-pool-size, etc.) | ||
| are ignored since the workers are assumed to be started separately."""), | ||
| ) | ||
| external_cluster_group.add_argument( | ||
| "--scheduler-file", | ||
| default=None, | ||
| type=str, | ||
| help=textwrap.dedent("""\ | ||
| Path to a scheduler file for connecting to an existing Dask cluster. | ||
| If provided, a cluster is not created and worker | ||
| configuration options (--n-workers, --rmm-pool-size, etc.) | ||
| are ignored since the workers are assumed to be started separately."""), | ||
| ) | ||
| parser.add_argument( | ||
| "--blocksize", | ||
| default=None, | ||
|
|
@@ -933,6 +995,11 @@ def run_polars( | |
|
|
||
| client = initialize_dask_cluster(run_config, args) | ||
|
|
||
| # Update n_workers from the actual cluster when using scheduler file/address | ||
| if client is not None: | ||
| actual_n_workers = len(client.scheduler_info().get("workers", {})) | ||
| run_config = dataclasses.replace(run_config, n_workers=actual_n_workers) | ||
|
|
||
| records: defaultdict[int, list[Record]] = defaultdict(list) | ||
| engine: pl.GPUEngine | None = None | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.