Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/valenspy/diagnostic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from ._model2self import *
from ._model2ref import *
from ._ensemble2self import *
from ._ensemble2ref import *
from ._ensemble2ref import *
26 changes: 24 additions & 2 deletions src/valenspy/diagnostic/_ensemble2ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,33 @@
from valenspy.diagnostic.functions import *
from valenspy.diagnostic.visualizations import *

__all__ = ["MetricsRankings"]
__all__ = [
"MetricsRankings",
"EnsembleSubSelection",
]

MetricsRankings = Ensemble2Ref(
calc_metrics_dt,
plot_metric_ranking,
"Metrics Rankings",
"The rankings of ensemble members with respect to several metrics when compared to the reference."
)
)

EnsembleSubSelection = Ensemble2Ref(
case_sub_selection,
{"default":
default_plot_kwargs({
"x": "var",
"y": "abs_change",
"selected": ["highest", "middle", "lowest"],
"sel_colors": {"highest": "red", "middle": "blue", "lowest": "green"}
})(ensemble_selection_boxplot),
"heatmap":
default_plot_kwargs({
"index": "label",
"columns": "var",
"values": "rel_change"
})(ensemble_change_signal_heatmap)},
"Ensemble Sub Selection",
"The sub selection of ensemble members."
)
72 changes: 49 additions & 23 deletions src/valenspy/diagnostic/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import xarray as xr
import matplotlib.pyplot as plt
from valenspy.processing.mask import add_prudence_regions
from valenspy.diagnostic.plot_utils import _augment_kwargs
from valenspy.diagnostic.plot_utils import default_plot_kwargs, _augment_kwargs
from valenspy._utilities import generate_parameters_doc
import numpy as np
import inspect
Expand All @@ -18,16 +18,16 @@ class Diagnostic():
"""An abstract class representing a diagnostic."""

def __init__(
self, diagnostic_function, plotting_function, name=None, description=None
self, diagnostic_function, plotting_functions, name=None, description=None
):
"""Initialize the Diagnostic.

Parameters
----------
diagnostic_function
diagnostic_function : function
The function that applies a diagnostic to the data.
plotting_function
The function that visualizes the results of the diagnostic.
plotting_functions : function or dict of functions
The functions or dictionary of functions that visualize the diagnostic.
name : str
The name of the diagnostic.
description : str
Expand All @@ -36,7 +36,11 @@ def __init__(
self.name = name
self._description = description
self.diagnostic_function = diagnostic_function
self.plotting_function = plotting_function

if callable(plotting_functions):
plotting_functions = {"default": plotting_functions}

self.plotting_functions = plotting_functions

self.__signature__ = inspect.signature(self.diagnostic_function)
self.__doc__ = self.description
Expand All @@ -61,7 +65,7 @@ def apply(self, data):
"""
pass

def plot(self, result, title=None, **kwargs):
def plot(self, result, kind="default", title=None, **kwargs):
"""Plot the diagnostic. Single ax plots.

Parameters
Expand All @@ -78,12 +82,29 @@ def plot(self, result, title=None, **kwargs):
ax : matplotlib.axis.Axis
The axis (singular) of the plot.
"""
ax = self.plotting_function(result, **kwargs)
ax = self.plotting_functions[kind](result, **kwargs)
if not title:
title = self.name
ax.set_title(title)
return ax

#Support easy access to the plotting functions
# class PlotAccessor:
# """An accessor to the plotting functions of the diagnostic."""

# def __init__(self, diagnostic):
# self.diagnostic = diagnostic

# def __getattr__(self, kind):
# def plot_kind(*args, **kwargs):
# return self._diagnostic.plot(*args, kind=kind, **kwargs)
# return plot_kind

# @property
# def plot(self):
# return self.PlotAccessor(self)


@property
def description(self):
"""Generate the docstring for the diagnostic."""
Expand All @@ -100,7 +121,7 @@ class DataSetDiagnostic(Diagnostic):
"""A class representing a diagnostic that operates on the level of single datasets."""

def __init__(
self, diagnostic_function, plotting_function, name=None, description=None, plot_type="single"
self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="single"
):
"""
Initialize the DataSetDiagnostic.
Expand All @@ -112,7 +133,7 @@ def __init__(
If "single", plot_dt will plot all the leaves of the DataTree on the same axis.
If "facetted", plot_dt will plot all the leaves of the DataTree on different axes.
"""
super().__init__(diagnostic_function, plotting_function, name, description)
super().__init__(diagnostic_function, plotting_functions, name, description)
self.plot_type = plot_type

def __call__(self, data, *args, **kwargs):
Expand Down Expand Up @@ -164,6 +185,7 @@ def apply(self, ds: xr.Dataset, *args, **kwargs):
"""
return self.diagnostic_function(ds, *args, **kwargs)

#Currently no support for different plotting kinds
def plot_dt(self, dt, *args, **kwargs):
if self.plot_type == "single":
return self.plot_dt_single(dt, *args, **kwargs)
Expand Down Expand Up @@ -257,20 +279,20 @@ class Model2Self(DataSetDiagnostic):
"""A class representing a diagnostic that compares a model to itself."""

def __init__(
self, diagnostic_function, plotting_function, name=None, description=None, plot_type="single"
self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="single"
):
"""Initialize the Model2Self diagnostic."""
super().__init__(diagnostic_function, plotting_function, name, description, plot_type)
super().__init__(diagnostic_function, plotting_functions, name, description, plot_type)


class Model2Ref(DataSetDiagnostic):
"""A class representing a diagnostic that compares a model to a reference."""

def __init__(
self, diagnostic_function, plotting_function, name=None, description=None, plot_type="facetted"
self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="facetted"
):
"""Initialize the Model2Ref diagnostic."""
super().__init__(diagnostic_function, plotting_function, name, description, plot_type)
super().__init__(diagnostic_function, plotting_functions, name, description, plot_type)

def apply(self, ds: xr.Dataset, ref: xr.Dataset, **kwargs):
"""Apply the diagnostic to the data. Only the common variables between the data and the reference are used.
Expand All @@ -296,11 +318,11 @@ class Ensemble2Self(Diagnostic):
"""A class representing a diagnostic that compares an ensemble to itself."""

def __init__(
self, diagnostic_function, plotting_function, name=None, description=None, iterative_plotting=False
self, diagnostic_function, plotting_functions, name=None, description=None, iterative_plotting=False
):
"""Initialize the Ensemble2Self diagnostic."""
self.iterative_plotting = iterative_plotting
super().__init__(diagnostic_function, plotting_function, name, description)
super().__init__(diagnostic_function, plotting_functions, name, description)


def apply(self, dt: DataTree, mask=None, **kwargs):
Expand All @@ -321,7 +343,7 @@ def apply(self, dt: DataTree, mask=None, **kwargs):

return self.diagnostic_function(dt, **kwargs)

def plot(self, result, variables=None, title=None, facetted=None, **kwargs):
def plot(self, result, kind="default", variables=None, title=None, facetted=None, **kwargs):
"""Plot the diagnostic.

If facetted multiple plots on different axes are created. If not facetted, the plots are created on the same axis.
Expand Down Expand Up @@ -349,10 +371,10 @@ class Ensemble2Ref(Diagnostic):
"""A class representing a diagnostic that compares an ensemble to a reference."""

def __init__(
self, diagnostic_function, plotting_function, name=None, description=None
self, diagnostic_function, plotting_functions, name=None, description=None
):
"""Initialize the Ensemble2Ref diagnostic."""
super().__init__(diagnostic_function, plotting_function, name, description)
super().__init__(diagnostic_function, plotting_functions, name, description)

def apply(self, dt: DataTree, ref, **kwargs):
"""Apply the diagnostic to the data.
Expand All @@ -369,10 +391,13 @@ def apply(self, dt: DataTree, ref, **kwargs):
DataTree or dict
The data after applying the diagnostic as a DataTree or a dictionary of results with the tree nodes as keys.
"""
# TODO: Add some checks to make sure the reference is a DataTree or a Dataset and contain common variables with the data.
#Make sure that the dt and ref are isomorphic
if isinstance(ref, DataTree):
dt = filter_like(dt, ref)
ref = filter_like(ref, dt)
return self.diagnostic_function(dt, ref, **kwargs)

def plot(self, result, facetted=True, **kwargs):
def plot(self, result, kind="default", facetted=True, **kwargs):
"""Plot the diagnostic.

If axes are provided, the diagnostic is plotted facetted. If ax is provided, the diagnostic is plotted non-facetted.
Expand All @@ -392,9 +417,9 @@ def plot(self, result, facetted=True, **kwargs):
raise ValueError("Either ax or axes can be provided, not both.")
elif "ax" not in kwargs and "axes" not in kwargs:
ax = plt.gca()
return self.plotting_function(result, ax=ax, **kwargs)
return self.plotting_functions[kind](result, ax=ax, **kwargs)
else:
return self.plotting_function(result, **kwargs)
return self.plotting_functions[kind](result, **kwargs)

def _common_vars(ds1, ds2):
"""Return the common variables in two datasets."""
Expand All @@ -411,3 +436,4 @@ def _initialize_multiaxis_plot(n, subplot_kws={}):
nrows=n//2+1, ncols=2, figsize=(10, 5 * n), subplot_kw=subplot_kws
)
return fig, axes

44 changes: 44 additions & 0 deletions src/valenspy/diagnostic/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,50 @@ def calc_metrics_dt(dt_mod: DataTree, da_obs: xr.Dataset, metrics=None, pss_binw
df = _add_ranks_metrics(df)
return df

##########################################
# Ensemble2Ensemble diagnostic functions #
##########################################

def case_sub_selection(dt_future: DataTree, dt_ref: DataTree, vars):
"""
Select 3 ensemble members based on avg normalized climate change in the variable of interest.
The two extreme and mediaan members are selected.
"""
#TODO - add direction of indicators
#TODO - Check consistnecy with paper (slightly different approach)!
#TEST how to leave an extra dimension left over - e.g. regions
#Test with different periods of variables (tas_JJA, etc)
#Improve the heatmap plot (square default, absolute values of the change?, larger squares)
dt_change = dt_future.mean() - dt_ref.mean()
dt_rel_change = dt_change / dt_ref.mean()

data = [[x.path, var, x[var].values] for x in dt_rel_change.leaves for var in x.data_vars]
data_abs = [[x.path, var, x[var].values] for x in dt_change.leaves for var in x.data_vars]

#Create one dataframe containing the relative change and the absolute change
df = pd.DataFrame(data, columns=["label", "var", "rel_change"])
df_abs = pd.DataFrame(data_abs, columns=["label", "var", "abs_change"])

df = pd.merge(df, df_abs, on=["label", "var"])

df["rel_change"] = df["rel_change"].astype(float)
df["abs_change"] = df["abs_change"].astype(float)

#Normalize the absolute change per variable
df["norm_rel_change"] = df.groupby("var")["rel_change"].transform(lambda x: (x - x.mean()) / x.std())
df["norm_abs_change"] = df.groupby("var")["abs_change"].transform(lambda x: (x - x.mean()) / x.std())

#Get the rank per variable
df["rank"] = df["norm_abs_change"].groupby(df["var"]).rank(ascending=True, method='min')
df["member_rank"] = df["rank"].where(df["var"].isin(vars)).groupby(df["label"]).transform("mean").rank(ascending=True, method='min') #

#Rank the ensemble members based on the mean of the normalized absolute change
df["highest"] = df["member_rank"] == 1
df["lowest"] = df["member_rank"] == df["member_rank"].max()
df["middle"] = df["member_rank"] == np.floor(df["member_rank"].median())

return df


##################################
####### Helper functions #########
Expand Down
5 changes: 3 additions & 2 deletions src/valenspy/diagnostic/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ def _augment_kwargs(def_kwargs, **kwargs):
cbar_kwargs = _merge_kwargs(def_kwargs.pop('cbar_kwargs'), kwargs.pop('cbar_kwargs', {}))
def_kwargs['cbar_kwargs'] = cbar_kwargs

#Is this correct? Are the cbar_kwargs not overwritten if defined by the user?
return _merge_kwargs(def_kwargs, kwargs)

######################################
############## Wrappers ##############
######################################

def default_plot_kwargs(kwargs):
def default_plot_kwargs(def_kwargs):
"""
Decorator to set the default keyword arguments for the plotting function. User will override and/or be augmented with the default keyword arguments.
subplot_kws and cbar_kwargs can also be set as default keyword arguments for the plotting function.
Expand Down Expand Up @@ -63,7 +64,7 @@ def decorator(plotting_function):

@wraps(plotting_function)
def wrapper(*args, **kwargs):
return plotting_function(*args, **_augment_kwargs(def_kwargs=kwargs, **kwargs))
return plotting_function(*args, **_augment_kwargs(def_kwargs=def_kwargs, **kwargs))

return wrapper

Expand Down
58 changes: 58 additions & 0 deletions src/valenspy/diagnostic/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,64 @@ def plot_metric_ranking(df_metric, ax=None, plot_colorbar=True, hex_color1 = Non

return ax

##########################################
# Ensemble2Ensemble diagnostic functions #
##########################################

#Add default plot kwargs at the diagnostic level
def ensemble_selection_boxplot(df, selected=None, sel_colors=None, ax=None, **kwargs):
"""
Create a boxplot of the ensemble members with the selected ensemble members highlighted.

Parameters
----------
df : pd.DataFrame
A DataFrame containing the ensemble members and their values.
selected : str or list of str, optional
The column name(s) with boolean values indicating the selected ensemble members.
If None all members are plotted.
ax : matplotlib.axes.Axes, optional
The axes on which to plot the boxplot. If None, a new figure and axes are created.
**kwargs : dict
Additional keyword arguments passed to `seaborn.boxplot`.
"""
sns.boxplot(data=df, ax=ax, **kwargs)

if selected is None:
sns.swarmplot(data=df, ax=ax, **kwargs)
else:
for sel in selected:
if sel in df.columns:
if isinstance(sel_colors, dict) and sel in sel_colors:
kwargs.update({"color": sel_colors[sel]})
sns.swarmplot(data=df[df[sel]], ax=ax, **kwargs)

ax = _get_gca(**kwargs)
return ax

def ensemble_change_signal_heatmap(df, index=None, columns=None, values=None, ax=None, **kwargs):
"""
Create a heatmap of the ensemble change signal for different variables per ensemble member.

Parameters
----------
df : pd.DataFrame
A DataFrame containing the ensemble change signals for different variables per ensemble member.
index : str, optional
The column name to use as the index for the heatmap.
columns : str, optional
The column name to use as the columns for the heatmap.
values : str, optional
The column name to use as the climate signal values for the heatmap.
ax : matplotlib.axes.Axes, optional
The axes on which to plot the heatmap. If None, a new figure and axes are created.
**kwargs : dict
Additional keyword arguments passed to `seaborn.heatmap`.
"""
sns.heatmap(data=df.pivot(index=index, columns=columns, values=values), ax=ax, **kwargs)
ax = _get_gca(**kwargs)
return ax

##################################
# Helper functions #
##################################
Expand Down
Loading