From 37579fec2895dfe45d1b6c69608df47a732122cb Mon Sep 17 00:00:00 2001 From: Kobe Vandelanotte Date: Fri, 28 Feb 2025 14:51:20 +0100 Subject: [PATCH 1/5] fix default kwarg rapping issue --- src/valenspy/diagnostic/plot_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/valenspy/diagnostic/plot_utils.py b/src/valenspy/diagnostic/plot_utils.py index 931200e3..f24fe1f9 100644 --- a/src/valenspy/diagnostic/plot_utils.py +++ b/src/valenspy/diagnostic/plot_utils.py @@ -29,13 +29,14 @@ def _augment_kwargs(def_kwargs, **kwargs): cbar_kwargs = _merge_kwargs(def_kwargs.pop('cbar_kwargs'), kwargs.pop('cbar_kwargs', {})) def_kwargs['cbar_kwargs'] = cbar_kwargs + #Is this correct? Are the cbar_kwargs not overwritten if defined by the user? return _merge_kwargs(def_kwargs, kwargs) ###################################### ############## Wrappers ############## ###################################### -def default_plot_kwargs(kwargs): +def default_plot_kwargs(def_kwargs): """ Decorator to set the default keyword arguments for the plotting function. User will override and/or be augmented with the default keyword arguments. subplot_kws and cbar_kwargs can also be set as default keyword arguments for the plotting function. @@ -63,7 +64,7 @@ def decorator(plotting_function): @wraps(plotting_function) def wrapper(*args, **kwargs): - return plotting_function(*args, **_augment_kwargs(def_kwargs=kwargs, **kwargs)) + return plotting_function(*args, **_augment_kwargs(def_kwargs=def_kwargs, **kwargs)) return wrapper From 95a3d7d9710e7f1b4ea26bdd3225e00781b39685 Mon Sep 17 00:00:00 2001 From: Kobe Vandelanotte Date: Fri, 28 Feb 2025 14:52:11 +0100 Subject: [PATCH 2/5] Ensemble selection diagnostic v1 --- src/valenspy/diagnostic/diagnostic.py | 16 ++++++++-- src/valenspy/diagnostic/functions.py | 31 +++++++++++++++++++ src/valenspy/diagnostic/visualizations.py | 36 +++++++++++++++++++++++ 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/src/valenspy/diagnostic/diagnostic.py b/src/valenspy/diagnostic/diagnostic.py index 95c16e84..490912b3 100644 --- a/src/valenspy/diagnostic/diagnostic.py +++ b/src/valenspy/diagnostic/diagnostic.py @@ -2,8 +2,7 @@ import xarray as xr import matplotlib.pyplot as plt from valenspy.processing.mask import add_prudence_regions -from valenspy.diagnostic.plot_utils import _augment_kwargs - +from valenspy.diagnostic.plot_utils import default_plot_kwargs, _augment_kwargs #Import get_axis from xarray from xarray.plot.utils import get_axis @@ -498,4 +497,17 @@ def _initialize_multiaxis_plot(n, subplot_kws={}): plot_metric_ranking, "Metrics Rankings", "The rankings of ensemble members with respect to several metrics when compared to the reference." +) + +# Ensemble2Ref diagnostics +EnsembleSubSelection = Ensemble2Ref( + case_sub_selection, + default_plot_kwargs({ + "x": "var", + "y": "abs_change", + "selected": ["highest", "middle", "lowest"], + "sel_colors": {"highest": "red", "middle": "blue", "lowest": "green"} + })(ensemble_selection_boxplot), + "Ensemble Sub Selection", + "The sub selection of ensemble members." ) \ No newline at end of file diff --git a/src/valenspy/diagnostic/functions.py b/src/valenspy/diagnostic/functions.py index 16ff1362..95638d01 100644 --- a/src/valenspy/diagnostic/functions.py +++ b/src/valenspy/diagnostic/functions.py @@ -261,6 +261,37 @@ def calc_metrics_dt(dt_mod: DataTree, da_obs: xr.Dataset, metrics=None, pss_binw df = _add_ranks_metrics(df) return df +########################################## +# Ensemble2Ensemble diagnostic functions # +########################################## + +def case_sub_selection(dt_future: DataTree, dt_ref: DataTree, vars): + """ + Select 3 ensemble members based on avg normalized climate change in the variable of interest. + The two extreme and mediaan members are selected. + """ + dt_change = dt_future.mean() - dt_ref.mean() + dt_rel_change = dt_change / dt_ref.mean() + + data = [[x.path, var, x[var].values] for x in dt_rel_change.leaves for var in vars if var in x] + data_abs = [[x.path, var, x[var].values] for x in dt_change.leaves for var in vars if var in x] + + #Create one dataframe containing the relative change and the absolute change + df = pd.DataFrame(data, columns=["label", "var", "rel_change"]) + df_abs = pd.DataFrame(data_abs, columns=["label", "var", "abs_change"]) + + df = pd.merge(df, df_abs, on=["label", "var"]) + df["rel_change"] = df["rel_change"].astype(float) + df["abs_change"] = df["abs_change"].astype(float) + + df["mean"] = df["rel_change"].groupby(df["label"]).transform("mean") + df["rank"] = df["mean"].groupby(df["var"]).rank(ascending=True, method='min') + df["highest"] = df["rank"] == 1 + df["lowest"] = df["rank"] == df["rank"].max() + df["middle"] = df["rank"] == np.floor(df["rank"].median()) + + return df + ################################## ####### Helper functions ######### diff --git a/src/valenspy/diagnostic/visualizations.py b/src/valenspy/diagnostic/visualizations.py index 744bedc6..c4463055 100644 --- a/src/valenspy/diagnostic/visualizations.py +++ b/src/valenspy/diagnostic/visualizations.py @@ -496,6 +496,42 @@ def plot_metric_ranking(df_metric, ax=None, plot_colorbar=True, hex_color1 = Non return ax +########################################## +# Ensemble2Ensemble diagnostic functions # +########################################## + +#Add default plot kwargs at the diagnostic level +def ensemble_selection_boxplot(df, selected=None, sel_colors=None, ax=None, **kwargs): + """ + Create a boxplot of the ensemble members with the selected ensemble members highlighted. + + Parameters + ---------- + df : pd.DataFrame + A DataFrame containing the ensemble members and their values. + selected : str or list of str, optional + The column name(s) with boolean values indicating the selected ensemble members. + If None all members are plotted. + ax : matplotlib.axes.Axes, optional + The axes on which to plot the boxplot. If None, a new figure and axes are created. + **kwargs : dict + Additional keyword arguments passed to `seaborn.boxplot`. + """ + sns.boxplot(data=df, ax=ax, **kwargs) + + if selected is None: + sns.swarmplot(data=df, ax=ax, **kwargs) + else: + for sel in selected: + if sel in df.columns: + if isinstance(sel_colors, dict) and sel in sel_colors: + kwargs.update({"color": sel_colors[sel]}) + sns.swarmplot(data=df[df[sel]], ax=ax, **kwargs) + + ax = _get_gca(**kwargs) + return ax + + ################################## # Helper functions # ################################## From 74c23347da27af70b13926358efbe42ed36cb3db Mon Sep 17 00:00:00 2001 From: Kobe Vandelanotte Date: Fri, 28 Feb 2025 14:52:25 +0100 Subject: [PATCH 3/5] EnsembleSubSelection diagnostic --- src/valenspy/diagnostic/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/valenspy/diagnostic/__init__.py b/src/valenspy/diagnostic/__init__.py index 23cac8c2..0069e7eb 100644 --- a/src/valenspy/diagnostic/__init__.py +++ b/src/valenspy/diagnostic/__init__.py @@ -8,4 +8,5 @@ SpatialBias, TemporalBias, DiurnalCycleBias, + EnsembleSubSelection ) \ No newline at end of file From 44485433193c0425b4eed4dda1808a93e86b4f73 Mon Sep 17 00:00:00 2001 From: Kobe Vandelanotte Date: Thu, 6 Mar 2025 18:15:54 +0100 Subject: [PATCH 4/5] ensemble plot --- poetry.lock | 142 +++++++++++++++------- pyproject.toml | 2 +- src/valenspy/diagnostic/diagnostic.py | 100 ++++++++++----- src/valenspy/diagnostic/functions.py | 29 +++-- src/valenspy/diagnostic/visualizations.py | 22 ++++ src/valenspy/input/manager.py | 2 +- 6 files changed, 212 insertions(+), 85 deletions(-) diff --git a/poetry.lock b/poetry.lock index 07eb3eec..a6fbef23 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1339,28 +1339,6 @@ perf = ["ipython"] test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] type = ["pytest-mypy"] -[[package]] -name = "importlib-resources" -version = "6.4.5" -description = "Read resources from Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717"}, - {file = "importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065"}, -] - -[package.dependencies] -zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} - -[package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] -cover = ["pytest-cov"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -enabler = ["pytest-enabler (>=2.2)"] -test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] -type = ["pytest-mypy"] - [[package]] name = "iniconfig" version = "2.0.0" @@ -1438,7 +1416,6 @@ prompt-toolkit = ">=3.0.41,<3.1.0" pygments = ">=2.4.0" stack-data = "*" traitlets = ">=5" -typing-extensions = {version = "*", markers = "python_version < \"3.10\""} [package.extras] all = ["black", "curio", "docrepr", "exceptiongroup", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio (<0.22)", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] @@ -1598,7 +1575,6 @@ files = [ ] [package.dependencies] -importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" @@ -1786,6 +1762,36 @@ files = [ {file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"}, ] +[[package]] +name = "llvmlite" +version = "0.44.0" +description = "lightweight wrapper around basic LLVM functionality" +optional = false +python-versions = ">=3.10" +files = [ + {file = "llvmlite-0.44.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:9fbadbfba8422123bab5535b293da1cf72f9f478a65645ecd73e781f962ca614"}, + {file = "llvmlite-0.44.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cccf8eb28f24840f2689fb1a45f9c0f7e582dd24e088dcf96e424834af11f791"}, + {file = "llvmlite-0.44.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7202b678cdf904823c764ee0fe2dfe38a76981f4c1e51715b4cb5abb6cf1d9e8"}, + {file = "llvmlite-0.44.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:40526fb5e313d7b96bda4cbb2c85cd5374e04d80732dd36a282d72a560bb6408"}, + {file = "llvmlite-0.44.0-cp310-cp310-win_amd64.whl", hash = "sha256:41e3839150db4330e1b2716c0be3b5c4672525b4c9005e17c7597f835f351ce2"}, + {file = "llvmlite-0.44.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:eed7d5f29136bda63b6d7804c279e2b72e08c952b7c5df61f45db408e0ee52f3"}, + {file = "llvmlite-0.44.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ace564d9fa44bb91eb6e6d8e7754977783c68e90a471ea7ce913bff30bd62427"}, + {file = "llvmlite-0.44.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5d22c3bfc842668168a786af4205ec8e3ad29fb1bc03fd11fd48460d0df64c1"}, + {file = "llvmlite-0.44.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f01a394e9c9b7b1d4e63c327b096d10f6f0ed149ef53d38a09b3749dcf8c9610"}, + {file = "llvmlite-0.44.0-cp311-cp311-win_amd64.whl", hash = "sha256:d8489634d43c20cd0ad71330dde1d5bc7b9966937a263ff1ec1cebb90dc50955"}, + {file = "llvmlite-0.44.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:1d671a56acf725bf1b531d5ef76b86660a5ab8ef19bb6a46064a705c6ca80aad"}, + {file = "llvmlite-0.44.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5f79a728e0435493611c9f405168682bb75ffd1fbe6fc360733b850c80a026db"}, + {file = "llvmlite-0.44.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0143a5ef336da14deaa8ec26c5449ad5b6a2b564df82fcef4be040b9cacfea9"}, + {file = "llvmlite-0.44.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d752f89e31b66db6f8da06df8b39f9b91e78c5feea1bf9e8c1fba1d1c24c065d"}, + {file = "llvmlite-0.44.0-cp312-cp312-win_amd64.whl", hash = "sha256:eae7e2d4ca8f88f89d315b48c6b741dcb925d6a1042da694aa16ab3dd4cbd3a1"}, + {file = "llvmlite-0.44.0-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:319bddd44e5f71ae2689859b7203080716448a3cd1128fb144fe5c055219d516"}, + {file = "llvmlite-0.44.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c58867118bad04a0bb22a2e0068c693719658105e40009ffe95c7000fcde88e"}, + {file = "llvmlite-0.44.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46224058b13c96af1365290bdfebe9a6264ae62fb79b2b55693deed11657a8bf"}, + {file = "llvmlite-0.44.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aa0097052c32bf721a4efc03bd109d335dfa57d9bffb3d4c24cc680711b8b4fc"}, + {file = "llvmlite-0.44.0-cp313-cp313-win_amd64.whl", hash = "sha256:2fb7c4f2fb86cbae6dca3db9ab203eeea0e22d73b99bc2341cdf9de93612e930"}, + {file = "llvmlite-0.44.0.tar.gz", hash = "sha256:07667d66a5d150abed9157ab6c0b9393c9356f229784a4385c02f99e94fc94d4"}, +] + [[package]] name = "locket" version = "1.0.0" @@ -1944,7 +1950,6 @@ files = [ contourpy = ">=1.0.1" cycler = ">=0.10" fonttools = ">=4.22.0" -importlib-resources = {version = ">=3.2.0", markers = "python_version < \"3.10\""} kiwisolver = ">=1.3.1" numpy = ">=1.23" packaging = ">=20.0" @@ -2157,7 +2162,6 @@ files = [ beautifulsoup4 = "*" bleach = "!=5.0.0" defusedxml = "*" -importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""} jinja2 = ">=3.0" jupyter-core = ">=4.7" jupyterlab-pygments = "*" @@ -2310,6 +2314,40 @@ files = [ {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, ] +[[package]] +name = "numba" +version = "0.61.0" +description = "compiling Python code using LLVM" +optional = false +python-versions = ">=3.10" +files = [ + {file = "numba-0.61.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:9cab9783a700fa428b1a54d65295122bc03b3de1d01fb819a6b9dbbddfdb8c43"}, + {file = "numba-0.61.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:46c5ae094fb3706f5adf9021bfb7fc11e44818d61afee695cdee4eadfed45e98"}, + {file = "numba-0.61.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6fb74e81aa78a2303e30593d8331327dfc0d2522b5db05ac967556a26db3ef87"}, + {file = "numba-0.61.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:0ebbd4827091384ab8c4615ba1b3ca8bc639a3a000157d9c37ba85d34cd0da1b"}, + {file = "numba-0.61.0-cp310-cp310-win_amd64.whl", hash = "sha256:43aa4d7d10c542d3c78106b8481e0cbaaec788c39ee8e3d7901682748ffdf0b4"}, + {file = "numba-0.61.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:bf64c2d0f3d161af603de3825172fb83c2600bcb1d53ae8ea568d4c53ba6ac08"}, + {file = "numba-0.61.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:de5aa7904741425f28e1028b85850b31f0a245e9eb4f7c38507fb893283a066c"}, + {file = "numba-0.61.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21c2fe25019267a608e2710a6a947f557486b4b0478b02e45a81cf606a05a7d4"}, + {file = "numba-0.61.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:74250b26ed6a1428763e774dc5b2d4e70d93f73795635b5412b8346a4d054574"}, + {file = "numba-0.61.0-cp311-cp311-win_amd64.whl", hash = "sha256:b72bbc8708e98b3741ad0c63f9929c47b623cc4ee86e17030a4f3e301e8401ac"}, + {file = "numba-0.61.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:152146ecdbb8d8176f294e9f755411e6f270103a11c3ff50cecc413f794e52c8"}, + {file = "numba-0.61.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5cafa6095716fcb081618c28a8d27bf7c001e09696f595b41836dec114be2905"}, + {file = "numba-0.61.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ffe9fe373ed30638d6e20a0269f817b2c75d447141f55a675bfcf2d1fe2e87fb"}, + {file = "numba-0.61.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9f25f7fef0206d55c1cfb796ad833cbbc044e2884751e56e798351280038484c"}, + {file = "numba-0.61.0-cp312-cp312-win_amd64.whl", hash = "sha256:550d389573bc3b895e1ccb18289feea11d937011de4d278b09dc7ed585d1cdcb"}, + {file = "numba-0.61.0-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:b96fafbdcf6f69b69855273e988696aae4974115a815f6818fef4af7afa1f6b8"}, + {file = "numba-0.61.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5f6c452dca1de8e60e593f7066df052dd8da09b243566ecd26d2b796e5d3087d"}, + {file = "numba-0.61.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:44240e694d4aa321430c97b21453e46014fe6c7b8b7d932afa7f6a88cc5d7e5e"}, + {file = "numba-0.61.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:764f0e47004f126f58c3b28e0a02374c420a9d15157b90806d68590f5c20cc89"}, + {file = "numba-0.61.0-cp313-cp313-win_amd64.whl", hash = "sha256:074cd38c5b1f9c65a4319d1f3928165f48975ef0537ad43385b2bd908e6e2e35"}, + {file = "numba-0.61.0.tar.gz", hash = "sha256:888d2e89b8160899e19591467e8fdd4970e07606e1fbc248f239c89818d5f925"}, +] + +[package.dependencies] +llvmlite = "==0.44.*" +numpy = ">=1.24,<2.2" + [[package]] name = "numpy" version = "2.0.2" @@ -2720,7 +2758,6 @@ cleo = ">=2.1.0,<3.0.0" crashtest = ">=0.4.1,<0.5.0" dulwich = ">=0.21.2,<0.22.0" fastjsonschema = ">=2.18.0,<3.0.0" -importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} installer = ">=0.7.0,<0.8.0" keyring = ">=24.0.0,<25.0.0" packaging = ">=23.1" @@ -3511,7 +3548,6 @@ certifi = "*" click = ">=4.0" click-plugins = "*" cligj = ">=0.5" -importlib-metadata = {version = "*", markers = "python_version < \"3.10\""} numpy = ">=1.24" pyparsing = "*" @@ -3943,6 +3979,28 @@ files = [ {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, ] +[[package]] +name = "sparse" +version = "0.15.5" +description = "Sparse n-dimensional arrays for the PyData ecosystem" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sparse-0.15.5-py2.py3-none-any.whl", hash = "sha256:cf608731f8564916443427bca323fe118e8c25a712ddf02dbe7673a961139706"}, + {file = "sparse-0.15.5.tar.gz", hash = "sha256:4c76ce0c96f5cd5c31b7e79e650f0022424c2b16f05f10049e9c6381ee4be266"}, +] + +[package.dependencies] +numba = ">=0.49" +numpy = ">=1.17" +scipy = ">=0.19" + +[package.extras] +all = ["matrepr", "sparse[docs,tox]"] +docs = ["sphinx", "sphinx_rtd_theme"] +tests = ["dask[array]", "pre-commit", "pytest (>=3.5)", "pytest-cov"] +tox = ["sparse[tests]", "tox"] + [[package]] name = "sphinx" version = "7.4.7" @@ -3960,7 +4018,6 @@ babel = ">=2.13" colorama = {version = ">=0.4.6", markers = "sys_platform == \"win32\""} docutils = ">=0.20,<0.22" imagesize = ">=1.3" -importlib-metadata = {version = ">=6.0", markers = "python_version < \"3.10\""} Jinja2 = ">=3.1" packaging = ">=23.0" Pygments = ">=2.17" @@ -4385,27 +4442,28 @@ files = [ [[package]] name = "xarray" -version = "2024.7.0" +version = "2025.1.2" description = "N-D labeled arrays and datasets in Python" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" files = [ - {file = "xarray-2024.7.0-py3-none-any.whl", hash = "sha256:1b0fd51ec408474aa1f4a355d75c00cc1c02bd425d97b2c2e551fd21810e7f64"}, - {file = "xarray-2024.7.0.tar.gz", hash = "sha256:4cae512d121a8522d41e66d942fb06c526bc1fd32c2c181d5fe62fe65b671638"}, + {file = "xarray-2025.1.2-py3-none-any.whl", hash = "sha256:a7ad6a36c6e0becd67f8aff6a7808d20e4bdcd344debb5205f0a34b1a4a7f8d6"}, + {file = "xarray-2025.1.2.tar.gz", hash = "sha256:e7675c79ac69d274dd3b3c5450ce57176928d2792947576251ed1c7df1783224"}, ] [package.dependencies] -numpy = ">=1.23" -packaging = ">=23.1" -pandas = ">=2.0" +numpy = ">=1.24" +packaging = ">=23.2" +pandas = ">=2.1" [package.extras] -accel = ["bottleneck", "flox", "numbagg", "opt-einsum", "scipy"] -complete = ["xarray[accel,dev,io,parallel,viz]"] -dev = ["hypothesis", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-env", "pytest-timeout", "pytest-xdist", "ruff", "xarray[complete]"] +accel = ["bottleneck", "flox", "numba (>=0.54)", "numbagg", "opt_einsum", "scipy"] +complete = ["xarray[accel,etc,io,parallel,viz]"] +dev = ["hypothesis", "jinja2", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-env", "pytest-timeout", "pytest-xdist", "ruff (>=0.8.0)", "sphinx", "sphinx_autosummary_accessors", "xarray[complete]"] +etc = ["sparse"] io = ["cftime", "fsspec", "h5netcdf", "netCDF4", "pooch", "pydap", "scipy", "zarr"] parallel = ["dask[complete]"] -viz = ["matplotlib", "nc-time-axis", "seaborn"] +viz = ["cartopy", "matplotlib", "nc-time-axis", "seaborn"] [[package]] name = "xarray-datatree" @@ -4591,5 +4649,5 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" -python-versions = ">=3.10,<4.0" -content-hash = "2fe86a5d9995efd2264db2eb4badc00cf319f4372bb22bda0f1494fbcc76528c" +python-versions = "^3.10" +content-hash = "dc31e08f220afd91858167191e311fbfcfd5ede73b2597d1b8bd1cd69c72270a" diff --git a/pyproject.toml b/pyproject.toml index bdd5c50b..8110a0a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ python = "^3.10" regionmask = "^0.12.1" scipy = "^1.13.0" shapely = "^2.0.3" -xarray = ">2024.2.0,<2024.10" +xarray = ">2024.2.0" xarray-datatree = "^0.0.14" seaborn = "^0.13.2" xesmf = "^0.8.8" diff --git a/src/valenspy/diagnostic/diagnostic.py b/src/valenspy/diagnostic/diagnostic.py index 490912b3..c8ed2330 100644 --- a/src/valenspy/diagnostic/diagnostic.py +++ b/src/valenspy/diagnostic/diagnostic.py @@ -1,29 +1,26 @@ -from datatree import DataTree +from xarray import DataTree import xarray as xr import matplotlib.pyplot as plt from valenspy.processing.mask import add_prudence_regions from valenspy.diagnostic.plot_utils import default_plot_kwargs, _augment_kwargs -#Import get_axis from xarray -from xarray.plot.utils import get_axis from abc import abstractmethod import warnings - class Diagnostic: """An abstract class representing a diagnostic.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None + self, diagnostic_function, plotting_functions, name=None, description=None ): """Initialize the Diagnostic. Parameters ---------- - diagnostic_function + diagnostic_function : function The function that applies a diagnostic to the data. - plotting_function - The function that visualizes the results of the diagnostic. + plotting_functions : function or dict of functions + The functions or dictionary of functions that visualize the diagnostic. name : str The name of the diagnostic. description : str @@ -32,7 +29,11 @@ def __init__( self.name = name self._description = description self.diagnostic_function = diagnostic_function - self.plotting_function = plotting_function + + if callable(plotting_functions): + plotting_functions = {"default": plotting_functions} + + self.plotting_functions = plotting_functions def __call__(self, data, *args, **kwargs): return self.apply(data, *args, **kwargs) @@ -54,7 +55,7 @@ def apply(self, data): """ pass - def plot(self, result, title=None, **kwargs): + def plot(self, result, kind="default", title=None, **kwargs): """Plot the diagnostic. Single ax plots. Parameters @@ -67,22 +68,39 @@ def plot(self, result, title=None, **kwargs): ax : matplotlib.axis.Axis The axis (singular) of the plot. """ - ax = self.plotting_function(result, **kwargs) + ax = self.plotting_functions[kind](result, **kwargs) if not title: title = self.name ax.set_title(title) return ax + #Support easy access to the plotting functions + # class PlotAccessor: + # """An accessor to the plotting functions of the diagnostic.""" + + # def __init__(self, diagnostic): + # self.diagnostic = diagnostic + + # def __getattr__(self, kind): + # def plot_kind(*args, **kwargs): + # return self._diagnostic.plot(*args, kind=kind, **kwargs) + # return plot_kind + + # @property + # def plot(self): + # return self.PlotAccessor(self) + + @property def description(self): """Return the description of the diagnostic a combination of the name, the type and the description and the docstring of the diagnostic and plot functions.""" - return f"{self.name} ({self.__class__.__name__})\n{self._description}\n Diagnostic function: {self.diagnostic_function.__name__}\n {self.diagnostic_function.__doc__}\n Visualization function: {self.plotting_function.__name__}\n {self.plotting_function.__doc__}" + return f"{self.name} ({self.__class__.__name__})\n{self._description}\n Diagnostic function: {self.diagnostic_function.__name__}\n {self.diagnostic_function.__doc__}\n Visualization function: {self.plotting_functions.__name__}\n" class DataSetDiagnostic(Diagnostic): """A class representing a diagnostic that operates on the level of single datasets.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None, plot_type="single" + self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="single" ): """ Initialize the DataSetDiagnostic. @@ -94,7 +112,7 @@ def __init__( If "single", plot_dt will plot all the leaves of the DataTree on the same axis. If "facetted", plot_dt will plot all the leaves of the DataTree on different axes. """ - super().__init__(diagnostic_function, plotting_function, name, description) + super().__init__(diagnostic_function, plotting_functions, name, description) self.plot_type = plot_type def __call__(self, data, *args, **kwargs): @@ -118,6 +136,7 @@ def apply_dt(self, dt: DataTree, *args, **kwargs): """ return dt.map_over_subtree(self.apply, *args, **kwargs) + #Currently no support for different plotting kinds def plot_dt(self, dt, *args, **kwargs): if self.plot_type == "single": return self.plot_dt_single(dt, *args, **kwargs) @@ -211,10 +230,10 @@ class Model2Self(DataSetDiagnostic): """A class representing a diagnostic that compares a model to itself.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None, plot_type="single" + self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="single" ): """Initialize the Model2Self diagnostic.""" - super().__init__(diagnostic_function, plotting_function, name, description, plot_type) + super().__init__(diagnostic_function, plotting_functions, name, description, plot_type) def apply(self, ds: xr.Dataset, mask=None, **kwargs): """Apply the diagnostic to the data. @@ -256,10 +275,10 @@ class Model2Ref(DataSetDiagnostic): """A class representing a diagnostic that compares a model to a reference.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None, plot_type="facetted" + self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="facetted" ): """Initialize the Model2Ref diagnostic.""" - super().__init__(diagnostic_function, plotting_function, name, description, plot_type) + super().__init__(diagnostic_function, plotting_functions, name, description, plot_type) def apply(self, ds: xr.Dataset, ref: xr.Dataset, mask=None, **kwargs): """Apply the diagnostic to the data. Only the common variables between the data and the reference are used. @@ -310,11 +329,11 @@ class Ensemble2Self(Diagnostic): """A class representing a diagnostic that compares an ensemble to itself.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None, iterative_plotting=False + self, diagnostic_function, plotting_functions, name=None, description=None, iterative_plotting=False ): """Initialize the Ensemble2Self diagnostic.""" self.iterative_plotting = iterative_plotting - super().__init__(diagnostic_function, plotting_function, name, description) + super().__init__(diagnostic_function, plotting_functions, name, description) def apply(self, dt: DataTree, mask=None, **kwargs): @@ -335,7 +354,7 @@ def apply(self, dt: DataTree, mask=None, **kwargs): return self.diagnostic_function(dt, **kwargs) - def plot(self, result, variables=None, title=None, facetted=None, **kwargs): + def plot(self, result, kind="default", variables=None, title=None, facetted=None, **kwargs): """Plot the diagnostic. If facetted multiple plots on different axes are created. If not facetted, the plots are created on the same axis. @@ -363,10 +382,10 @@ class Ensemble2Ref(Diagnostic): """A class representing a diagnostic that compares an ensemble to a reference.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None + self, diagnostic_function, plotting_functions, name=None, description=None ): """Initialize the Ensemble2Ref diagnostic.""" - super().__init__(diagnostic_function, plotting_function, name, description) + super().__init__(diagnostic_function, plotting_functions, name, description) def apply(self, dt: DataTree, ref, **kwargs): """Apply the diagnostic to the data. @@ -383,10 +402,13 @@ def apply(self, dt: DataTree, ref, **kwargs): DataTree or dict The data after applying the diagnostic as a DataTree or a dictionary of results with the tree nodes as keys. """ - # TODO: Add some checks to make sure the reference is a DataTree or a Dataset and contain common variables with the data. + #Make sure that the dt and ref are isomorphic + if isinstance(ref, DataTree): + dt = filter_like(dt, ref) + ref = filter_like(ref, dt) return self.diagnostic_function(dt, ref, **kwargs) - def plot(self, result, facetted=True, **kwargs): + def plot(self, result, kind="default", facetted=True, **kwargs): """Plot the diagnostic. If axes are provided, the diagnostic is plotted facetted. If ax is provided, the diagnostic is plotted non-facetted. @@ -406,9 +428,9 @@ def plot(self, result, facetted=True, **kwargs): raise ValueError("Either ax or axes can be provided, not both.") elif "ax" not in kwargs and "axes" not in kwargs: ax = plt.gca() - return self.plotting_function(result, ax=ax, **kwargs) + return self.plotting_functions[kind](result, ax=ax, **kwargs) else: - return self.plotting_function(result, **kwargs) + return self.plotting_functions[kind](result, **kwargs) def _common_vars(ds1, ds2): """Return the common variables in two datasets.""" @@ -426,6 +448,11 @@ def _initialize_multiaxis_plot(n, subplot_kws={}): ) return fig, axes +def filter_like(dt, other): + """Filter the dt by the ref.""" + other = {key for key,_ in other.subtree_with_keys} + return dt.filter(lambda node: node.relative_to(dt) in other) + # ============================================================================= # Pre-made diagnostics # ============================================================================= @@ -502,12 +529,19 @@ def _initialize_multiaxis_plot(n, subplot_kws={}): # Ensemble2Ref diagnostics EnsembleSubSelection = Ensemble2Ref( case_sub_selection, - default_plot_kwargs({ - "x": "var", - "y": "abs_change", - "selected": ["highest", "middle", "lowest"], - "sel_colors": {"highest": "red", "middle": "blue", "lowest": "green"} - })(ensemble_selection_boxplot), + {"default": + default_plot_kwargs({ + "x": "var", + "y": "abs_change", + "selected": ["highest", "middle", "lowest"], + "sel_colors": {"highest": "red", "middle": "blue", "lowest": "green"} + })(ensemble_selection_boxplot), + "heatmap": + default_plot_kwargs({ + "index": "label", + "columns": "var", + "values": "rel_change" + })(ensemble_change_signal_heatmap)}, "Ensemble Sub Selection", "The sub selection of ensemble members." ) \ No newline at end of file diff --git a/src/valenspy/diagnostic/functions.py b/src/valenspy/diagnostic/functions.py index 95638d01..ad0274cf 100644 --- a/src/valenspy/diagnostic/functions.py +++ b/src/valenspy/diagnostic/functions.py @@ -1,7 +1,7 @@ import xarray as xr import numpy as np from scipy.stats import spearmanr -from datatree import DataTree +from xarray import DataTree import pandas as pd from functools import partial @@ -270,25 +270,38 @@ def case_sub_selection(dt_future: DataTree, dt_ref: DataTree, vars): Select 3 ensemble members based on avg normalized climate change in the variable of interest. The two extreme and mediaan members are selected. """ + #TODO - add direction of indicators + #TODO - Check consistnecy with paper (slightly different approach)! + #TEST how to leave an extra dimension left over - e.g. regions + #Test with different periods of variables (tas_JJA, etc) + #Improve the heatmap plot (square default, absolute values of the change?, larger squares) dt_change = dt_future.mean() - dt_ref.mean() dt_rel_change = dt_change / dt_ref.mean() - data = [[x.path, var, x[var].values] for x in dt_rel_change.leaves for var in vars if var in x] - data_abs = [[x.path, var, x[var].values] for x in dt_change.leaves for var in vars if var in x] + data = [[x.path, var, x[var].values] for x in dt_rel_change.leaves for var in x.data_vars] + data_abs = [[x.path, var, x[var].values] for x in dt_change.leaves for var in x.data_vars] #Create one dataframe containing the relative change and the absolute change df = pd.DataFrame(data, columns=["label", "var", "rel_change"]) df_abs = pd.DataFrame(data_abs, columns=["label", "var", "abs_change"]) df = pd.merge(df, df_abs, on=["label", "var"]) + df["rel_change"] = df["rel_change"].astype(float) df["abs_change"] = df["abs_change"].astype(float) - df["mean"] = df["rel_change"].groupby(df["label"]).transform("mean") - df["rank"] = df["mean"].groupby(df["var"]).rank(ascending=True, method='min') - df["highest"] = df["rank"] == 1 - df["lowest"] = df["rank"] == df["rank"].max() - df["middle"] = df["rank"] == np.floor(df["rank"].median()) + #Normalize the absolute change per variable + df["norm_rel_change"] = df.groupby("var")["rel_change"].transform(lambda x: (x - x.mean()) / x.std()) + df["norm_abs_change"] = df.groupby("var")["abs_change"].transform(lambda x: (x - x.mean()) / x.std()) + + #Get the rank per variable + df["rank"] = df["norm_abs_change"].groupby(df["var"]).rank(ascending=True, method='min') + df["member_rank"] = df["rank"].where(df["var"].isin(vars)).groupby(df["label"]).transform("mean").rank(ascending=True, method='min') # + + #Rank the ensemble members based on the mean of the normalized absolute change + df["highest"] = df["member_rank"] == 1 + df["lowest"] = df["member_rank"] == df["member_rank"].max() + df["middle"] = df["member_rank"] == np.floor(df["member_rank"].median()) return df diff --git a/src/valenspy/diagnostic/visualizations.py b/src/valenspy/diagnostic/visualizations.py index c4463055..8196fb06 100644 --- a/src/valenspy/diagnostic/visualizations.py +++ b/src/valenspy/diagnostic/visualizations.py @@ -531,6 +531,28 @@ def ensemble_selection_boxplot(df, selected=None, sel_colors=None, ax=None, **kw ax = _get_gca(**kwargs) return ax +def ensemble_change_signal_heatmap(df, index=None, columns=None, values=None, ax=None, **kwargs): + """ + Create a heatmap of the ensemble change signal for different variables per ensemble member. + + Parameters + ---------- + df : pd.DataFrame + A DataFrame containing the ensemble change signals for different variables per ensemble member. + index : str, optional + The column name to use as the index for the heatmap. + columns : str, optional + The column name to use as the columns for the heatmap. + values : str, optional + The column name to use as the climate signal values for the heatmap. + ax : matplotlib.axes.Axes, optional + The axes on which to plot the heatmap. If None, a new figure and axes are created. + **kwargs : dict + Additional keyword arguments passed to `seaborn.heatmap`. + """ + sns.heatmap(data=df.pivot(index=index, columns=columns, values=values), ax=ax, **kwargs) + ax = _get_gca(**kwargs) + return ax ################################## # Helper functions # diff --git a/src/valenspy/input/manager.py b/src/valenspy/input/manager.py index 2b9f999c..bd86d95d 100644 --- a/src/valenspy/input/manager.py +++ b/src/valenspy/input/manager.py @@ -1,6 +1,6 @@ from pathlib import Path import xarray as xr -from datatree import DataTree +from xarray import DataTree import re import glob From 8a05ef3e54399d463253e87c0477395b21e3c5af Mon Sep 17 00:00:00 2001 From: kobebryant432 Date: Thu, 12 Feb 2026 14:05:37 +0100 Subject: [PATCH 5/5] move diagnostic to dedicated folder --- src/valenspy/diagnostic/_ensemble2ref.py | 26 ++++++++++++++++++++++-- src/valenspy/diagnostic/diagnostic.py | 19 ----------------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/valenspy/diagnostic/_ensemble2ref.py b/src/valenspy/diagnostic/_ensemble2ref.py index 5cf51100..e1c74431 100644 --- a/src/valenspy/diagnostic/_ensemble2ref.py +++ b/src/valenspy/diagnostic/_ensemble2ref.py @@ -2,11 +2,33 @@ from valenspy.diagnostic.functions import * from valenspy.diagnostic.visualizations import * -__all__ = ["MetricsRankings"] +__all__ = [ + "MetricsRankings", + "EnsembleSubSelection", + ] MetricsRankings = Ensemble2Ref( calc_metrics_dt, plot_metric_ranking, "Metrics Rankings", "The rankings of ensemble members with respect to several metrics when compared to the reference." -) \ No newline at end of file +) + +EnsembleSubSelection = Ensemble2Ref( + case_sub_selection, + {"default": + default_plot_kwargs({ + "x": "var", + "y": "abs_change", + "selected": ["highest", "middle", "lowest"], + "sel_colors": {"highest": "red", "middle": "blue", "lowest": "green"} + })(ensemble_selection_boxplot), + "heatmap": + default_plot_kwargs({ + "index": "label", + "columns": "var", + "values": "rel_change" + })(ensemble_change_signal_heatmap)}, + "Ensemble Sub Selection", + "The sub selection of ensemble members." + ) diff --git a/src/valenspy/diagnostic/diagnostic.py b/src/valenspy/diagnostic/diagnostic.py index f1ece74e..473a493e 100644 --- a/src/valenspy/diagnostic/diagnostic.py +++ b/src/valenspy/diagnostic/diagnostic.py @@ -437,22 +437,3 @@ def _initialize_multiaxis_plot(n, subplot_kws={}): ) return fig, axes -# Ensemble2Ref diagnostics -EnsembleSubSelection = Ensemble2Ref( - case_sub_selection, - {"default": - default_plot_kwargs({ - "x": "var", - "y": "abs_change", - "selected": ["highest", "middle", "lowest"], - "sel_colors": {"highest": "red", "middle": "blue", "lowest": "green"} - })(ensemble_selection_boxplot), - "heatmap": - default_plot_kwargs({ - "index": "label", - "columns": "var", - "values": "rel_change" - })(ensemble_change_signal_heatmap)}, - "Ensemble Sub Selection", - "The sub selection of ensemble members." -)