Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ classifiers = [
]
dependencies = [
"aiohttp",
"dask",
"dask-jobqueue",
"importlib_metadata <5.0.0",
"importlib-resources",
"intake",
Expand Down
2 changes: 2 additions & 0 deletions src/esnb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .core.esnb_datastore import esnb_datastore
from .core.NotebookDiagnostic import NotebookDiagnostic
from .core.RequestedVariable import RequestedVariable
from .core.util_dask import init_dask_cluster

__all__ = [
"core",
Expand All @@ -24,6 +25,7 @@
"esnb_datastore",
"NotebookDiagnostic",
"RequestedVariable",
"init_dask_cluster",
"nbtools",
]

Expand Down
2 changes: 2 additions & 0 deletions src/esnb/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
util,
util2,
util_catalog,
util_dask,
util_mdtf,
util_xr,
)
Expand All @@ -23,6 +24,7 @@
"util",
"util2",
"util_case",
"util_dask",
"util_catalog",
"util_mdtf",
"util_xr",
Expand Down
64 changes: 64 additions & 0 deletions src/esnb/core/util_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import inspect
import logging
import os

from dask.distributed import Client, LocalCluster

logger = logging.getLogger(__name__)


def _local_cluster(**kwargs):
sig = inspect.signature(LocalCluster)
params = sig.parameters
valid_args = [name for name, param in params.items()]

ignored_args = []
for k, v in kwargs.items():
if k not in valid_args:
ignored_args.append(k)

if len(ignored_args) > 0:
for x in ignored_args:
del kwargs[x]
logger.warning(f"Ignoring options for LocalCluster: {ignored_args}")

cluster = LocalCluster(**kwargs)
client = Client(cluster)
logger.info(f"Initializing local dask cluster: {cluster.dashboard_link}")
return (cluster, client)


def init_dask_cluster(site="local", **kwargs):
if "portdash" in kwargs.keys():
kwargs["dashboard_address"] = f":{kwargs['portdash']}"
del kwargs["portdash"]

if site == "local":
cluster, client = _local_cluster(**kwargs)
elif site == "gfdl_ppan":
from esnb.sites.gfdl import dask_cluster_ppan

cluster, client = dask_cluster_ppan(**kwargs)

if cluster is None:
logger.warning("An error occured; Falling back to dask LocalCluster")
cluster, client = _local_cluster(**kwargs)

else:
raise ValueError(f"Unrecognized Dask Site: {site}")

return (cluster, client)


def init_dask_cluster_test(**kwargs):
options = {
"site": "gfdl_ppan",
"walltime": "24:00:00",
"highmem": True,
"memory": "48GB",
"portdash": os.getuid() + 6047,
}

options = {**options, **kwargs}

return init_dask_cluster(**options)
82 changes: 82 additions & 0 deletions src/esnb/sites/gfdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
and loading DORA catalogs via the esnb_datastore interface.
"""

import getpass
import logging
import os
import shutil
Expand All @@ -15,6 +16,8 @@
import intake
import pandas as pd
import requests
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

from esnb.core.esnb_datastore import esnb_datastore

Expand All @@ -39,6 +42,85 @@
logger = logging.getLogger(__name__)


def default_slurm_account(user=None):
user = getpass.getuser() if user is None else str(user)
cmd = ["sacctmgr", "show", "user", user, "format=defaultaccount", "-nP"]
default_account = subprocess.check_output(cmd, text=True).strip()
return default_account


def dask_cluster_ppan(highmem=False, **kwargs):
if "account" not in kwargs.keys():
default_account = default_slurm_account()
logger.info(f"Setting SLURM account to {default_account}")
kwargs["account"] = default_account

if "dashboard_address" in kwargs.keys():
kwargs["scheduler_options"] = {"dashboard_address": kwargs["dashboard_address"]}
del kwargs["dashboard_address"]

if "jobs" in kwargs.keys():
jobs = kwargs["jobs"]
del kwargs["jobs"]
else:
jobs = None

if "wait" in kwargs.keys():
wait = kwargs["wait"]
del kwargs["wait"]
else:
wait = True

options = {
"queue": "batch",
"cores": 8,
"processes": 2,
"walltime": "01:00:00",
"local_directory": "$TMPDIR",
"death_timeout": 120,
"job_name": "esnb-dask",
"memory": "16GB",
}

if highmem:
options["job_extra_directives"] = ["--exclude=pp[008-010],pp[013-075]"]

options = {**options, **kwargs}

try:
cluster = SLURMCluster(**options)
client = Client(cluster)
logger.info(
f"Successfully started SLURM dask cluster: {cluster.dashboard_link}"
)

except Exception as exc:
logger.warning(f"Unable to start SLURMCluster: {exc}")
cluster = None
client = None

if (cluster is not None) and (jobs is not None):
try:
cluster.scale(jobs=int(jobs))
logger.info(f"Scaling SLURM dask cluster to {jobs} jobs.")

if wait:
logger.info("Waiting for a worker to come online ...")
client.wait_for_workers(n_workers=1)

except Exception as exc:
logger.warning(f"Unable to scale the SLURM cluster: {exc}")
cluster.close()

if client is not None:
client.close()

cluster = None
client = None

return (cluster, client)


def generate_gfdl_intake_catalog(pathpp, fre_cli=None):
logger.info(f"Generating intake catalog for: {pathpp}")
current_dir = os.getcwd()
Expand Down