From 5fe40d08a0aeea72ea9f6ac967f9edae51498dee Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 29 Aug 2024 13:51:19 +0200 Subject: [PATCH 1/4] first parts of joyplot --- src/cfp/plotting/_plotting.py | 9 + src/cfp/plotting/_utils.py | 405 ++++++++++++++++++++++++++++++++++ 2 files changed, 414 insertions(+) diff --git a/src/cfp/plotting/_plotting.py b/src/cfp/plotting/_plotting.py index d856e147..1c5f31b2 100644 --- a/src/cfp/plotting/_plotting.py +++ b/src/cfp/plotting/_plotting.py @@ -175,3 +175,12 @@ def plot_condition_embedding( ax.yaxis.set_tick_params(labelsize=fontsize) return fig if return_fig else None + + +def plot_expressions(): + """Plot kernel density estimations of expressions. + + This function is adapted from https://github.com/leotac/joypy/blob/master/joypy/joyplot.py + + + """ diff --git a/src/cfp/plotting/_utils.py b/src/cfp/plotting/_utils.py index 1c161ce6..2de65a10 100644 --- a/src/cfp/plotting/_utils.py +++ b/src/cfp/plotting/_utils.py @@ -2,14 +2,20 @@ from typing import Any import anndata as ad +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scanpy as sc +import scipy.stats as stats import seaborn as sns +from pandas.plotting._matplotlib.tools import create_subplots as _subplots +from pandas.plotting._matplotlib.tools import flatten_axes as _flatten +from scipy.stats import gaussian_kde from sklearn.decomposition import KernelPCA from sklearn.metrics.pairwise import cosine_similarity from cfp import _constants, _logging +from cfp._types import ArrayLike def set_plotting_vars( @@ -102,3 +108,402 @@ def _compute_kernel_pca_from_df( similarity_matrix ) return pd.DataFrame(data=X, columns=list(range(n_components)), index=df.index) + + +def _x_range(data: ArrayLike | list[float], extra: float = 0.2) -> ArrayLike: + """Compute the x_range for density estimation.""" + try: + sample_range = np.nanmax(data) - np.nanmin(data) + except ValueError: + return np.array([]) + if sample_range < 1e-6: + return np.array([np.nanmin(data), np.nanmax(data)]) + return np.linspace( + np.nanmin(data) - extra * sample_range, + np.nanmax(data) + extra * sample_range, + 1000, + ) + + +def _setup_axis( + ax: plt.Axes, + x_range: ArrayLike, + col_name: str | None = None, + grid: bool = False, + ylabelsize: int | None = None, + yrot: int | None = None, +) -> None: + """Setup the axis for the joyplot.""" + if col_name is not None: + ax.set_yticks([0]) + ax.set_yticklabels([col_name], fontsize=ylabelsize, rotation=yrot) + ax.yaxis.grid(grid) + else: + ax.yaxis.set_visible(False) + ax.patch.set_alpha(0) + ax.set_xlim([x_range.min(), x_range.max()]) + ax.tick_params(axis="both", which="both", length=0, pad=10) + + +def _get_alpha(i: int, n: int, start: float = 0.4, end: float = 1.0) -> float: + """Compute alpha value for plotting.""" + return start + (1 + i) * (end - start) / n + + +def _remove_na(data: list[Any] | ArrayLike | pd.Series) -> ArrayLike: + """Remove NA values from the data.""" + return pd.Series(data).dropna().values + + +def _moving_average(a: ArrayLike, n: int = 3, zero_padded: bool = False) -> ArrayLike: + """Calculate the moving average of order n.""" + ret = np.cumsum(a, dtype=float) + ret[n:] = ret[n:] - ret[:-n] + if zero_padded: + return ret / n + else: + return ret[n - 1 :] / n + + +def _joyplot( + data, + grid=False, + labels=None, + sublabels=None, + xlabels=True, + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, + ax=None, + figsize=None, + hist=False, + bins=10, + fade=False, + xlim=None, + ylim="max", + fill=True, + linecolor=None, + overlap=1, + background=None, + range_style="all", + x_range=None, + tails=0.2, + title=None, + legend=False, + loc="upper right", + colormap=None, + color=None, + normalize=True, + floc=None, + **kwargs, +): + if fill is True and linecolor is None: + linecolor = "k" + + if sublabels is None: + legend = False + + def _get_color(i, num_axes, j, num_subgroups): + if isinstance(color, list): + return color[j] if num_subgroups > 1 else color[i] + elif color is not None: + return color + elif isinstance(colormap, list): + return colormap[j](i / num_axes) + elif color is None and colormap is None: + num_cycle_colors = len(plt.rcParams["axes.prop_cycle"].by_key()["color"]) + return plt.rcParams["axes.prop_cycle"].by_key()["color"][ + j % num_cycle_colors + ] + else: + return colormap(i / num_axes) + + ygrid = grid is True or grid == "y" or grid == "both" + xgrid = grid is True or grid == "x" or grid == "both" + + num_axes = len(data) + + if x_range is None: + global_x_range = _x_range([v for g in data for sg in g for v in sg]) + else: + global_x_range = _x_range(x_range, 0.0) + + # Each plot will have its own axis + fig, axes = _subplots( + naxes=num_axes, + ax=ax, + squeeze=False, + sharex=True, + sharey=False, + figsize=figsize, + layout_type="vertical", + ) + _axes = _flatten(axes) + + # The legend must be drawn in the last axis if we want it at the bottom. + if loc in (3, 4, 8) or "lower" in str(loc): + legend_axis = num_axes - 1 + else: + legend_axis = 0 + + # A couple of simple checks. + if labels is not None: + assert len(labels) == num_axes + if sublabels is not None: + assert all(len(g) == len(sublabels) for g in data) + if isinstance(color, list): + assert all(len(g) <= len(color) for g in data) + if isinstance(colormap, list): + assert all(len(g) == len(colormap) for g in data) + + for i, group in enumerate(data): + + a = _axes[i] + group_zorder = i + if fade: + kwargs["alpha"] = _get_alpha(i, num_axes) + + num_subgroups = len(group) + + if hist: + # matplotlib hist() already handles multiple subgroups in a histogram + a.hist( + group, + label=sublabels, + bins=bins, + color=color, + range=[min(global_x_range), max(global_x_range)], + edgecolor=linecolor, + zorder=group_zorder, + **kwargs, + ) + else: + for j, subgroup in enumerate(group): + + # Compute the x_range of the current plot + if range_style == "all": + # All plots have the same range + x_range = global_x_range + elif range_style == "own": + # Each plot has its own range + x_range = _x_range(subgroup, tails) + elif range_style == "group": + # Each plot has a range that covers the whole group + x_range = _x_range(group, tails) + elif isinstance(range_style, list | np.ndarray): + # All plots have exactly the range passed as argument + x_range = _x_range(range_style, 0.0) + else: + raise NotImplementedError("Unrecognized range style.") + + if sublabels is None: + sublabel = None + else: + sublabel = sublabels[j] + + element_zorder = group_zorder + j / (num_subgroups + 1) + element_color = _get_color(i, num_axes, j, num_subgroups) + + _plot_density( + a, + x_range, + subgroup, + fill=fill, + linecolor=linecolor, + label=sublabel, + zorder=element_zorder, + color=element_color, + bins=bins, + **kwargs, + ) + + # Setup the current axis: transparency, labels, spines. + col_name = None if labels is None else labels[i] + _setup_axis( + a, + global_x_range, + col_name=col_name, + grid=ygrid, + ylabelsize=ylabelsize, + yrot=yrot, + ) + + # When needed, draw the legend + if legend and i == legend_axis: + a.legend(loc=loc) + # Bypass alpha values, in case + for p in a.get_legend().get_patches(): + p.set_facecolor(p.get_facecolor()) + p.set_alpha(1.0) + for l in a.get_legend().get_lines(): + l.set_alpha(1.0) + + # Final adjustments + + # Set the y limit for the density plots. + # Since the y range in the subplots can vary significantly, + # different options are available. + if ylim == "max": + # Set all yaxis limit to the same value (max range among all) + max_ylim = max(a.get_ylim()[1] for a in _axes) + min_ylim = min(a.get_ylim()[0] for a in _axes) + for a in _axes: + a.set_ylim([min_ylim - 0.1 * (max_ylim - min_ylim), max_ylim]) + + elif ylim == "own": + # Do nothing, each axis keeps its own ylim + pass + + else: + # Set all yaxis lim to the argument value ylim + try: + for a in _axes: + a.set_ylim(ylim) + except ValueError: + raise ValueError( + "Warning: the value of ylim must be either 'max', 'own', or a tuple of length 2. The value you provided has no effect." + ) from None + + # Compute a final axis, used to apply global settings + last_axis = fig.add_subplot(1, 1, 1) + + # Background color + if background is not None: + last_axis.patch.set_facecolor(background) + + # This looks hacky, but all the axes share the x-axis, + # so they have the same lims and ticks + last_axis.set_xlim(_axes[0].get_xlim()) + if xlabels is True: + last_axis.set_xticks(np.array(_axes[0].get_xticks()[1:-1])) + for t in last_axis.get_xticklabels(): + t.set_visible(True) + t.set_fontsize(xlabelsize) + t.set_rotation(xrot) + + # If grid is enabled, do not allow xticks (they are ugly) + if xgrid: + last_axis.tick_params(axis="both", which="both", length=0) + else: + last_axis.xaxis.set_visible(False) + + last_axis.yaxis.set_visible(False) + last_axis.grid(xgrid) + + # Last axis on the back + last_axis.zorder = min(a.zorder for a in _axes) - 1 + _axes = list(_axes) + [last_axis] + + if title is not None: + plt.title(title) + + # The magic overlap happens here. + h_pad = 5 + (-5 * (1 + overlap)) + fig.tight_layout(h_pad=h_pad) + + return fig, _axes + + +def _plot_density( + ax, + x_range, + v, + kind="kde", + bw_method=None, + bins=50, + fill=False, + linecolor=None, + clip_on=True, + normalize=True, + floc=None, + **kwargs, +): + v = _remove_na(v) + if len(v) == 0 or len(x_range) == 0: + return + + if kind == "kde": + try: + gkde = gaussian_kde(v, bw_method=bw_method) + y = gkde.evaluate(x_range) + y = np.log(y + 1.0) + except ValueError: + # Handle cases where there is no data in a group. + y = np.zeros_like(x_range) + except np.linalg.LinAlgError as e: + # Handle singular matrix in kde computation. + distinct_values = np.unique(v) + if len(distinct_values) == 1: + # In case of a group with a single value val, + # that should have infinite density, + # return a δ(val) + val = distinct_values[0] + _logging.logger.warning( + f"The data contains a group with a single distinct value ({val}) " + "having infinite probability density. " + "Consider using a different visualization." + ) + + # Find index i of x_range + # such that x_range[i-1] < val ≤ x_range[i] + i = np.searchsorted(x_range, val) + + y = np.zeros_like(x_range) + y[i] = 1 + else: + raise e + + elif kind == "lognorm": + if floc is not None: + lnparam = stats.lognorm.fit(v, loc=floc) + else: + lnparam = stats.lognorm.fit(v) + + lpdf = stats.lognorm.pdf(x_range, lnparam[0], lnparam[1], lnparam[2]) + if normalize: + y = lpdf / lpdf.sum() + else: + y = lpdf + elif kind == "counts": + y, bin_edges = np.histogram(v, bins=bins, range=(min(x_range), max(x_range))) + # np.histogram returns the edges of the bins. + # We compute here the middle of the bins. + x_range = _moving_average(bin_edges, 2) + elif kind == "normalized_counts": + y, bin_edges = np.histogram( + v, bins=bins, density=False, range=(min(x_range), max(x_range)) + ) + # np.histogram returns the edges of the bins. + # We compute here the middle of the bins. + y = y / len(v) + x_range = _moving_average(bin_edges, 2) + elif kind == "values": + # Warning: to use values and get a meaningful visualization, + # x_range must also be manually set in the main function. + y = v + x_range = list(range(len(y))) + else: + raise NotImplementedError + + if fill: + + ax.fill_between(x_range, 0.0, y, clip_on=clip_on, **kwargs) + + # Hack to have a border at the bottom at the fill patch + # (of the same color of the fill patch) + # so that the fill reaches the same bottom margin as the edge lines + # with y value = 0.0 + kw = kwargs + kw["label"] = None + ax.plot(x_range, [0.0] * len(x_range), clip_on=clip_on, **kw) + + if linecolor is not None: + kwargs["color"] = linecolor + + # Remove the legend labels if we are plotting filled curve: + # we only want one entry per group in the legend (if shown). + if fill: + kwargs["label"] = None + + ax.plot(x_range, y, clip_on=clip_on, **kwargs) From 99ea28def89466df6033a1c1185d5b1b2e6f3b94 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 29 Aug 2024 16:10:37 +0200 Subject: [PATCH 2/4] joyplot working --- src/cfp/plotting/__init__.py | 4 +- src/cfp/plotting/_plotting.py | 167 +++++++++++++++++++++++++++++++- src/cfp/plotting/_utils.py | 61 ++++++++---- tests/plotting/conftest.py | 7 ++ tests/plotting/test_plotting.py | 15 ++- 5 files changed, 230 insertions(+), 24 deletions(-) diff --git a/src/cfp/plotting/__init__.py b/src/cfp/plotting/__init__.py index 1b9c072d..c8486788 100644 --- a/src/cfp/plotting/__init__.py +++ b/src/cfp/plotting/__init__.py @@ -1,3 +1,3 @@ -from cfp.plotting._plotting import plot_condition_embedding +from cfp.plotting._plotting import plot_condition_embedding, plot_densities -__all__ = ["plot_condition_embedding"] +__all__ = ["plot_condition_embedding", "plot_densities"] diff --git a/src/cfp/plotting/_plotting.py b/src/cfp/plotting/_plotting.py index 1c5f31b2..48bd08b9 100644 --- a/src/cfp/plotting/_plotting.py +++ b/src/cfp/plotting/_plotting.py @@ -1,22 +1,30 @@ import types +from collections.abc import Sequence from typing import Any, Literal import anndata as ad import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +import pandas as pd import seaborn as sns from adjustText import adjust_text -from cfp import _constants +from cfp import _constants, _logging from cfp.plotting._utils import ( _compute_kernel_pca_from_df, _compute_pca_from_df, _compute_umap_from_df, _get_colors, + _grouped_df_to_standard, + _is_numeric, + _joyplot, + _remove_na, get_plotting_vars, ) +__all__ = ["plot_condition_embedding", "plot_densities"] + def plot_condition_embedding( adata: ad.AnnData, @@ -177,10 +185,165 @@ def plot_condition_embedding( return fig if return_fig else None -def plot_expressions(): +def plot_densities( + data: pd.DataFrame, + features: Sequence[str], + group_by: str | None = None, + ax: mpl.axes.Axes | None = None, + figsize: tuple[float, float] | None = None, + dpi: int | None = None, + xlabels: bool = False, + ylabels: bool = True, + xlabelsize: float | None = None, + xrot: float | None = None, + labels: Sequence[Any] = None, + ylabelsize: float | None = None, + yrot: float | None = None, + hist: bool = False, + bins: int = 10, + fade: bool = False, + ylim: Literal["max"] | tuple[float, float] | None = "max", + fill: bool = True, + linecolor: Any = None, + overlap: float = 1.0, + background: Any = None, + range_style: Literal["all", "individual", "group"] | list[float] = "all", + x_range: tuple[float, float] = None, + title: str | None = None, + colormap: str | mpl.colors.Colormap | None = None, + color: Any = None, + normalize: bool = True, + grid: bool = False, + **kwargs, +): """Plot kernel density estimations of expressions. This function is adapted from https://github.com/leotac/joypy/blob/master/joypy/joyplot.py + Parameters + ---------- + data + :class:`pandas.DataFrame` object containing (predicted) expression values. + features + Features whose density to plot. + group_by + Column in ``'data'`` to group by. + ax + :class:`matplotlib.axes.Axes` used for plotting. If :obj:`None`, create a new one. + figsize + Size of the figure. + dpi + Dots per inch. + xlabels + Whether to show x-axis labels. + ylabels + Whether to show y-axis labels. + xlabelsize + Size of the x-axis labels. + xrot + Rotation (in degrees) of the x-axis labels. + labels + Sequence of labels for each density plot. + ylabelsize + Size of the y-axis labels. + yrot + Rotation (in degrees) of the y-axis labels. + hist + If :obj:`True`, plot a histogram, otherwise a density plot. + bins + Number of bins to use, only applicable if ``hist`` is :obj:`True`. + fade + If :obj:`True`, automatically sets different values of transparency of the density plots. + ylim + Limits of the y-axis. + fill + Whether to fill the density plots. If :obj:`False`, only the lines are plotted. + linecolor: :mpltype:`color` + Color of the contour lines. + overlap + Overlap between the density plots. The higher the value, the more overlap between densities. + background: :mpltype:`color` + Background color of the plot. + range_style + Style of the range. Options are + + - "all" - all density plots have the same range, autmoatically determined. + - "individual" - every density plot has its own range, automatically determined. + - "group" - each plot has a range that covers the whole group + - type :obj:`list` - custom ranges for each density plot. + + x_range + Custom range for the x-axis, shared across all density plots. If :obj:`None`, set via ``'range_style'``. + title + Title of the plot. + colormap + Colormap to use. + color: :mpltype:`color` + Color of the density plots. + normalize + Whether to normalize the densities. + grid + Whether to show the grid. + kwargs + Additional keyword arguments for the plot. """ + if group_by is not None and isinstance(data, pd.DataFrame): + grouped = data.groupby(group_by) + if features is None: + features = list(data.columns) + features.remove(group_by) + converted, _labels, sublabels = _grouped_df_to_standard(grouped, features) + if labels is None: + labels = _labels + elif isinstance(data, pd.DataFrame): + if features is not None: + data = data[features] + converted = [ + [_remove_na(data[col])] for col in data.columns if _is_numeric(data[col]) + ] + labels = [col for col in data.columns if _is_numeric(data[col])] + sublabels = None + else: + raise TypeError(f"Unknown type for 'data': {type(data)!r}") + + if ylabels is False: + labels = None + + if all(len(subg) == 0 for g in converted for subg in g): + raise ValueError( + "No numeric values found. Joyplot requires at least a numeric column/group." + ) + + if any(len(subg) == 0 for g in converted for subg in g): + _logging.logger.warning("At least a column/group has no numeric values.") + + return _joyplot( + converted, + labels=labels, + sublabels=sublabels, + grid=grid, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + ax=ax, + dpi=dpi, + figsize=figsize, + hist=hist, + bins=bins, + fade=fade, + ylim=ylim, + fill=fill, + linecolor=linecolor, + overlap=overlap, + background=background, + xlabels=xlabels, + range_style=range_style, + x_range=x_range, + title=title, + colormap=colormap, + color=color, + normalize=normalize, + **kwargs, + ) diff --git a/src/cfp/plotting/_utils.py b/src/cfp/plotting/_utils.py index 2de65a10..4600f6c8 100644 --- a/src/cfp/plotting/_utils.py +++ b/src/cfp/plotting/_utils.py @@ -1,13 +1,15 @@ from collections.abc import Sequence -from typing import Any +from typing import Any, Literal import anndata as ad +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd import scanpy as sc import scipy.stats as stats import seaborn as sns +from pandas.core.dtypes.common import is_number from pandas.plotting._matplotlib.tools import create_subplots as _subplots from pandas.plotting._matplotlib.tools import flatten_axes as _flatten from scipy.stats import gaussian_kde @@ -150,6 +152,11 @@ def _get_alpha(i: int, n: int, start: float = 0.4, end: float = 1.0) -> float: return start + (1 + i) * (end - start) / n +def _is_numeric(x): + """Whether the array x is numeric.""" + return all(is_number(i) for i in x) + + def _remove_na(data: list[Any] | ArrayLike | pd.Series) -> ArrayLike: """Remove NA values from the data.""" return pd.Series(data).dropna().values @@ -165,9 +172,24 @@ def _moving_average(a: ArrayLike, n: int = 3, zero_padded: bool = False) -> Arra return ret[n - 1 :] / n +def _grouped_df_to_standard(grouped, column): + converted = [] + labels = [] + for i, (key, group) in enumerate(grouped): + if column is not None: + group = group[column] + labels.append(key) + converted.append( + [_remove_na(group[c]) for c in group.columns if _is_numeric(group[c])] + ) + if i == 0: + sublabels = [col for col in group.columns if _is_numeric(group[col])] + return converted, labels, sublabels + + def _joyplot( - data, - grid=False, + data: pd.DataFrame, + grid: bool = False, labels=None, sublabels=None, xlabels=True, @@ -175,26 +197,27 @@ def _joyplot( xrot=None, ylabelsize=None, yrot=None, - ax=None, - figsize=None, - hist=False, + ax: mpl.axes.Axes | None = None, + dpi: int | None = None, + figsize: tuple[float, float] | None = None, + hist: bool = False, bins=10, - fade=False, + fade: bool = False, xlim=None, ylim="max", fill=True, linecolor=None, - overlap=1, - background=None, - range_style="all", - x_range=None, - tails=0.2, - title=None, + overlap: float = 1.0, + background: Any = None, + range_style: Literal["all", "individual", "group"] | list[float] = "all", + x_range: tuple[float, float] = None, + tails: float = 0.2, + title: str | None = None, legend=False, loc="upper right", - colormap=None, + colormap: str | mpl.colors.Colormap | None = None, color=None, - normalize=True, + normalize: bool = True, floc=None, **kwargs, ): @@ -232,6 +255,7 @@ def _get_color(i, num_axes, j, num_subgroups): # Each plot will have its own axis fig, axes = _subplots( naxes=num_axes, + dpi=dpi, ax=ax, squeeze=False, sharex=True, @@ -285,13 +309,13 @@ def _get_color(i, num_axes, j, num_subgroups): if range_style == "all": # All plots have the same range x_range = global_x_range - elif range_style == "own": + elif range_style == "individual": # Each plot has its own range x_range = _x_range(subgroup, tails) elif range_style == "group": # Each plot has a range that covers the whole group x_range = _x_range(group, tails) - elif isinstance(range_style, list | np.ndarray): + elif isinstance(range_style, list): # All plots have exactly the range passed as argument x_range = _x_range(range_style, 0.0) else: @@ -314,6 +338,7 @@ def _get_color(i, num_axes, j, num_subgroups): label=sublabel, zorder=element_zorder, color=element_color, + normalize=normalize, bins=bins, **kwargs, ) @@ -351,7 +376,7 @@ def _get_color(i, num_axes, j, num_subgroups): for a in _axes: a.set_ylim([min_ylim - 0.1 * (max_ylim - min_ylim), max_ylim]) - elif ylim == "own": + elif ylim is None: # Do nothing, each axis keeps its own ylim pass diff --git a/tests/plotting/conftest.py b/tests/plotting/conftest.py index b74dbbe2..9d0609ca 100644 --- a/tests/plotting/conftest.py +++ b/tests/plotting/conftest.py @@ -27,3 +27,10 @@ def adata_with_condition_embedding(adata_perturbation) -> ad.AnnData: adata_perturbation.uns[_constants.CFP_KEY] = {} adata_perturbation.uns[_constants.CFP_KEY][_constants.CONDITION_EMBEDDING] = df return adata_perturbation + + +@pytest.fixture +def df_joyplot() -> pd.DataFrame: + test_df = pd.DataFrame(data=np.random.randn(100, 2), columns=["x", "y"]) + test_df["dosage"] = np.random.choice([0.0, 1.0, 2.0], 100) + return test_df diff --git a/tests/plotting/test_plotting.py b/tests/plotting/test_plotting.py index a1dfe89e..bf233c60 100644 --- a/tests/plotting/test_plotting.py +++ b/tests/plotting/test_plotting.py @@ -1,10 +1,10 @@ import matplotlib.pyplot as plt import pytest -from cfp.plotting import plot_condition_embedding +from cfp.plotting import plot_condition_embedding, plot_densities -class TestCallbacks: +class TestPlotConditionEmbedding: @pytest.mark.parametrize( "embedding", ["raw_embedding", "UMAP", "PCA", "Kernel_PCA"] ) @@ -26,3 +26,14 @@ def test_plot_embeddings( ) assert isinstance(fig, plt.Figure) + + +class TestJoyPlot: + @pytest.mark.parametrize("features", [["x", "y"], "x"]) + @pytest.mark.parametrize("group_by", ["dosage", None]) + @pytest.mark.parametrize("hist", [True, False]) + def test_plot_joyplot(self, df_joyplot, features, group_by, hist): + fig = plot_densities( + df_joyplot, features=features, group_by=group_by, hist=hist + ) + assert isinstance(fig, plt.Figure) From 3936593cf6f55666b601d099de4acf913944bb9f Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 29 Aug 2024 17:01:36 +0200 Subject: [PATCH 3/4] clean --- src/cfp/plotting/_plotting.py | 16 ++++- src/cfp/plotting/_utils.py | 132 +++++++++++++--------------------- 2 files changed, 65 insertions(+), 83 deletions(-) diff --git a/src/cfp/plotting/_plotting.py b/src/cfp/plotting/_plotting.py index 48bd08b9..6419634a 100644 --- a/src/cfp/plotting/_plotting.py +++ b/src/cfp/plotting/_plotting.py @@ -189,6 +189,7 @@ def plot_densities( data: pd.DataFrame, features: Sequence[str], group_by: str | None = None, + density_fit: Literal["log1p", "raw"] = "raw", ax: mpl.axes.Axes | None = None, figsize: tuple[float, float] | None = None, dpi: int | None = None, @@ -214,6 +215,7 @@ def plot_densities( color: Any = None, normalize: bool = True, grid: bool = False, + return_fig: bool = True, **kwargs, ): """Plot kernel density estimations of expressions. @@ -228,6 +230,9 @@ def plot_densities( Features whose density to plot. group_by Column in ``'data'`` to group by. + density_fit + Type of density fit to use. If "raw", the kernel density estimation is plotted. If "log1p", the log1p + transformed values of the densities are plotted. ax :class:`matplotlib.axes.Axes` used for plotting. If :obj:`None`, create a new one. figsize @@ -284,16 +289,21 @@ def plot_densities( Whether to normalize the densities. grid Whether to show the grid. + return_fig + Whether to return the figure. kwargs Additional keyword arguments for the plot. + Returns + ------- + :class:`matplotlib.figure.Figure` if ``'return_fig'`` is :obj:`True`, else :obj:`None`. """ if group_by is not None and isinstance(data, pd.DataFrame): grouped = data.groupby(group_by) if features is None: features = list(data.columns) features.remove(group_by) - converted, _labels, sublabels = _grouped_df_to_standard(grouped, features) + converted, _labels, sublabels = _grouped_df_to_standard(grouped, features) # type: ignore[arg-type] if labels is None: labels = _labels elif isinstance(data, pd.DataFrame): @@ -318,7 +328,7 @@ def plot_densities( if any(len(subg) == 0 for g in converted for subg in g): _logging.logger.warning("At least a column/group has no numeric values.") - return _joyplot( + fig, axes = _joyplot( converted, labels=labels, sublabels=sublabels, @@ -345,5 +355,7 @@ def plot_densities( colormap=colormap, color=color, normalize=normalize, + density_fit=density_fit, **kwargs, ) + return fig if return_fig else None diff --git a/src/cfp/plotting/_utils.py b/src/cfp/plotting/_utils.py index 4600f6c8..90641fa1 100644 --- a/src/cfp/plotting/_utils.py +++ b/src/cfp/plotting/_utils.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd import scanpy as sc -import scipy.stats as stats import seaborn as sns from pandas.core.dtypes.common import is_number from pandas.plotting._matplotlib.tools import create_subplots as _subplots @@ -152,7 +151,7 @@ def _get_alpha(i: int, n: int, start: float = 0.4, end: float = 1.0) -> float: return start + (1 + i) * (end - start) / n -def _is_numeric(x): +def _is_numeric(x: ArrayLike) -> bool: """Whether the array x is numeric.""" return all(is_number(i) for i in x) @@ -172,7 +171,9 @@ def _moving_average(a: ArrayLike, n: int = 3, zero_padded: bool = False) -> Arra return ret[n - 1 :] / n -def _grouped_df_to_standard(grouped, column): +def _grouped_df_to_standard( + grouped: pd.api.typing.DataFrameGroupBy, column: str | None +) -> tuple[pd.DataFrame, list[str], list[str]]: converted = [] labels = [] for i, (key, group) in enumerate(grouped): @@ -190,6 +191,7 @@ def _grouped_df_to_standard(grouped, column): def _joyplot( data: pd.DataFrame, grid: bool = False, + density_fit: Literal["log1p", "raw"] = "raw", labels=None, sublabels=None, xlabels=True, @@ -203,7 +205,6 @@ def _joyplot( hist: bool = False, bins=10, fade: bool = False, - xlim=None, ylim="max", fill=True, linecolor=None, @@ -218,9 +219,8 @@ def _joyplot( colormap: str | mpl.colors.Colormap | None = None, color=None, normalize: bool = True, - floc=None, **kwargs, -): +) -> tuple[plt.Figure, list[plt.Axes]]: if fill is True and linecolor is None: linecolor = "k" @@ -240,7 +240,7 @@ def _get_color(i, num_axes, j, num_subgroups): j % num_cycle_colors ] else: - return colormap(i / num_axes) + return colormap(i / num_axes) # type: ignore[operator] ygrid = grid is True or grid == "y" or grid == "both" xgrid = grid is True or grid == "x" or grid == "both" @@ -250,7 +250,7 @@ def _get_color(i, num_axes, j, num_subgroups): if x_range is None: global_x_range = _x_range([v for g in data for sg in g for v in sg]) else: - global_x_range = _x_range(x_range, 0.0) + global_x_range = _x_range(x_range, 0.0) # type: ignore[arg-type] # Each plot will have its own axis fig, axes = _subplots( @@ -292,7 +292,7 @@ def _get_color(i, num_axes, j, num_subgroups): if hist: # matplotlib hist() already handles multiple subgroups in a histogram - a.hist( + ax = a.hist( group, label=sublabels, bins=bins, @@ -329,9 +329,9 @@ def _get_color(i, num_axes, j, num_subgroups): element_zorder = group_zorder + j / (num_subgroups + 1) element_color = _get_color(i, num_axes, j, num_subgroups) - _plot_density( + ax = _plot_density( a, - x_range, + x_range, # type: ignore[arg-type] subgroup, fill=fill, linecolor=linecolor, @@ -340,6 +340,7 @@ def _get_color(i, num_axes, j, num_subgroups): color=element_color, normalize=normalize, bins=bins, + density_fit=density_fit, **kwargs, ) @@ -431,85 +432,53 @@ def _get_color(i, num_axes, j, num_subgroups): def _plot_density( - ax, - x_range, - v, - kind="kde", + ax: plt.Axes, + x_range: ArrayLike, + v: ArrayLike, bw_method=None, - bins=50, - fill=False, - linecolor=None, - clip_on=True, - normalize=True, - floc=None, + fill: bool = False, + linecolor: Any = None, + density_fit: Literal["log1p", "raw"] = "raw", + clip_on: bool = True, **kwargs, -): +) -> plt.Axes: v = _remove_na(v) if len(v) == 0 or len(x_range) == 0: return - - if kind == "kde": - try: - gkde = gaussian_kde(v, bw_method=bw_method) - y = gkde.evaluate(x_range) + try: + gkde = gaussian_kde(v, bw_method=bw_method) + y = gkde.evaluate(x_range) + if density_fit == "log1p": y = np.log(y + 1.0) - except ValueError: - # Handle cases where there is no data in a group. - y = np.zeros_like(x_range) - except np.linalg.LinAlgError as e: - # Handle singular matrix in kde computation. - distinct_values = np.unique(v) - if len(distinct_values) == 1: - # In case of a group with a single value val, - # that should have infinite density, - # return a δ(val) - val = distinct_values[0] - _logging.logger.warning( - f"The data contains a group with a single distinct value ({val}) " - "having infinite probability density. " - "Consider using a different visualization." - ) - - # Find index i of x_range - # such that x_range[i-1] < val ≤ x_range[i] - i = np.searchsorted(x_range, val) - - y = np.zeros_like(x_range) - y[i] = 1 - else: - raise e - - elif kind == "lognorm": - if floc is not None: - lnparam = stats.lognorm.fit(v, loc=floc) + elif density_fit == "raw": + y = y else: - lnparam = stats.lognorm.fit(v) + raise ValueError("density_fit must be either 'log1p' or 'raw'.") + except ValueError: + # Handle cases where there is no data in a group. + y = np.zeros_like(x_range) + except np.linalg.LinAlgError as e: + # Handle singular matrix in kde computation. + distinct_values = np.unique(v) + if len(distinct_values) == 1: + # In case of a group with a single value val, + # that should have infinite density, + # return a δ(val) + val = distinct_values[0] + _logging.logger.warning( + f"The data contains a group with a single distinct value ({val}) " + "having infinite probability density. " + "Consider using a different visualization." + ) - lpdf = stats.lognorm.pdf(x_range, lnparam[0], lnparam[1], lnparam[2]) - if normalize: - y = lpdf / lpdf.sum() + # Find index i of x_range + # such that x_range[i-1] < val ≤ x_range[i] + i = np.searchsorted(x_range, val) + + y = np.zeros_like(x_range) + y[i] = 1 else: - y = lpdf - elif kind == "counts": - y, bin_edges = np.histogram(v, bins=bins, range=(min(x_range), max(x_range))) - # np.histogram returns the edges of the bins. - # We compute here the middle of the bins. - x_range = _moving_average(bin_edges, 2) - elif kind == "normalized_counts": - y, bin_edges = np.histogram( - v, bins=bins, density=False, range=(min(x_range), max(x_range)) - ) - # np.histogram returns the edges of the bins. - # We compute here the middle of the bins. - y = y / len(v) - x_range = _moving_average(bin_edges, 2) - elif kind == "values": - # Warning: to use values and get a meaningful visualization, - # x_range must also be manually set in the main function. - y = v - x_range = list(range(len(y))) - else: - raise NotImplementedError + raise e if fill: @@ -532,3 +501,4 @@ def _plot_density( kwargs["label"] = None ax.plot(x_range, y, clip_on=clip_on, **kwargs) + return ax From abb7f739c14ac302f8214d47544c151983631afe Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 29 Aug 2024 17:36:08 +0200 Subject: [PATCH 4/4] clean and simplify --- src/cfp/plotting/_plotting.py | 22 +++++++++------------- src/cfp/plotting/_utils.py | 3 +-- tests/plotting/test_plotting.py | 6 ++---- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/src/cfp/plotting/_plotting.py b/src/cfp/plotting/_plotting.py index 6419634a..3523cf58 100644 --- a/src/cfp/plotting/_plotting.py +++ b/src/cfp/plotting/_plotting.py @@ -187,7 +187,7 @@ def plot_condition_embedding( def plot_densities( data: pd.DataFrame, - features: Sequence[str], + feature: str, group_by: str | None = None, density_fit: Literal["log1p", "raw"] = "raw", ax: mpl.axes.Axes | None = None, @@ -213,7 +213,6 @@ def plot_densities( title: str | None = None, colormap: str | mpl.colors.Colormap | None = None, color: Any = None, - normalize: bool = True, grid: bool = False, return_fig: bool = True, **kwargs, @@ -226,8 +225,8 @@ def plot_densities( ---------- data :class:`pandas.DataFrame` object containing (predicted) expression values. - features - Features whose density to plot. + feature + Column in ``'data'`` to plot. group_by Column in ``'data'`` to group by. density_fit @@ -285,8 +284,6 @@ def plot_densities( Colormap to use. color: :mpltype:`color` Color of the density plots. - normalize - Whether to normalize the densities. grid Whether to show the grid. return_fig @@ -300,15 +297,15 @@ def plot_densities( """ if group_by is not None and isinstance(data, pd.DataFrame): grouped = data.groupby(group_by) - if features is None: - features = list(data.columns) - features.remove(group_by) - converted, _labels, sublabels = _grouped_df_to_standard(grouped, features) # type: ignore[arg-type] + if feature is None: + feature = list(data.columns) + feature.remove(group_by) + converted, _labels, sublabels = _grouped_df_to_standard(grouped, feature) # type: ignore[arg-type] if labels is None: labels = _labels elif isinstance(data, pd.DataFrame): - if features is not None: - data = data[features] + if feature is not None: + data = data[feature] converted = [ [_remove_na(data[col])] for col in data.columns if _is_numeric(data[col]) ] @@ -354,7 +351,6 @@ def plot_densities( title=title, colormap=colormap, color=color, - normalize=normalize, density_fit=density_fit, **kwargs, ) diff --git a/src/cfp/plotting/_utils.py b/src/cfp/plotting/_utils.py index 90641fa1..7d6749c9 100644 --- a/src/cfp/plotting/_utils.py +++ b/src/cfp/plotting/_utils.py @@ -218,7 +218,6 @@ def _joyplot( loc="upper right", colormap: str | mpl.colors.Colormap | None = None, color=None, - normalize: bool = True, **kwargs, ) -> tuple[plt.Figure, list[plt.Axes]]: if fill is True and linecolor is None: @@ -338,7 +337,6 @@ def _get_color(i, num_axes, j, num_subgroups): label=sublabel, zorder=element_zorder, color=element_color, - normalize=normalize, bins=bins, density_fit=density_fit, **kwargs, @@ -440,6 +438,7 @@ def _plot_density( linecolor: Any = None, density_fit: Literal["log1p", "raw"] = "raw", clip_on: bool = True, + bins: int = 10, **kwargs, ) -> plt.Axes: v = _remove_na(v) diff --git a/tests/plotting/test_plotting.py b/tests/plotting/test_plotting.py index bf233c60..e4ceb460 100644 --- a/tests/plotting/test_plotting.py +++ b/tests/plotting/test_plotting.py @@ -29,11 +29,9 @@ def test_plot_embeddings( class TestJoyPlot: - @pytest.mark.parametrize("features", [["x", "y"], "x"]) + @pytest.mark.parametrize("features", [["x"]]) @pytest.mark.parametrize("group_by", ["dosage", None]) @pytest.mark.parametrize("hist", [True, False]) def test_plot_joyplot(self, df_joyplot, features, group_by, hist): - fig = plot_densities( - df_joyplot, features=features, group_by=group_by, hist=hist - ) + fig = plot_densities(df_joyplot, feature=features, group_by=group_by, hist=hist) assert isinstance(fig, plt.Figure)