diff --git a/src/cfp/plotting/__init__.py b/src/cfp/plotting/__init__.py index 1b9c072d..c8486788 100644 --- a/src/cfp/plotting/__init__.py +++ b/src/cfp/plotting/__init__.py @@ -1,3 +1,3 @@ -from cfp.plotting._plotting import plot_condition_embedding +from cfp.plotting._plotting import plot_condition_embedding, plot_densities -__all__ = ["plot_condition_embedding"] +__all__ = ["plot_condition_embedding", "plot_densities"] diff --git a/src/cfp/plotting/_plotting.py b/src/cfp/plotting/_plotting.py index d856e147..3523cf58 100644 --- a/src/cfp/plotting/_plotting.py +++ b/src/cfp/plotting/_plotting.py @@ -1,22 +1,30 @@ import types +from collections.abc import Sequence from typing import Any, Literal import anndata as ad import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +import pandas as pd import seaborn as sns from adjustText import adjust_text -from cfp import _constants +from cfp import _constants, _logging from cfp.plotting._utils import ( _compute_kernel_pca_from_df, _compute_pca_from_df, _compute_umap_from_df, _get_colors, + _grouped_df_to_standard, + _is_numeric, + _joyplot, + _remove_na, get_plotting_vars, ) +__all__ = ["plot_condition_embedding", "plot_densities"] + def plot_condition_embedding( adata: ad.AnnData, @@ -175,3 +183,175 @@ def plot_condition_embedding( ax.yaxis.set_tick_params(labelsize=fontsize) return fig if return_fig else None + + +def plot_densities( + data: pd.DataFrame, + feature: str, + group_by: str | None = None, + density_fit: Literal["log1p", "raw"] = "raw", + ax: mpl.axes.Axes | None = None, + figsize: tuple[float, float] | None = None, + dpi: int | None = None, + xlabels: bool = False, + ylabels: bool = True, + xlabelsize: float | None = None, + xrot: float | None = None, + labels: Sequence[Any] = None, + ylabelsize: float | None = None, + yrot: float | None = None, + hist: bool = False, + bins: int = 10, + fade: bool = False, + ylim: Literal["max"] | tuple[float, float] | None = "max", + fill: bool = True, + linecolor: Any = None, + overlap: float = 1.0, + background: Any = None, + range_style: Literal["all", "individual", "group"] | list[float] = "all", + x_range: tuple[float, float] = None, + title: str | None = None, + colormap: str | mpl.colors.Colormap | None = None, + color: Any = None, + grid: bool = False, + return_fig: bool = True, + **kwargs, +): + """Plot kernel density estimations of expressions. + + This function is adapted from https://github.com/leotac/joypy/blob/master/joypy/joyplot.py + + Parameters + ---------- + data + :class:`pandas.DataFrame` object containing (predicted) expression values. + feature + Column in ``'data'`` to plot. + group_by + Column in ``'data'`` to group by. + density_fit + Type of density fit to use. If "raw", the kernel density estimation is plotted. If "log1p", the log1p + transformed values of the densities are plotted. + ax + :class:`matplotlib.axes.Axes` used for plotting. If :obj:`None`, create a new one. + figsize + Size of the figure. + dpi + Dots per inch. + xlabels + Whether to show x-axis labels. + ylabels + Whether to show y-axis labels. + xlabelsize + Size of the x-axis labels. + xrot + Rotation (in degrees) of the x-axis labels. + labels + Sequence of labels for each density plot. + ylabelsize + Size of the y-axis labels. + yrot + Rotation (in degrees) of the y-axis labels. + hist + If :obj:`True`, plot a histogram, otherwise a density plot. + bins + Number of bins to use, only applicable if ``hist`` is :obj:`True`. + fade + If :obj:`True`, automatically sets different values of transparency of the density plots. + ylim + Limits of the y-axis. + fill + Whether to fill the density plots. If :obj:`False`, only the lines are plotted. + linecolor: :mpltype:`color` + Color of the contour lines. + overlap + Overlap between the density plots. The higher the value, the more overlap between densities. + background: :mpltype:`color` + Background color of the plot. + range_style + Style of the range. Options are + + - "all" - all density plots have the same range, autmoatically determined. + - "individual" - every density plot has its own range, automatically determined. + - "group" - each plot has a range that covers the whole group + - type :obj:`list` - custom ranges for each density plot. + + x_range + Custom range for the x-axis, shared across all density plots. If :obj:`None`, set via ``'range_style'``. + title + Title of the plot. + colormap + Colormap to use. + color: :mpltype:`color` + Color of the density plots. + grid + Whether to show the grid. + return_fig + Whether to return the figure. + kwargs + Additional keyword arguments for the plot. + + Returns + ------- + :class:`matplotlib.figure.Figure` if ``'return_fig'`` is :obj:`True`, else :obj:`None`. + """ + if group_by is not None and isinstance(data, pd.DataFrame): + grouped = data.groupby(group_by) + if feature is None: + feature = list(data.columns) + feature.remove(group_by) + converted, _labels, sublabels = _grouped_df_to_standard(grouped, feature) # type: ignore[arg-type] + if labels is None: + labels = _labels + elif isinstance(data, pd.DataFrame): + if feature is not None: + data = data[feature] + converted = [ + [_remove_na(data[col])] for col in data.columns if _is_numeric(data[col]) + ] + labels = [col for col in data.columns if _is_numeric(data[col])] + sublabels = None + else: + raise TypeError(f"Unknown type for 'data': {type(data)!r}") + + if ylabels is False: + labels = None + + if all(len(subg) == 0 for g in converted for subg in g): + raise ValueError( + "No numeric values found. Joyplot requires at least a numeric column/group." + ) + + if any(len(subg) == 0 for g in converted for subg in g): + _logging.logger.warning("At least a column/group has no numeric values.") + + fig, axes = _joyplot( + converted, + labels=labels, + sublabels=sublabels, + grid=grid, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + ax=ax, + dpi=dpi, + figsize=figsize, + hist=hist, + bins=bins, + fade=fade, + ylim=ylim, + fill=fill, + linecolor=linecolor, + overlap=overlap, + background=background, + xlabels=xlabels, + range_style=range_style, + x_range=x_range, + title=title, + colormap=colormap, + color=color, + density_fit=density_fit, + **kwargs, + ) + return fig if return_fig else None diff --git a/src/cfp/plotting/_utils.py b/src/cfp/plotting/_utils.py index 1c161ce6..7d6749c9 100644 --- a/src/cfp/plotting/_utils.py +++ b/src/cfp/plotting/_utils.py @@ -1,15 +1,22 @@ from collections.abc import Sequence -from typing import Any +from typing import Any, Literal import anndata as ad +import matplotlib as mpl +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scanpy as sc import seaborn as sns +from pandas.core.dtypes.common import is_number +from pandas.plotting._matplotlib.tools import create_subplots as _subplots +from pandas.plotting._matplotlib.tools import flatten_axes as _flatten +from scipy.stats import gaussian_kde from sklearn.decomposition import KernelPCA from sklearn.metrics.pairwise import cosine_similarity from cfp import _constants, _logging +from cfp._types import ArrayLike def set_plotting_vars( @@ -102,3 +109,395 @@ def _compute_kernel_pca_from_df( similarity_matrix ) return pd.DataFrame(data=X, columns=list(range(n_components)), index=df.index) + + +def _x_range(data: ArrayLike | list[float], extra: float = 0.2) -> ArrayLike: + """Compute the x_range for density estimation.""" + try: + sample_range = np.nanmax(data) - np.nanmin(data) + except ValueError: + return np.array([]) + if sample_range < 1e-6: + return np.array([np.nanmin(data), np.nanmax(data)]) + return np.linspace( + np.nanmin(data) - extra * sample_range, + np.nanmax(data) + extra * sample_range, + 1000, + ) + + +def _setup_axis( + ax: plt.Axes, + x_range: ArrayLike, + col_name: str | None = None, + grid: bool = False, + ylabelsize: int | None = None, + yrot: int | None = None, +) -> None: + """Setup the axis for the joyplot.""" + if col_name is not None: + ax.set_yticks([0]) + ax.set_yticklabels([col_name], fontsize=ylabelsize, rotation=yrot) + ax.yaxis.grid(grid) + else: + ax.yaxis.set_visible(False) + ax.patch.set_alpha(0) + ax.set_xlim([x_range.min(), x_range.max()]) + ax.tick_params(axis="both", which="both", length=0, pad=10) + + +def _get_alpha(i: int, n: int, start: float = 0.4, end: float = 1.0) -> float: + """Compute alpha value for plotting.""" + return start + (1 + i) * (end - start) / n + + +def _is_numeric(x: ArrayLike) -> bool: + """Whether the array x is numeric.""" + return all(is_number(i) for i in x) + + +def _remove_na(data: list[Any] | ArrayLike | pd.Series) -> ArrayLike: + """Remove NA values from the data.""" + return pd.Series(data).dropna().values + + +def _moving_average(a: ArrayLike, n: int = 3, zero_padded: bool = False) -> ArrayLike: + """Calculate the moving average of order n.""" + ret = np.cumsum(a, dtype=float) + ret[n:] = ret[n:] - ret[:-n] + if zero_padded: + return ret / n + else: + return ret[n - 1 :] / n + + +def _grouped_df_to_standard( + grouped: pd.api.typing.DataFrameGroupBy, column: str | None +) -> tuple[pd.DataFrame, list[str], list[str]]: + converted = [] + labels = [] + for i, (key, group) in enumerate(grouped): + if column is not None: + group = group[column] + labels.append(key) + converted.append( + [_remove_na(group[c]) for c in group.columns if _is_numeric(group[c])] + ) + if i == 0: + sublabels = [col for col in group.columns if _is_numeric(group[col])] + return converted, labels, sublabels + + +def _joyplot( + data: pd.DataFrame, + grid: bool = False, + density_fit: Literal["log1p", "raw"] = "raw", + labels=None, + sublabels=None, + xlabels=True, + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, + ax: mpl.axes.Axes | None = None, + dpi: int | None = None, + figsize: tuple[float, float] | None = None, + hist: bool = False, + bins=10, + fade: bool = False, + ylim="max", + fill=True, + linecolor=None, + overlap: float = 1.0, + background: Any = None, + range_style: Literal["all", "individual", "group"] | list[float] = "all", + x_range: tuple[float, float] = None, + tails: float = 0.2, + title: str | None = None, + legend=False, + loc="upper right", + colormap: str | mpl.colors.Colormap | None = None, + color=None, + **kwargs, +) -> tuple[plt.Figure, list[plt.Axes]]: + if fill is True and linecolor is None: + linecolor = "k" + + if sublabels is None: + legend = False + + def _get_color(i, num_axes, j, num_subgroups): + if isinstance(color, list): + return color[j] if num_subgroups > 1 else color[i] + elif color is not None: + return color + elif isinstance(colormap, list): + return colormap[j](i / num_axes) + elif color is None and colormap is None: + num_cycle_colors = len(plt.rcParams["axes.prop_cycle"].by_key()["color"]) + return plt.rcParams["axes.prop_cycle"].by_key()["color"][ + j % num_cycle_colors + ] + else: + return colormap(i / num_axes) # type: ignore[operator] + + ygrid = grid is True or grid == "y" or grid == "both" + xgrid = grid is True or grid == "x" or grid == "both" + + num_axes = len(data) + + if x_range is None: + global_x_range = _x_range([v for g in data for sg in g for v in sg]) + else: + global_x_range = _x_range(x_range, 0.0) # type: ignore[arg-type] + + # Each plot will have its own axis + fig, axes = _subplots( + naxes=num_axes, + dpi=dpi, + ax=ax, + squeeze=False, + sharex=True, + sharey=False, + figsize=figsize, + layout_type="vertical", + ) + _axes = _flatten(axes) + + # The legend must be drawn in the last axis if we want it at the bottom. + if loc in (3, 4, 8) or "lower" in str(loc): + legend_axis = num_axes - 1 + else: + legend_axis = 0 + + # A couple of simple checks. + if labels is not None: + assert len(labels) == num_axes + if sublabels is not None: + assert all(len(g) == len(sublabels) for g in data) + if isinstance(color, list): + assert all(len(g) <= len(color) for g in data) + if isinstance(colormap, list): + assert all(len(g) == len(colormap) for g in data) + + for i, group in enumerate(data): + + a = _axes[i] + group_zorder = i + if fade: + kwargs["alpha"] = _get_alpha(i, num_axes) + + num_subgroups = len(group) + + if hist: + # matplotlib hist() already handles multiple subgroups in a histogram + ax = a.hist( + group, + label=sublabels, + bins=bins, + color=color, + range=[min(global_x_range), max(global_x_range)], + edgecolor=linecolor, + zorder=group_zorder, + **kwargs, + ) + else: + for j, subgroup in enumerate(group): + + # Compute the x_range of the current plot + if range_style == "all": + # All plots have the same range + x_range = global_x_range + elif range_style == "individual": + # Each plot has its own range + x_range = _x_range(subgroup, tails) + elif range_style == "group": + # Each plot has a range that covers the whole group + x_range = _x_range(group, tails) + elif isinstance(range_style, list): + # All plots have exactly the range passed as argument + x_range = _x_range(range_style, 0.0) + else: + raise NotImplementedError("Unrecognized range style.") + + if sublabels is None: + sublabel = None + else: + sublabel = sublabels[j] + + element_zorder = group_zorder + j / (num_subgroups + 1) + element_color = _get_color(i, num_axes, j, num_subgroups) + + ax = _plot_density( + a, + x_range, # type: ignore[arg-type] + subgroup, + fill=fill, + linecolor=linecolor, + label=sublabel, + zorder=element_zorder, + color=element_color, + bins=bins, + density_fit=density_fit, + **kwargs, + ) + + # Setup the current axis: transparency, labels, spines. + col_name = None if labels is None else labels[i] + _setup_axis( + a, + global_x_range, + col_name=col_name, + grid=ygrid, + ylabelsize=ylabelsize, + yrot=yrot, + ) + + # When needed, draw the legend + if legend and i == legend_axis: + a.legend(loc=loc) + # Bypass alpha values, in case + for p in a.get_legend().get_patches(): + p.set_facecolor(p.get_facecolor()) + p.set_alpha(1.0) + for l in a.get_legend().get_lines(): + l.set_alpha(1.0) + + # Final adjustments + + # Set the y limit for the density plots. + # Since the y range in the subplots can vary significantly, + # different options are available. + if ylim == "max": + # Set all yaxis limit to the same value (max range among all) + max_ylim = max(a.get_ylim()[1] for a in _axes) + min_ylim = min(a.get_ylim()[0] for a in _axes) + for a in _axes: + a.set_ylim([min_ylim - 0.1 * (max_ylim - min_ylim), max_ylim]) + + elif ylim is None: + # Do nothing, each axis keeps its own ylim + pass + + else: + # Set all yaxis lim to the argument value ylim + try: + for a in _axes: + a.set_ylim(ylim) + except ValueError: + raise ValueError( + "Warning: the value of ylim must be either 'max', 'own', or a tuple of length 2. The value you provided has no effect." + ) from None + + # Compute a final axis, used to apply global settings + last_axis = fig.add_subplot(1, 1, 1) + + # Background color + if background is not None: + last_axis.patch.set_facecolor(background) + + # This looks hacky, but all the axes share the x-axis, + # so they have the same lims and ticks + last_axis.set_xlim(_axes[0].get_xlim()) + if xlabels is True: + last_axis.set_xticks(np.array(_axes[0].get_xticks()[1:-1])) + for t in last_axis.get_xticklabels(): + t.set_visible(True) + t.set_fontsize(xlabelsize) + t.set_rotation(xrot) + + # If grid is enabled, do not allow xticks (they are ugly) + if xgrid: + last_axis.tick_params(axis="both", which="both", length=0) + else: + last_axis.xaxis.set_visible(False) + + last_axis.yaxis.set_visible(False) + last_axis.grid(xgrid) + + # Last axis on the back + last_axis.zorder = min(a.zorder for a in _axes) - 1 + _axes = list(_axes) + [last_axis] + + if title is not None: + plt.title(title) + + # The magic overlap happens here. + h_pad = 5 + (-5 * (1 + overlap)) + fig.tight_layout(h_pad=h_pad) + + return fig, _axes + + +def _plot_density( + ax: plt.Axes, + x_range: ArrayLike, + v: ArrayLike, + bw_method=None, + fill: bool = False, + linecolor: Any = None, + density_fit: Literal["log1p", "raw"] = "raw", + clip_on: bool = True, + bins: int = 10, + **kwargs, +) -> plt.Axes: + v = _remove_na(v) + if len(v) == 0 or len(x_range) == 0: + return + try: + gkde = gaussian_kde(v, bw_method=bw_method) + y = gkde.evaluate(x_range) + if density_fit == "log1p": + y = np.log(y + 1.0) + elif density_fit == "raw": + y = y + else: + raise ValueError("density_fit must be either 'log1p' or 'raw'.") + except ValueError: + # Handle cases where there is no data in a group. + y = np.zeros_like(x_range) + except np.linalg.LinAlgError as e: + # Handle singular matrix in kde computation. + distinct_values = np.unique(v) + if len(distinct_values) == 1: + # In case of a group with a single value val, + # that should have infinite density, + # return a δ(val) + val = distinct_values[0] + _logging.logger.warning( + f"The data contains a group with a single distinct value ({val}) " + "having infinite probability density. " + "Consider using a different visualization." + ) + + # Find index i of x_range + # such that x_range[i-1] < val ≤ x_range[i] + i = np.searchsorted(x_range, val) + + y = np.zeros_like(x_range) + y[i] = 1 + else: + raise e + + if fill: + + ax.fill_between(x_range, 0.0, y, clip_on=clip_on, **kwargs) + + # Hack to have a border at the bottom at the fill patch + # (of the same color of the fill patch) + # so that the fill reaches the same bottom margin as the edge lines + # with y value = 0.0 + kw = kwargs + kw["label"] = None + ax.plot(x_range, [0.0] * len(x_range), clip_on=clip_on, **kw) + + if linecolor is not None: + kwargs["color"] = linecolor + + # Remove the legend labels if we are plotting filled curve: + # we only want one entry per group in the legend (if shown). + if fill: + kwargs["label"] = None + + ax.plot(x_range, y, clip_on=clip_on, **kwargs) + return ax diff --git a/tests/plotting/conftest.py b/tests/plotting/conftest.py index b74dbbe2..9d0609ca 100644 --- a/tests/plotting/conftest.py +++ b/tests/plotting/conftest.py @@ -27,3 +27,10 @@ def adata_with_condition_embedding(adata_perturbation) -> ad.AnnData: adata_perturbation.uns[_constants.CFP_KEY] = {} adata_perturbation.uns[_constants.CFP_KEY][_constants.CONDITION_EMBEDDING] = df return adata_perturbation + + +@pytest.fixture +def df_joyplot() -> pd.DataFrame: + test_df = pd.DataFrame(data=np.random.randn(100, 2), columns=["x", "y"]) + test_df["dosage"] = np.random.choice([0.0, 1.0, 2.0], 100) + return test_df diff --git a/tests/plotting/test_plotting.py b/tests/plotting/test_plotting.py index a1dfe89e..e4ceb460 100644 --- a/tests/plotting/test_plotting.py +++ b/tests/plotting/test_plotting.py @@ -1,10 +1,10 @@ import matplotlib.pyplot as plt import pytest -from cfp.plotting import plot_condition_embedding +from cfp.plotting import plot_condition_embedding, plot_densities -class TestCallbacks: +class TestPlotConditionEmbedding: @pytest.mark.parametrize( "embedding", ["raw_embedding", "UMAP", "PCA", "Kernel_PCA"] ) @@ -26,3 +26,12 @@ def test_plot_embeddings( ) assert isinstance(fig, plt.Figure) + + +class TestJoyPlot: + @pytest.mark.parametrize("features", [["x"]]) + @pytest.mark.parametrize("group_by", ["dosage", None]) + @pytest.mark.parametrize("hist", [True, False]) + def test_plot_joyplot(self, df_joyplot, features, group_by, hist): + fig = plot_densities(df_joyplot, feature=features, group_by=group_by, hist=hist) + assert isinstance(fig, plt.Figure)