From ce7ed4e1d85e727c0cf771f1004f44d586009392 Mon Sep 17 00:00:00 2001 From: reginabuehler Date: Tue, 17 Feb 2026 17:41:52 +0100 Subject: [PATCH 1/2] feat: add ClusterLocal as scheduler on a cluster that does not need a remote connection --- src/queens/schedulers/__init__.py | 1 + src/queens/schedulers/cluster_local.py | 232 +++++++++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 src/queens/schedulers/cluster_local.py diff --git a/src/queens/schedulers/__init__.py b/src/queens/schedulers/__init__.py index 1bbcf70b7..b5a8777de 100644 --- a/src/queens/schedulers/__init__.py +++ b/src/queens/schedulers/__init__.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from queens.schedulers._scheduler import Scheduler from queens.schedulers.cluster import Cluster + from queens.schedulers.cluster_local import ClusterLocal from queens.schedulers.local import Local from queens.schedulers.pool import Pool diff --git a/src/queens/schedulers/cluster_local.py b/src/queens/schedulers/cluster_local.py new file mode 100644 index 000000000..15ae2336f --- /dev/null +++ b/src/queens/schedulers/cluster_local.py @@ -0,0 +1,232 @@ +# +# SPDX-License-Identifier: LGPL-3.0-or-later +# Copyright (c) 2024-2025, QUEENS contributors. +# +# This file is part of QUEENS. +# +# QUEENS is free software: you can redistribute it and/or modify it under the terms of the GNU +# Lesser General Public License as published by the Free Software Foundation, either version 3 of +# the License, or (at your option) any later version. QUEENS is distributed in the hope that it will +# be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You +# should have received a copy of the GNU Lesser General Public License along with QUEENS. If not, +# see . +# +"""Cluster scheduler for QUEENS runs.""" + +import logging +import time +from datetime import timedelta + +from dask.distributed import Client + +from queens.schedulers._dask import Dask +from queens.schedulers.cluster import VALID_WORKLOAD_MANAGERS, timedelta_to_str +from queens.utils.logger_settings import log_init_args +from queens.utils.remote_operations import get_port +from queens.utils.rsync import rsync +from queens.utils.valid_options import get_option + +_logger = logging.getLogger(__name__) + + +class ClusterLocal(Dask): + """Cluster (local) scheduler for QUEENS. + + Can be used to schedule jobs to a cluster scheduler with local + access i.e. without a network connection. + """ + + @log_init_args + def __init__( + self, + experiment_name, + workload_manager, + walltime, + num_jobs=1, + min_jobs=0, + num_procs=1, + num_nodes=1, + queue=None, + cluster_internal_address=None, + restart_workers=False, + allowed_failures=5, + verbose=True, + experiment_base_dir=None, + overwrite_existing_experiment=False, + job_script_prologue=None, + ): + """Init method for the cluster scheduler. + + The total number of cores per job is given by num_procs*num_nodes. + + Args: + experiment_name (str): name of the current experiment + workload_manager (str): Workload manager ("pbs" or "slurm") + walltime (str): Walltime for each worker job. Format (hh:mm:ss) + num_jobs (int, opt): Maximum number of parallel jobs + min_jobs (int, opt): Minimum number of active workers for the cluster + num_procs (int, opt): Number of processors per job per node + num_nodes (int, opt): Number of cluster nodes per job + queue (str, opt): Destination queue for each worker job + cluster_internal_address (str, opt): Internal address of cluster + restart_workers (bool): If true, restart workers after each finished job. For larger + jobs (>1min) this should be set to true in most cases. + allowed_failures (int): Number of allowed failures for a task before an error is raised + verbose (bool, opt): Verbosity of evaluations. Defaults to True. + experiment_base_dir (str, Path): Base directory for the simulation outputs + overwrite_existing_experiment (bool): If True, overwrite experiment directory if it + exists already. If False, prompt user for confirmation before overwriting. + job_script_prologue (list, opt): List of commands to be executed before starting a + worker. + """ + self.workload_manager = workload_manager + self.walltime = walltime + self.min_jobs = min_jobs + self.num_nodes = num_nodes + self.queue = queue + self.cluster_internal_address = cluster_internal_address + self.allowed_failures = allowed_failures + self.job_script_prologue = job_script_prologue + + # get the path of the experiment directory on remote host + experiment_dir = self.local_experiment_dir( + experiment_name, experiment_base_dir, overwrite_existing_experiment + ) + + _logger.debug( + "experiment directory: %s", + experiment_dir, + ) + + super().__init__( + experiment_name=experiment_name, + experiment_dir=experiment_dir, + num_jobs=num_jobs, + num_procs=num_procs, + restart_workers=restart_workers, + verbose=verbose, + ) + + def _start_cluster_and_connect_client(self): + """Start a Dask cluster and a client that connects to it. + + Returns: + client (Client): Dask client that is connected to and submits computations to a Dask + cluster. + """ + # collect all settings for the dask cluster + dask_cluster_options = get_option(VALID_WORKLOAD_MANAGERS, self.workload_manager) + job_extra_directives = dask_cluster_options["job_extra_directives"]( + self.num_nodes, self.num_procs + ) + job_directives_skip = dask_cluster_options["job_directives_skip"] + if self.queue is None: + job_directives_skip.append("#SBATCH -p") + + hours, minutes, seconds = map(int, self.walltime.split(":")) + walltime_delta = timedelta(hours=hours, minutes=minutes, seconds=seconds) + + # Increase jobqueue walltime by 5 minutes to kill dask workers in time + walltime = timedelta_to_str(walltime_delta + timedelta(minutes=5)) + + # dask worker lifetime = walltime - 3m +/- 2m + worker_lifetime = str(int((walltime_delta + timedelta(minutes=2)).total_seconds())) + "s" + + remote_port = get_port() + local_port_dashboard = get_port() + remote_port_dashboard = get_port() + + scheduler_options = { + "port": remote_port, + "dashboard_address": remote_port_dashboard, + "allowed_failures": self.allowed_failures, + } + if self.cluster_internal_address: + scheduler_options["contact_address"] = f"{self.cluster_internal_address}:{remote_port}" + dask_cluster_kwargs = { + "job_name": self.experiment_name, + "queue": self.queue, + "memory": "10TB", + "scheduler_options": scheduler_options, + "walltime": walltime, + "log_directory": str(self.experiment_dir), + "job_directives_skip": job_directives_skip, + "job_extra_directives": [job_extra_directives], + "worker_extra_args": ["--lifetime", worker_lifetime, "--lifetime-stagger", "2m"], + # keep this hardcoded to 1, the number of threads for the mpi run is handled by + # job_extra_directives. Note that the number of workers is not the number of parallel + # simulations! + "cores": 1, + "processes": 1, + "n_workers": 1, + } + dask_cluster_adapt_kwargs = { + "minimum_jobs": self.min_jobs, + "maximum_jobs": self.num_jobs, + } + + dask_cluster_options = get_option(VALID_WORKLOAD_MANAGERS, self.workload_manager) + dask_cluster_cls = dask_cluster_options["dask_cluster_cls"] + + try: + _logger.info("Starting dask cluster of type: %s", dask_cluster_cls) + _logger.debug("Dask cluster kwargs:") + _logger.debug(dask_cluster_kwargs) + cluster = dask_cluster_cls(**dask_cluster_kwargs) + + _logger.info("Adapting dask cluster settings") + _logger.debug("Dask cluster adapt kwargs:") + _logger.debug(dask_cluster_adapt_kwargs) + cluster.adapt(**dask_cluster_adapt_kwargs) + + _logger.info("Dask cluster info:") + _logger.info(cluster) + + dask_jobscript = self.experiment_dir / "dask_jobscript.sh" + _logger.info("Writing dask jobscript to:") + _logger.info(dask_jobscript) + dask_jobscript.write_text(str(cluster.job_script())) + except Exception as e: + raise RuntimeError() from e + + for i in range(20, 0, -1): # 20 tries to connect + _logger.debug("Trying to connect to Dask Cluster: try #%d", i) + try: + client = Client(cluster) + break + except OSError as exc: + if i == 1: + raise OSError() from exc + time.sleep(1) + + _logger.debug("Submitting dummy job to check basic functionality of client.") + client.submit(lambda: "Dummy job").result(timeout=180) + _logger.debug("Dummy job was successful.") + _logger.info( + "To view the Dask dashboard open this link in your browser: " + "http://localhost:%i/status", + local_port_dashboard, + ) + return client + + def restart_worker(self, worker): + """Restart a worker. + + This method retires a dask worker. + The Client.adapt method of dask takes cares of submitting new workers subsequently. + + Args: + worker (str, tuple): Worker to restart. This can be a worker address, name, or a both. + """ + self.client.retire_workers(workers=list(worker)) + + def copy_files_to_experiment_dir(self, paths): + """Copy file to experiment directory. + + Args: + paths (Path, list): paths to files or directories that should be copied to experiment + directory + """ + destination = f"{self.experiment_dir}/" + rsync(paths, destination) From 835aaf0681f36f76c2c5f5878f1851f310c8d7c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Regina=20B=C3=BChler?= Date: Fri, 27 Mar 2026 17:43:17 +0100 Subject: [PATCH 2/2] Add missing option to `dask_cluster_kwargs` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sebastian Brandstäter <45557303+sbrandstaeter@users.noreply.github.com> --- src/queens/schedulers/cluster_local.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/queens/schedulers/cluster_local.py b/src/queens/schedulers/cluster_local.py index 15ae2336f..24145d52d 100644 --- a/src/queens/schedulers/cluster_local.py +++ b/src/queens/schedulers/cluster_local.py @@ -154,6 +154,7 @@ def _start_cluster_and_connect_client(self): "job_directives_skip": job_directives_skip, "job_extra_directives": [job_extra_directives], "worker_extra_args": ["--lifetime", worker_lifetime, "--lifetime-stagger", "2m"], + "job_script_prologue": self.job_script_prologue, # keep this hardcoded to 1, the number of threads for the mpi run is handled by # job_extra_directives. Note that the number of workers is not the number of parallel # simulations!