diff --git a/pyproject.toml b/pyproject.toml index 2b20173..3f72ec0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,8 @@ classifiers = [ ] dependencies = [ "aiohttp", + "dask", + "dask-jobqueue", "importlib_metadata <5.0.0", "importlib-resources", "intake", diff --git a/src/esnb/__init__.py b/src/esnb/__init__.py index 54bed02..e7f5981 100644 --- a/src/esnb/__init__.py +++ b/src/esnb/__init__.py @@ -11,6 +11,7 @@ from .core.esnb_datastore import esnb_datastore from .core.NotebookDiagnostic import NotebookDiagnostic from .core.RequestedVariable import RequestedVariable +from .core.util_dask import init_dask_cluster __all__ = [ "core", @@ -24,6 +25,7 @@ "esnb_datastore", "NotebookDiagnostic", "RequestedVariable", + "init_dask_cluster", "nbtools", ] diff --git a/src/esnb/core/__init__.py b/src/esnb/core/__init__.py index 96bf9a3..1a8f0d4 100644 --- a/src/esnb/core/__init__.py +++ b/src/esnb/core/__init__.py @@ -8,6 +8,7 @@ util, util2, util_catalog, + util_dask, util_mdtf, util_xr, ) @@ -23,6 +24,7 @@ "util", "util2", "util_case", + "util_dask", "util_catalog", "util_mdtf", "util_xr", diff --git a/src/esnb/core/util_dask.py b/src/esnb/core/util_dask.py new file mode 100644 index 0000000..f1c25ef --- /dev/null +++ b/src/esnb/core/util_dask.py @@ -0,0 +1,64 @@ +import inspect +import logging +import os + +from dask.distributed import Client, LocalCluster + +logger = logging.getLogger(__name__) + + +def _local_cluster(**kwargs): + sig = inspect.signature(LocalCluster) + params = sig.parameters + valid_args = [name for name, param in params.items()] + + ignored_args = [] + for k, v in kwargs.items(): + if k not in valid_args: + ignored_args.append(k) + + if len(ignored_args) > 0: + for x in ignored_args: + del kwargs[x] + logger.warning(f"Ignoring options for LocalCluster: {ignored_args}") + + cluster = LocalCluster(**kwargs) + client = Client(cluster) + logger.info(f"Initializing local dask cluster: {cluster.dashboard_link}") + return (cluster, client) + + +def init_dask_cluster(site="local", **kwargs): + if "portdash" in kwargs.keys(): + kwargs["dashboard_address"] = f":{kwargs['portdash']}" + del kwargs["portdash"] + + if site == "local": + cluster, client = _local_cluster(**kwargs) + elif site == "gfdl_ppan": + from esnb.sites.gfdl import dask_cluster_ppan + + cluster, client = dask_cluster_ppan(**kwargs) + + if cluster is None: + logger.warning("An error occured; Falling back to dask LocalCluster") + cluster, client = _local_cluster(**kwargs) + + else: + raise ValueError(f"Unrecognized Dask Site: {site}") + + return (cluster, client) + + +def init_dask_cluster_test(**kwargs): + options = { + "site": "gfdl_ppan", + "walltime": "24:00:00", + "highmem": True, + "memory": "48GB", + "portdash": os.getuid() + 6047, + } + + options = {**options, **kwargs} + + return init_dask_cluster(**options) diff --git a/src/esnb/sites/gfdl.py b/src/esnb/sites/gfdl.py index 61ed2c5..ff883f0 100644 --- a/src/esnb/sites/gfdl.py +++ b/src/esnb/sites/gfdl.py @@ -4,6 +4,7 @@ and loading DORA catalogs via the esnb_datastore interface. """ +import getpass import logging import os import shutil @@ -15,6 +16,8 @@ import intake import pandas as pd import requests +from dask.distributed import Client +from dask_jobqueue import SLURMCluster from esnb.core.esnb_datastore import esnb_datastore @@ -39,6 +42,85 @@ logger = logging.getLogger(__name__) +def default_slurm_account(user=None): + user = getpass.getuser() if user is None else str(user) + cmd = ["sacctmgr", "show", "user", user, "format=defaultaccount", "-nP"] + default_account = subprocess.check_output(cmd, text=True).strip() + return default_account + + +def dask_cluster_ppan(highmem=False, **kwargs): + if "account" not in kwargs.keys(): + default_account = default_slurm_account() + logger.info(f"Setting SLURM account to {default_account}") + kwargs["account"] = default_account + + if "dashboard_address" in kwargs.keys(): + kwargs["scheduler_options"] = {"dashboard_address": kwargs["dashboard_address"]} + del kwargs["dashboard_address"] + + if "jobs" in kwargs.keys(): + jobs = kwargs["jobs"] + del kwargs["jobs"] + else: + jobs = None + + if "wait" in kwargs.keys(): + wait = kwargs["wait"] + del kwargs["wait"] + else: + wait = True + + options = { + "queue": "batch", + "cores": 8, + "processes": 2, + "walltime": "01:00:00", + "local_directory": "$TMPDIR", + "death_timeout": 120, + "job_name": "esnb-dask", + "memory": "16GB", + } + + if highmem: + options["job_extra_directives"] = ["--exclude=pp[008-010],pp[013-075]"] + + options = {**options, **kwargs} + + try: + cluster = SLURMCluster(**options) + client = Client(cluster) + logger.info( + f"Successfully started SLURM dask cluster: {cluster.dashboard_link}" + ) + + except Exception as exc: + logger.warning(f"Unable to start SLURMCluster: {exc}") + cluster = None + client = None + + if (cluster is not None) and (jobs is not None): + try: + cluster.scale(jobs=int(jobs)) + logger.info(f"Scaling SLURM dask cluster to {jobs} jobs.") + + if wait: + logger.info("Waiting for a worker to come online ...") + client.wait_for_workers(n_workers=1) + + except Exception as exc: + logger.warning(f"Unable to scale the SLURM cluster: {exc}") + cluster.close() + + if client is not None: + client.close() + + cluster = None + client = None + + return (cluster, client) + + def generate_gfdl_intake_catalog(pathpp, fre_cli=None): logger.info(f"Generating intake catalog for: {pathpp}") current_dir = os.getcwd()