diff --git a/src/valenspy/diagnostic/__init__.py b/src/valenspy/diagnostic/__init__.py index 0d9fac73..c67c0a93 100644 --- a/src/valenspy/diagnostic/__init__.py +++ b/src/valenspy/diagnostic/__init__.py @@ -7,4 +7,4 @@ from ._model2self import * from ._model2ref import * from ._ensemble2self import * -from ._ensemble2ref import * \ No newline at end of file +from ._ensemble2ref import * diff --git a/src/valenspy/diagnostic/_ensemble2ref.py b/src/valenspy/diagnostic/_ensemble2ref.py index 5cf51100..e1c74431 100644 --- a/src/valenspy/diagnostic/_ensemble2ref.py +++ b/src/valenspy/diagnostic/_ensemble2ref.py @@ -2,11 +2,33 @@ from valenspy.diagnostic.functions import * from valenspy.diagnostic.visualizations import * -__all__ = ["MetricsRankings"] +__all__ = [ + "MetricsRankings", + "EnsembleSubSelection", + ] MetricsRankings = Ensemble2Ref( calc_metrics_dt, plot_metric_ranking, "Metrics Rankings", "The rankings of ensemble members with respect to several metrics when compared to the reference." -) \ No newline at end of file +) + +EnsembleSubSelection = Ensemble2Ref( + case_sub_selection, + {"default": + default_plot_kwargs({ + "x": "var", + "y": "abs_change", + "selected": ["highest", "middle", "lowest"], + "sel_colors": {"highest": "red", "middle": "blue", "lowest": "green"} + })(ensemble_selection_boxplot), + "heatmap": + default_plot_kwargs({ + "index": "label", + "columns": "var", + "values": "rel_change" + })(ensemble_change_signal_heatmap)}, + "Ensemble Sub Selection", + "The sub selection of ensemble members." + ) diff --git a/src/valenspy/diagnostic/diagnostic.py b/src/valenspy/diagnostic/diagnostic.py index 926915dd..473a493e 100644 --- a/src/valenspy/diagnostic/diagnostic.py +++ b/src/valenspy/diagnostic/diagnostic.py @@ -2,7 +2,7 @@ import xarray as xr import matplotlib.pyplot as plt from valenspy.processing.mask import add_prudence_regions -from valenspy.diagnostic.plot_utils import _augment_kwargs +from valenspy.diagnostic.plot_utils import default_plot_kwargs, _augment_kwargs from valenspy._utilities import generate_parameters_doc import numpy as np import inspect @@ -18,16 +18,16 @@ class Diagnostic(): """An abstract class representing a diagnostic.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None + self, diagnostic_function, plotting_functions, name=None, description=None ): """Initialize the Diagnostic. Parameters ---------- - diagnostic_function + diagnostic_function : function The function that applies a diagnostic to the data. - plotting_function - The function that visualizes the results of the diagnostic. + plotting_functions : function or dict of functions + The functions or dictionary of functions that visualize the diagnostic. name : str The name of the diagnostic. description : str @@ -36,7 +36,11 @@ def __init__( self.name = name self._description = description self.diagnostic_function = diagnostic_function - self.plotting_function = plotting_function + + if callable(plotting_functions): + plotting_functions = {"default": plotting_functions} + + self.plotting_functions = plotting_functions self.__signature__ = inspect.signature(self.diagnostic_function) self.__doc__ = self.description @@ -61,7 +65,7 @@ def apply(self, data): """ pass - def plot(self, result, title=None, **kwargs): + def plot(self, result, kind="default", title=None, **kwargs): """Plot the diagnostic. Single ax plots. Parameters @@ -78,12 +82,29 @@ def plot(self, result, title=None, **kwargs): ax : matplotlib.axis.Axis The axis (singular) of the plot. """ - ax = self.plotting_function(result, **kwargs) + ax = self.plotting_functions[kind](result, **kwargs) if not title: title = self.name ax.set_title(title) return ax + #Support easy access to the plotting functions + # class PlotAccessor: + # """An accessor to the plotting functions of the diagnostic.""" + + # def __init__(self, diagnostic): + # self.diagnostic = diagnostic + + # def __getattr__(self, kind): + # def plot_kind(*args, **kwargs): + # return self._diagnostic.plot(*args, kind=kind, **kwargs) + # return plot_kind + + # @property + # def plot(self): + # return self.PlotAccessor(self) + + @property def description(self): """Generate the docstring for the diagnostic.""" @@ -100,7 +121,7 @@ class DataSetDiagnostic(Diagnostic): """A class representing a diagnostic that operates on the level of single datasets.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None, plot_type="single" + self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="single" ): """ Initialize the DataSetDiagnostic. @@ -112,7 +133,7 @@ def __init__( If "single", plot_dt will plot all the leaves of the DataTree on the same axis. If "facetted", plot_dt will plot all the leaves of the DataTree on different axes. """ - super().__init__(diagnostic_function, plotting_function, name, description) + super().__init__(diagnostic_function, plotting_functions, name, description) self.plot_type = plot_type def __call__(self, data, *args, **kwargs): @@ -164,6 +185,7 @@ def apply(self, ds: xr.Dataset, *args, **kwargs): """ return self.diagnostic_function(ds, *args, **kwargs) + #Currently no support for different plotting kinds def plot_dt(self, dt, *args, **kwargs): if self.plot_type == "single": return self.plot_dt_single(dt, *args, **kwargs) @@ -257,20 +279,20 @@ class Model2Self(DataSetDiagnostic): """A class representing a diagnostic that compares a model to itself.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None, plot_type="single" + self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="single" ): """Initialize the Model2Self diagnostic.""" - super().__init__(diagnostic_function, plotting_function, name, description, plot_type) + super().__init__(diagnostic_function, plotting_functions, name, description, plot_type) class Model2Ref(DataSetDiagnostic): """A class representing a diagnostic that compares a model to a reference.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None, plot_type="facetted" + self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="facetted" ): """Initialize the Model2Ref diagnostic.""" - super().__init__(diagnostic_function, plotting_function, name, description, plot_type) + super().__init__(diagnostic_function, plotting_functions, name, description, plot_type) def apply(self, ds: xr.Dataset, ref: xr.Dataset, **kwargs): """Apply the diagnostic to the data. Only the common variables between the data and the reference are used. @@ -296,11 +318,11 @@ class Ensemble2Self(Diagnostic): """A class representing a diagnostic that compares an ensemble to itself.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None, iterative_plotting=False + self, diagnostic_function, plotting_functions, name=None, description=None, iterative_plotting=False ): """Initialize the Ensemble2Self diagnostic.""" self.iterative_plotting = iterative_plotting - super().__init__(diagnostic_function, plotting_function, name, description) + super().__init__(diagnostic_function, plotting_functions, name, description) def apply(self, dt: DataTree, mask=None, **kwargs): @@ -321,7 +343,7 @@ def apply(self, dt: DataTree, mask=None, **kwargs): return self.diagnostic_function(dt, **kwargs) - def plot(self, result, variables=None, title=None, facetted=None, **kwargs): + def plot(self, result, kind="default", variables=None, title=None, facetted=None, **kwargs): """Plot the diagnostic. If facetted multiple plots on different axes are created. If not facetted, the plots are created on the same axis. @@ -349,10 +371,10 @@ class Ensemble2Ref(Diagnostic): """A class representing a diagnostic that compares an ensemble to a reference.""" def __init__( - self, diagnostic_function, plotting_function, name=None, description=None + self, diagnostic_function, plotting_functions, name=None, description=None ): """Initialize the Ensemble2Ref diagnostic.""" - super().__init__(diagnostic_function, plotting_function, name, description) + super().__init__(diagnostic_function, plotting_functions, name, description) def apply(self, dt: DataTree, ref, **kwargs): """Apply the diagnostic to the data. @@ -369,10 +391,13 @@ def apply(self, dt: DataTree, ref, **kwargs): DataTree or dict The data after applying the diagnostic as a DataTree or a dictionary of results with the tree nodes as keys. """ - # TODO: Add some checks to make sure the reference is a DataTree or a Dataset and contain common variables with the data. + #Make sure that the dt and ref are isomorphic + if isinstance(ref, DataTree): + dt = filter_like(dt, ref) + ref = filter_like(ref, dt) return self.diagnostic_function(dt, ref, **kwargs) - def plot(self, result, facetted=True, **kwargs): + def plot(self, result, kind="default", facetted=True, **kwargs): """Plot the diagnostic. If axes are provided, the diagnostic is plotted facetted. If ax is provided, the diagnostic is plotted non-facetted. @@ -392,9 +417,9 @@ def plot(self, result, facetted=True, **kwargs): raise ValueError("Either ax or axes can be provided, not both.") elif "ax" not in kwargs and "axes" not in kwargs: ax = plt.gca() - return self.plotting_function(result, ax=ax, **kwargs) + return self.plotting_functions[kind](result, ax=ax, **kwargs) else: - return self.plotting_function(result, **kwargs) + return self.plotting_functions[kind](result, **kwargs) def _common_vars(ds1, ds2): """Return the common variables in two datasets.""" @@ -411,3 +436,4 @@ def _initialize_multiaxis_plot(n, subplot_kws={}): nrows=n//2+1, ncols=2, figsize=(10, 5 * n), subplot_kw=subplot_kws ) return fig, axes + diff --git a/src/valenspy/diagnostic/functions.py b/src/valenspy/diagnostic/functions.py index 37855f90..62c1e43e 100644 --- a/src/valenspy/diagnostic/functions.py +++ b/src/valenspy/diagnostic/functions.py @@ -314,6 +314,50 @@ def calc_metrics_dt(dt_mod: DataTree, da_obs: xr.Dataset, metrics=None, pss_binw df = _add_ranks_metrics(df) return df +########################################## +# Ensemble2Ensemble diagnostic functions # +########################################## + +def case_sub_selection(dt_future: DataTree, dt_ref: DataTree, vars): + """ + Select 3 ensemble members based on avg normalized climate change in the variable of interest. + The two extreme and mediaan members are selected. + """ + #TODO - add direction of indicators + #TODO - Check consistnecy with paper (slightly different approach)! + #TEST how to leave an extra dimension left over - e.g. regions + #Test with different periods of variables (tas_JJA, etc) + #Improve the heatmap plot (square default, absolute values of the change?, larger squares) + dt_change = dt_future.mean() - dt_ref.mean() + dt_rel_change = dt_change / dt_ref.mean() + + data = [[x.path, var, x[var].values] for x in dt_rel_change.leaves for var in x.data_vars] + data_abs = [[x.path, var, x[var].values] for x in dt_change.leaves for var in x.data_vars] + + #Create one dataframe containing the relative change and the absolute change + df = pd.DataFrame(data, columns=["label", "var", "rel_change"]) + df_abs = pd.DataFrame(data_abs, columns=["label", "var", "abs_change"]) + + df = pd.merge(df, df_abs, on=["label", "var"]) + + df["rel_change"] = df["rel_change"].astype(float) + df["abs_change"] = df["abs_change"].astype(float) + + #Normalize the absolute change per variable + df["norm_rel_change"] = df.groupby("var")["rel_change"].transform(lambda x: (x - x.mean()) / x.std()) + df["norm_abs_change"] = df.groupby("var")["abs_change"].transform(lambda x: (x - x.mean()) / x.std()) + + #Get the rank per variable + df["rank"] = df["norm_abs_change"].groupby(df["var"]).rank(ascending=True, method='min') + df["member_rank"] = df["rank"].where(df["var"].isin(vars)).groupby(df["label"]).transform("mean").rank(ascending=True, method='min') # + + #Rank the ensemble members based on the mean of the normalized absolute change + df["highest"] = df["member_rank"] == 1 + df["lowest"] = df["member_rank"] == df["member_rank"].max() + df["middle"] = df["member_rank"] == np.floor(df["member_rank"].median()) + + return df + ################################## ####### Helper functions ######### diff --git a/src/valenspy/diagnostic/plot_utils.py b/src/valenspy/diagnostic/plot_utils.py index 931200e3..f24fe1f9 100644 --- a/src/valenspy/diagnostic/plot_utils.py +++ b/src/valenspy/diagnostic/plot_utils.py @@ -29,13 +29,14 @@ def _augment_kwargs(def_kwargs, **kwargs): cbar_kwargs = _merge_kwargs(def_kwargs.pop('cbar_kwargs'), kwargs.pop('cbar_kwargs', {})) def_kwargs['cbar_kwargs'] = cbar_kwargs + #Is this correct? Are the cbar_kwargs not overwritten if defined by the user? return _merge_kwargs(def_kwargs, kwargs) ###################################### ############## Wrappers ############## ###################################### -def default_plot_kwargs(kwargs): +def default_plot_kwargs(def_kwargs): """ Decorator to set the default keyword arguments for the plotting function. User will override and/or be augmented with the default keyword arguments. subplot_kws and cbar_kwargs can also be set as default keyword arguments for the plotting function. @@ -63,7 +64,7 @@ def decorator(plotting_function): @wraps(plotting_function) def wrapper(*args, **kwargs): - return plotting_function(*args, **_augment_kwargs(def_kwargs=kwargs, **kwargs)) + return plotting_function(*args, **_augment_kwargs(def_kwargs=def_kwargs, **kwargs)) return wrapper diff --git a/src/valenspy/diagnostic/visualizations.py b/src/valenspy/diagnostic/visualizations.py index cd702907..448722e2 100644 --- a/src/valenspy/diagnostic/visualizations.py +++ b/src/valenspy/diagnostic/visualizations.py @@ -541,6 +541,64 @@ def plot_metric_ranking(df_metric, ax=None, plot_colorbar=True, hex_color1 = Non return ax +########################################## +# Ensemble2Ensemble diagnostic functions # +########################################## + +#Add default plot kwargs at the diagnostic level +def ensemble_selection_boxplot(df, selected=None, sel_colors=None, ax=None, **kwargs): + """ + Create a boxplot of the ensemble members with the selected ensemble members highlighted. + + Parameters + ---------- + df : pd.DataFrame + A DataFrame containing the ensemble members and their values. + selected : str or list of str, optional + The column name(s) with boolean values indicating the selected ensemble members. + If None all members are plotted. + ax : matplotlib.axes.Axes, optional + The axes on which to plot the boxplot. If None, a new figure and axes are created. + **kwargs : dict + Additional keyword arguments passed to `seaborn.boxplot`. + """ + sns.boxplot(data=df, ax=ax, **kwargs) + + if selected is None: + sns.swarmplot(data=df, ax=ax, **kwargs) + else: + for sel in selected: + if sel in df.columns: + if isinstance(sel_colors, dict) and sel in sel_colors: + kwargs.update({"color": sel_colors[sel]}) + sns.swarmplot(data=df[df[sel]], ax=ax, **kwargs) + + ax = _get_gca(**kwargs) + return ax + +def ensemble_change_signal_heatmap(df, index=None, columns=None, values=None, ax=None, **kwargs): + """ + Create a heatmap of the ensemble change signal for different variables per ensemble member. + + Parameters + ---------- + df : pd.DataFrame + A DataFrame containing the ensemble change signals for different variables per ensemble member. + index : str, optional + The column name to use as the index for the heatmap. + columns : str, optional + The column name to use as the columns for the heatmap. + values : str, optional + The column name to use as the climate signal values for the heatmap. + ax : matplotlib.axes.Axes, optional + The axes on which to plot the heatmap. If None, a new figure and axes are created. + **kwargs : dict + Additional keyword arguments passed to `seaborn.heatmap`. + """ + sns.heatmap(data=df.pivot(index=index, columns=columns, values=values), ax=ax, **kwargs) + ax = _get_gca(**kwargs) + return ax + ################################## # Helper functions # ##################################