From 4561564dbd3c0ba858e79cfe68f4275a081689fc Mon Sep 17 00:00:00 2001 From: John Krasting Date: Tue, 18 Nov 2025 13:26:52 -0500 Subject: [PATCH] Option to control xarray dataset load options - Introduced `xr_opts` keyboard to diag.open() to pass options to control the dataset loading. --- src/esnb/core/CaseGroup2.py | 4 ++-- src/esnb/core/NotebookDiagnostic.py | 34 ++++++++++++++++++++++++----- src/esnb/core/util_xr.py | 23 ++++++++++++++----- src/esnb/sites/gfdl.py | 2 +- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/src/esnb/core/CaseGroup2.py b/src/esnb/core/CaseGroup2.py index 331f9a0..a8fecf6 100644 --- a/src/esnb/core/CaseGroup2.py +++ b/src/esnb/core/CaseGroup2.py @@ -353,8 +353,8 @@ def files(self): caselist = [x for x in flatten_list(self.cases)] return sorted(flatten_list([list(x.catalog.df["path"]) for x in caselist])) - def open_var(self, varname): - return open_var_from_group(self, varname) + def open_var(self, varname, xr_opts=None): + return open_var_from_group(self, varname, xr_opts=xr_opts) def resolve(self, varlist): """ diff --git a/src/esnb/core/NotebookDiagnostic.py b/src/esnb/core/NotebookDiagnostic.py index 6e2e65a..434887c 100644 --- a/src/esnb/core/NotebookDiagnostic.py +++ b/src/esnb/core/NotebookDiagnostic.py @@ -337,7 +337,14 @@ def dmget(self, status=False): else: gfdl.call_dmget(self.files, status=status) - def load(self, site="gfdl", dmget=False, use_cache=False, cache_format="zarr"): + def load( + self, + site="gfdl", + dmget=False, + use_cache=False, + cache_format="zarr", + xr_opts=None, + ): """ Load all groups by calling their load method. """ @@ -345,10 +352,21 @@ def load(self, site="gfdl", dmget=False, use_cache=False, cache_format="zarr"): _ = [x.load() for x in self.groups] else: self.loader( - site=site, dmget=dmget, use_cache=use_cache, cache_format=cache_format + site=site, + dmget=dmget, + use_cache=use_cache, + cache_format=cache_format, + xr_opts=xr_opts, ) - def loader(self, site="gfdl", dmget=False, use_cache=False, cache_format="zarr"): + def loader( + self, + site="gfdl", + dmget=False, + use_cache=False, + cache_format="zarr", + xr_opts=None, + ): diag = self groups = diag.groups variables = diag.variables @@ -379,13 +397,14 @@ def loader(self, site="gfdl", dmget=False, use_cache=False, cache_format="zarr") cached_file_name, decode_times=time_coder, decode_timedelta=True, + **xr_opts, ) else: raise ValueError( f"Trying to open unsupported cache type: {cache_format}" ) else: - ds = group.open_var(var.varname) + ds = group.open_var(var.varname, xr_opts=xr_opts) tcoord = "time" logger.info( @@ -427,9 +446,14 @@ def open( use_cache=False, cache_format="zarr", statics=False, + xr_opts=None, ): self.load( - site=site, dmget=dmget, use_cache=use_cache, cache_format=cache_format + site=site, + dmget=dmget, + use_cache=use_cache, + cache_format=cache_format, + xr_opts=xr_opts, ) if statics: logger.info("Loading dictionary of static files") diff --git a/src/esnb/core/util_xr.py b/src/esnb/core/util_xr.py index f032beb..afd1dba 100644 --- a/src/esnb/core/util_xr.py +++ b/src/esnb/core/util_xr.py @@ -8,13 +8,13 @@ logger = logging.getLogger(__name__) -def open_paths(files, varname=None): +def open_paths(files, varname=None, xr_opts=None): file_type = infer_source_data_file_types(files) logger.debug(f"Found {file_type} files: {files}") if file_type == "unix_file": logger.info(f"Opening local files in xarray: {files}") - _ds = open_xr(files) + _ds = open_xr(files, xr_opts=xr_opts) elif file_type == "google_cloud": logger.info(f"Opening Google Cloud stores in xarray: {files}") _ds = open_gcs(files) @@ -55,20 +55,31 @@ def open_gcs(files): return ds -def open_xr(files, xr_merge_opts=None): +def open_xr(files, xr_merge_opts=None, xr_opts=None): xr_merge_opts = ( {"coords": "minimal", "compat": "override"} if xr_merge_opts is None else xr_merge_opts ) + if xr_opts is None: + xr_opts = xr_merge_opts + + assert isinstance(xr_opts, dict), "`xr_opts` must be a dictionary of kwargs" + + if len(xr_opts) == 0: + xr_opts = {**xr_merge_opts} + + if len(xr_opts) > 0: + logger.debug(f"Options passed to `xr.open_mfdataset`: {xr_opts}") + time_coder = xr.coders.CFDatetimeCoder(use_cftime=True) ds = xr.open_mfdataset( files, decode_times=time_coder, decode_timedelta=True, chunks={}, - **xr_merge_opts, + **xr_opts, ) ds.attrs["files"] = files @@ -76,7 +87,7 @@ def open_xr(files, xr_merge_opts=None): return ds -def open_var_from_group(group, varname): +def open_var_from_group(group, varname, xr_opts=None): concat_dim = group.concat_dim concat_dim = [concat_dim] if not isinstance(concat_dim, list) else concat_dim @@ -96,7 +107,7 @@ def open_var_from_group(group, varname): nelements = len(case_elements) logger.debug(f"This case has {nelements} elements: {case_elements}") case_elements = [ - open_paths(x.files(variable_id=varname), varname=varname) + open_paths(x.files(variable_id=varname), varname=varname, xr_opts=xr_opts) for x in case_elements ] if nelements > 1: diff --git a/src/esnb/sites/gfdl.py b/src/esnb/sites/gfdl.py index 26f0be3..97730a1 100644 --- a/src/esnb/sites/gfdl.py +++ b/src/esnb/sites/gfdl.py @@ -293,7 +293,7 @@ def convert_to_momgrid(diag, positive_longitudes=False, fatal_on_error=True): ds.replace(_ds.data) except Exception as exc: logger.warning(f"Unable to convert dataset [{n}] with momgrid: {exc}") - if fatal_on_error: + if fatal_on_error: raise exc