From a123971f5a0c6e0318aa6963bcc4c43dddcccbdf Mon Sep 17 00:00:00 2001 From: Andrew Mitchell Date: Sat, 10 May 2025 01:46:23 +0100 Subject: [PATCH 1/8] Add stub methods to ISOPlotLayersMixin for add_layer and get_single_axes to resolve unresolved method warnings --- src/soundscapy/plotting/iso_plot.py | 699 +------------------- src/soundscapy/plotting/iso_plot_layers.py | 374 +++++++++++ src/soundscapy/plotting/iso_plot_styling.py | 344 ++++++++++ 3 files changed, 724 insertions(+), 693 deletions(-) create mode 100644 src/soundscapy/plotting/iso_plot_layers.py create mode 100644 src/soundscapy/plotting/iso_plot_styling.py diff --git a/src/soundscapy/plotting/iso_plot.py b/src/soundscapy/plotting/iso_plot.py index 4fa6f60..f40857c 100644 --- a/src/soundscapy/plotting/iso_plot.py +++ b/src/soundscapy/plotting/iso_plot.py @@ -4,7 +4,7 @@ Example: ------- >>> from soundscapy import isd, surveys ->>> from soundscapy.plotting.iso_plot_new import ISOPlot +>>> from soundscapy.plotting import ISOPlot >>> df = isd.load() >>> df = surveys.add_iso_coords(df) >>> sub_df = isd.select_location_ids(df, ['CamdenTown', 'RegentsParkJapan']) @@ -30,30 +30,19 @@ import warnings from typing import TYPE_CHECKING, Any +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import seaborn as sns -from matplotlib import pyplot as plt -from matplotlib import ticker -from matplotlib.artist import Artist from matplotlib.axes import Axes from matplotlib.figure import Figure, SubFigure from soundscapy.plotting.defaults import ( - DEFAULT_STYLE_PARAMS, DEFAULT_XCOL, DEFAULT_YCOL, RECOMMENDED_MIN_SAMPLES, ) -from soundscapy.plotting.layers import ( - DensityLayer, - Layer, - ScatterLayer, - SimpleDensityLayer, - SPIDensityLayer, - SPIScatterLayer, - SPISimpleLayer, -) +from soundscapy.plotting.iso_plot_layers import ISOPlotLayersMixin +from soundscapy.plotting.iso_plot_styling import ISOPlotStylingMixin from soundscapy.plotting.plot_context import PlotContext from soundscapy.plotting.plotting_types import ( ParamModel, @@ -65,16 +54,13 @@ if TYPE_CHECKING: from collections.abc import Generator + from soundscapy.plotting.layers import Layer from soundscapy.plotting.plotting_types import SeabornPaletteType - from soundscapy.spi.msn import ( - CentredParams, - DirectParams, - ) logger = get_logger() -class ISOPlot: +class ISOPlot(ISOPlotLayersMixin, ISOPlotStylingMixin): """ A class for creating circumplex plots using different backends. @@ -1096,676 +1082,3 @@ def _resolve_axis_indices( return on_axis msg = f"Invalid axis specification: {on_axis}" raise ValueError(msg) - - def add_scatter( - self, - data: pd.DataFrame | None = None, - *, - on_axis: int | tuple[int, int] | list[int] | None = None, - **params: Any, - ) -> ISOPlot: - """ - Add a scatter layer to specific subplot(s). - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes - data : pd.DataFrame, optional - Custom data for this specific scatter plot - **params : dict - Parameters for the scatter plot - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - Add a scatter layer to all subplots: - - >>> import pandas as pd - >>> import numpy as np - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame( - ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... rng.integers(1, 3, 100)], - ... columns=['ISOPleasant', 'ISOEventful', 'Group']) - >>> plot = (ISOPlot(data=data) - ... .create_subplots(nrows=2, ncols=1) - ... .add_scatter(s=50, alpha=0.7, hue='Group') - ... .apply_styling()) - >>> plot.show() # xdoctest: +SKIP - >>> all(len(ctx.layers) == 1 for ctx in plot.subplot_contexts) - True - >>> plot.close() # Clean up - - Add a scatter layer with custom data to a specific subplot: - - >>> custom_data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0.2, 0.1, 50), - ... 'ISOEventful': rng.normal(0.15, 0.2, 50), - ... }) - >>> plot = (ISOPlot(data=data) - ... .create_subplots(nrows=2, ncols=1) - ... .add_scatter(hue='Group') - ... .add_scatter(on_axis=0, data=custom_data, color='red') - ... .apply_styling()) - >>> plot.show() # xdoctest: +SKIP - >>> plot.subplot_contexts[0].layers[1].custom_data is custom_data - True - >>> plot.close() # Clean up - - """ - # Merge default scatter parameters with provided ones - # Remove data from scatter_params to avoid conflict - scatter_params = self._scatter_params.model_copy().drop("data").update(**params) - - return self.add_layer( - ScatterLayer, data=data, on_axis=on_axis, **scatter_params.as_dict() - ) - - def add_spi( - self, - on_axis: int | tuple[int, int] | list[int] | None = None, - spi_target_data: pd.DataFrame | np.ndarray | None = None, - msn_params: DirectParams | CentredParams | None = None, - *, - layer_class: type[Layer] = SPISimpleLayer, - **params: Any, - ) -> ISOPlot: - """ - Add a SPI layer to specific subplot(s). - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes - spi_target_data : pd.DataFrame | np.ndarray | None, optional - Custom data for this specific SPI plot - msn_params : DirectParams | CentredParams | None, optional - Parameters for the SPI plot - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - Add a SPI layer to all subplots: - - >>> import pandas as pd - >>> import numpy as np - >>> from soundscapy.spi import DirectParams - >>> rng = np.random.default_rng(42) - >>> # Create a DataFrame with random data - >>> data = pd.DataFrame( - ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... columns=['ISOPleasant', 'ISOEventful'] - ... ) - >>> # Define MSN parameters for the SPI target - >>> msn_params = DirectParams( - ... xi=np.array([0.5, 0.7]), - ... omega=np.array([[0.1, 0.05], [0.05, 0.1]]), - ... alpha=np.array([0, -5]), - ... ) - >>> # Create the plot with only an SPI layer - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .add_spi(msn_params=msn_params) - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 2 - True - >>> plot.close() # Clean up - - Add an SPI layer over top of 'real' data: - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .add_density() - ... .add_spi(msn_params=msn_params, show_score="on axis") - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 3 - True - >>> plot.close() # Clean up - - """ - if layer_class == SPISimpleLayer: - spi_simple_params = ( - self._spi_simple_density_params.model_copy() - .drop("data") - .update(**params) - ) - - return self.add_layer( - layer_class, - on_axis=on_axis, - msn_params=msn_params, - spi_target_data=spi_target_data, - **spi_simple_params.as_dict(), - ) - if layer_class in (SPIDensityLayer, SPIScatterLayer): - msg = ( - "Only the simple density layer type is currently supported for " - "SPI plots. Please use SPISimpleLayer" - ) - raise NotImplementedError(msg) - - msg = "Invalid layer class provided. Expected SPISimpleLayer. " - raise ValueError(msg) - - def add_density( - self, - on_axis: int | tuple[int, int] | list[int] | None = None, - data: pd.DataFrame | None = None, - *, - include_outline: bool = False, - **params: Any, - ) -> ISOPlot: - """ - Add a density layer to specific subplot(s). - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes - data : pd.DataFrame, optional - Custom data for this specific density plot - include_outline : bool, optional - Whether to include an outline around the density plot, by default False - **params : dict - Parameters for the density plot - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - Add a density layer to all subplots: - - >>> import pandas as pd - >>> import numpy as np - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0.2, 0.25, 50), - ... 'ISOEventful': rng.normal(0.15, 0.4, 50), - ... }) - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_density() - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 1 - True - >>> plot.close() # Clean up - - Add a density layer with custom settings: - - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_density(levels=5, alpha=0.7) - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 1 - True - >>> plot.close() # Clean up - - """ - # Merge default density parameters with provided ones - density_params = self._density_params.model_copy().drop("data").update(**params) - - return self.add_layer( - DensityLayer, - data=data, - on_axis=on_axis, - include_outline=include_outline, - **density_params.as_dict(), - ) - - def add_simple_density( - self, - on_axis: int | tuple[int, int] | list[int] | None = None, - data: pd.DataFrame | None = None, - *, - include_outline: bool = True, - **params: Any, - ) -> ISOPlot: - """ - Add a simple density layer to specific subplot(s). - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes - data : pd.DataFrame, optional - Custom data for this specific density plot - thresh : float, optional - Threshold for density contours, by default 0.5 - levels : int | Iterable[float], optional - Contour levels, by default 2 - alpha : float, optional - Transparency level, by default 0.5 - include_outline : bool, optional - Whether to include an outline around the density plot, by default True - **params : dict - Additional parameters for the density plot - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - Add a simple density layer: - - >>> import pandas as pd - >>> import numpy as np - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0.2, 0.25, 30), - ... 'ISOEventful': rng.normal(0.15, 0.4, 30), - ... }) - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .add_simple_density() - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 2 - True - >>> plot.close() # Clean up - - Add a simple density with splitting by group: - >>> data = pd.DataFrame( - ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... rng.integers(1, 3, 100)], - ... columns=['ISOPleasant', 'ISOEventful', 'Group']) - >>> plot = ( - ... ISOPlot(data=data, hue='Group') - ... .create_subplots() - ... .add_scatter() - ... .add_simple_density() - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 2 - True - >>> plot.close() - ... - - """ - # Merge default simple density parameters with provided ones - simple_density_params = ( - self._simple_density_params.model_copy().drop("data").update(**params) - ) - - return self.add_layer( - SimpleDensityLayer, - on_axis=on_axis, - data=data, - include_outline=include_outline, - **simple_density_params.as_dict(), - ) - - def add_annotation( - self, - text: str, - xy: tuple[float, float], - xytext: tuple[float, float], - arrowprops: dict[str, Any] | None = None, - ) -> ISOPlot: - """ - Add an annotation to the plot. - - Parameters - ---------- - text : str - The text to display in the annotation. - xy : tuple[float, float] - The point to annotate. - xytext : tuple[float, float] - The point at which to place the text. - arrowprops : dict[str, Any] | None, optional - Properties for the arrow connecting the annotation text to the point. - - Returns - ------- - ISOPlot - The current plot instance for chaining - - """ - msg = "AnnotationLayer is not yet implemented. " - raise NotImplementedError(msg) - # TODO(MitchellAcoustics): Implement AnnotationLayer # noqa: TD003 - return self.add_layer( - "AnnotationLayer", - text=text, - xy=xy, - xytext=xytext, - arrowprops=arrowprops, - ) - - def apply_styling( - self, - **kwargs: Any, - ) -> ISOPlot: - """ - Apply styling to the plot. - - Parameters - ---------- - **kwargs: Styling parameters to override defaults - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - Apply styling with default parameters: - - >>> import pandas as pd - >>> import numpy as np - >>> rng = np.random.default_rng(42) - >>> # Create simple data for styling example - >>> data = pd.DataFrame( - ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... rng.integers(1, 3, 100)], - ... columns=['ISOPleasant', 'ISOEventful', 'Group']) - >>> # Create plot with default styling - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> plot.get_figure() is not None - True - >>> plot.close() # Clean up - - Apply styling with custom parameters: - - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .apply_styling(xlim=(-2, 2), ylim=(-2, 2), primary_lines=False) - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> plot.get_figure() is not None - True - >>> plot.close() # Clean up - - Demonstrate the fluent interface (method chaining): - - >>> # Create plot with method chaining - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots(nrows=1, ncols=1) - ... .add_scatter(alpha=0.7) - ... .add_density(levels=5) - ... .apply_styling(title_fontsize=14) - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> # Verify results - >>> isinstance(plot, ISOPlot) - True - >>> plot.close() # Clean up - - """ - self._style_params.update(**kwargs) - self._check_for_axes() - - self._set_style() - self._circumplex_grid() - self._set_title() - self._set_axes_titles() - self._primary_labels() - if self._style_params.get("primary_lines"): - self._primary_lines() - if self._style_params.get("diagonal_lines"): - self._diagonal_lines_and_labels() - - if self._style_params.get("legend_loc") is not False: - self._move_legend() - - return self - - def _set_style(self) -> None: - """Set the overall style for the plot.""" - sns.set_style({"xtick.direction": "in", "ytick.direction": "in"}) - - def _circumplex_grid(self) -> ISOPlot: - """Add the circumplex grid to the plot.""" - for _, axis in enumerate(self.yield_axes_objects()): - axis.set_xlim(self._style_params.get("xlim")) - axis.set_ylim(self._style_params.get("ylim")) - axis.set_aspect("equal") - - axis.get_yaxis().set_minor_locator(ticker.AutoMinorLocator()) - axis.get_xaxis().set_minor_locator(ticker.AutoMinorLocator()) - - axis.grid(visible=True, which="major", color="grey", alpha=0.5) - axis.grid( - visible=True, - which="minor", - color="grey", - linestyle="dashed", - linewidth=0.5, - alpha=0.4, - zorder=self._style_params.get("prim_lines_zorder"), - ) - - return self - - def _set_title(self) -> ISOPlot: - """Set the title of the plot.""" - if self.title and self._has_subplots: - figure = self.get_figure() - figure.suptitle( - self.title, fontsize=self._style_params.get("title_fontsize") - ) - elif self.title and not self._has_subplots: - axis = self.get_single_axes() - if axis.get_title() == "": - axis.set_title( - self.title, fontsize=self._style_params.get("title_fontsize") - ) - else: - figure = self.get_figure() - figure.suptitle( - self.title, fontsize=self._style_params.get("title_fontsize") - ) - return self - - def _set_axes_titles(self) -> ISOPlot: - """Set the titles of the subplots.""" - for context in self.subplot_contexts: - if context.ax and context.title: - context.ax.set_title(context.title) - return self - - def _primary_lines(self) -> ISOPlot: - """Add primary lines to the plot.""" - for _, axis in enumerate(self.yield_axes_objects()): - axis.axhline( - y=0, - color="grey", - linestyle="dashed", - alpha=1, - lw=self._style_params.get("linewidth"), - zorder=self._style_params.get("prim_lines_zorder"), - ) - axis.axvline( - x=0, - color="grey", - linestyle="dashed", - alpha=1, - lw=self._style_params.get("linewidth"), - zorder=self._style_params.get("prim_lines_zorder"), - ) - return self - - def _primary_labels(self) -> ISOPlot: - """Handle the default labels for the x and y axes.""" - xlabel = self._style_params.get("xlabel") - ylabel = self._style_params.get("ylabel") - - xlabel = self.x if xlabel is None else xlabel - ylabel = self.y if ylabel is None else ylabel - fontdict = self._style_params.get("prim_ax_fontdict") - - # BUG: For some reason, this ruins the sharex and sharey - # functionality, but only when a layer is applied - # a specific subplot. - for _, axis in enumerate(self.yield_axes_objects()): - axis.set_xlabel( - xlabel, fontdict=fontdict - ) if xlabel is not False else axis.xaxis.label.set_visible(False) - - axis.set_ylabel( - ylabel, fontdict=fontdict - ) if ylabel is not False else axis.yaxis.label.set_visible(False) - - return self - - def _diagonal_lines_and_labels(self) -> ISOPlot: - """ - Add diagonal lines and labels to the plot. - - Examples - -------- - >>> import pandas as pd - >>> import numpy as np - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame( - ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... columns=['ISOPleasant', 'ISOEventful']) - >>> # Create a plot with diagonal lines and labels - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .apply_styling(diagonal_lines=True) - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> plot.close('all') - - """ - for _, axis in enumerate(self.yield_axes_objects()): - xlim = self._style_params.get("xlim", DEFAULT_STYLE_PARAMS["xlim"]) - ylim = self._style_params.get("ylim", DEFAULT_STYLE_PARAMS["ylim"]) - axis.plot( - xlim, - ylim, - linestyle="dashed", - color="grey", - alpha=0.5, - lw=self._style_params.get("linewidth"), - zorder=self._style_params.get("diag_lines_zorder"), - ) - logger.debug("Plotting diagonal line for axis.") - axis.plot( - xlim, - ylim[::-1], - linestyle="dashed", - color="grey", - alpha=0.5, - lw=self._style_params.get("linewidth"), - zorder=self._style_params.get("diag_lines_zorder"), - ) - - diag_ax_font = { - "fontstyle": "italic", - "fontsize": "small", - "fontweight": "bold", - "color": "black", - "alpha": 0.5, - } - axis.text( - xlim[1] / 2, - ylim[1] / 2, - "(vibrant)", - ha="center", - va="center", - fontdict=diag_ax_font, - zorder=self._style_params.get("diag_labels_zorder"), - ) - axis.text( - xlim[0] / 2, - ylim[1] / 2, - "(chaotic)", - ha="center", - va="center", - fontdict=diag_ax_font, - zorder=self._style_params.get("diag_labels_zorder"), - ) - axis.text( - xlim[0] / 2, - ylim[0] / 2, - "(monotonous)", - ha="center", - va="center", - fontdict=diag_ax_font, - zorder=self._style_params.get("diag_labels_zorder"), - ) - axis.text( - xlim[1] / 2, - ylim[0] / 2, - "(calm)", - ha="center", - va="center", - fontdict=diag_ax_font, - zorder=self._style_params.get("diag_labels_zorder"), - ) - return self - - def _move_legend(self) -> ISOPlot: - """Move the legend to the specified location.""" - for i, axis in enumerate(self.yield_axes_objects()): - old_legend = axis.get_legend() - if old_legend is None: - # logger.debug("_move_legend: No legend found for axis %s", i) - continue - - # Get handles and filter out None values - handles = [ - h for h in old_legend.legend_handles if isinstance(h, Artist | tuple) - ] - # Skip if no valid handles remain - if not handles: - continue - - labels = [t.get_text() for t in old_legend.get_texts()] - title = old_legend.get_title().get_text() - # Ensure labels and handles match in length - if len(handles) != len(labels): - labels = labels[: len(handles)] - - axis.legend( - handles, - labels, - loc=self._style_params.get("legend_loc"), - title=title, - ) - return self diff --git a/src/soundscapy/plotting/iso_plot_layers.py b/src/soundscapy/plotting/iso_plot_layers.py new file mode 100644 index 0000000..b2f730c --- /dev/null +++ b/src/soundscapy/plotting/iso_plot_layers.py @@ -0,0 +1,374 @@ +""" +Layer-specific methods for ISOPlot. + +This module provides a mixin class with methods for adding different types of +visualization layers to ISOPlot instances. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import pandas as pd + +from soundscapy.plotting.layers import ( + DensityLayer, + Layer, + ScatterLayer, + SimpleDensityLayer, + SPISimpleLayer, +) + +if TYPE_CHECKING: + from soundscapy.spi.msn import ( + CentredParams, + DirectParams, + ) + + +class ISOPlotLayersMixin: + """Mixin providing layer-specific methods for ISOPlot.""" + + def add_layer( + self, + layer_class: type[Layer], + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> Any: + """ + Add a visualization layer, optionally targeting specific subplot(s). + + This is a stub method that should be implemented by classes using this mixin. + The actual implementation is in the ISOPlot class. + + Parameters + ---------- + layer_class : Layer subclass + The type of layer to add + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes + data : pd.DataFrame, optional + Custom data for this specific layer + **params : dict + Parameters for the layer + + Returns + ------- + Any + The current plot instance for chaining + """ + raise NotImplementedError( + "Classes using ISOPlotLayersMixin must implement add_layer" + ) + + def get_single_axes(self, ax_idx: int | tuple[int, int] | None = None) -> Any: + """ + Get a specific axes object. + + This is a stub method that should be implemented by classes using this mixin. + The actual implementation is in the ISOPlot class. + + Parameters + ---------- + ax_idx : int | tuple[int, int] | None, optional + The index of the axes to get. If None, returns the first axes. + Can be an integer for flattened access or a tuple of (row, col). + + Returns + ------- + Any + The requested matplotlib Axes object + """ + raise NotImplementedError( + "Classes using ISOPlotLayersMixin must implement get_single_axes" + ) + + def add_scatter( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> Any: + """ + Add a scatter layer to the plot. + + Parameters + ---------- + data : pd.DataFrame | None, optional + Custom data for this specific layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes: + - int: Index of subplot (flattened) + - tuple: (row, col) coordinates + - list: Multiple indices to apply the layer to + - None: Apply to all subplots (default) + **params : Any + Additional parameters for the scatter layer + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.plotting import ISOPlot + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0, 0.5, 100), + ... 'ISOEventful': rng.normal(0, 0.5, 100), + ... 'Group': rng.integers(1, 3, 100) + ... }) + >>> plot = (ISOPlot(data=data) + ... .add_scatter(hue='Group') + ... .apply_styling()) + >>> plot.show() # xdoctest: +SKIP + >>> plot.close() # Clean up + + """ + return self.add_layer(ScatterLayer, data=data, on_axis=on_axis, **params) + + def add_spi( + self, + on_axis: int | tuple[int, int] | list[int] | None = None, + spi_target_data: pd.DataFrame | np.ndarray | None = None, + msn_params: DirectParams | CentredParams | None = None, + *, + layer_class: type[Layer] = SPISimpleLayer, + **params: Any, + ) -> Any: + """ + Add an SPI (Soundscape Perception Index) layer to the plot. + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes: + - int: Index of subplot (flattened) + - tuple: (row, col) coordinates + - list: Multiple indices to apply the layer to + - None: Apply to all subplots (default) + spi_target_data : pd.DataFrame | np.ndarray | None, optional + Pre-sampled data for SPI target distribution, by default None + msn_params : DirectParams | CentredParams | None, optional + Parameters to generate SPI data if no spi_target_data is provided, by default None + layer_class : type[Layer], optional + The type of SPI layer to add, by default SPISimpleLayer + **params : Any + Additional parameters for the SPI layer + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Notes + ----- + Either spi_target_data or msn_params must be provided, but not both. + The test data for SPI calculations will be retrieved from the plot context. + + """ + # Validate that we have either spi_target_data or msn_params + if spi_target_data is None and msn_params is None: + msg = ( + "No data provided for SPI plot. " + "Please provide either spi_target_data or msn_params." + ) + raise ValueError(msg) + + if spi_target_data is not None and msn_params is not None: + msg = ( + "Please provide either spi_target_data or msn_params, not both. " + "Got: \n" + f"\n`spi_target_data`: {type(spi_target_data)}\n`msn_params`: {type(msn_params)}" + ) + raise ValueError(msg) + + # Add the SPI layer + return self.add_layer( + layer_class, + on_axis=on_axis, + spi_target_data=spi_target_data, + msn_params=msn_params, + **params, + ) + + def add_density( + self, + on_axis: int | tuple[int, int] | list[int] | None = None, + data: pd.DataFrame | None = None, + *, + include_outline: bool = False, + **params: Any, + ) -> Any: + """ + Add a density layer to the plot. + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes: + - int: Index of subplot (flattened) + - tuple: (row, col) coordinates + - list: Multiple indices to apply the layer to + - None: Apply to all subplots (default) + data : pd.DataFrame | None, optional + Custom data for this specific layer, by default None + include_outline : bool, optional + Whether to include an outline around the density plot, by default False + **params : Any + Additional parameters for the density layer + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.plotting import ISOPlot + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0, 0.5, 100), + ... 'ISOEventful': rng.normal(0, 0.5, 100), + ... }) + >>> plot = (ISOPlot(data=data) + ... .add_density() + ... .apply_styling()) + >>> plot.show() # xdoctest: +SKIP + >>> plot.close() # Clean up + + """ + return self.add_layer( + DensityLayer, + data=data, + on_axis=on_axis, + include_outline=include_outline, + **params, + ) + + def add_simple_density( + self, + on_axis: int | tuple[int, int] | list[int] | None = None, + data: pd.DataFrame | None = None, + *, + include_outline: bool = True, + **params: Any, + ) -> Any: + """ + Add a simplified density layer to the plot. + + This creates a density plot with fewer contour levels, typically used + to highlight the main density region. + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes: + - int: Index of subplot (flattened) + - tuple: (row, col) coordinates + - list: Multiple indices to apply the layer to + - None: Apply to all subplots (default) + data : pd.DataFrame | None, optional + Custom data for this specific layer, by default None + include_outline : bool, optional + Whether to include an outline around the density plot, by default True + **params : Any + Additional parameters for the simple density layer + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.plotting import ISOPlot + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0, 0.5, 100), + ... 'ISOEventful': rng.normal(0, 0.5, 100), + ... }) + >>> plot = (ISOPlot(data=data) + ... .add_simple_density() + ... .apply_styling()) + >>> plot.show() # xdoctest: +SKIP + >>> plot.close() # Clean up + + """ + return self.add_layer( + SimpleDensityLayer, + data=data, + on_axis=on_axis, + include_outline=include_outline, + **params, + ) + + def add_annotation( + self, + text: str, + xy: tuple[float, float], + xytext: tuple[float, float], + arrowprops: dict[str, Any] | None = None, + ) -> Any: + """ + Add an annotation to the plot. + + Parameters + ---------- + text : str + The text of the annotation + xy : tuple[float, float] + The point (x, y) to annotate + xytext : tuple[float, float] + The position (x, y) to place the text + arrowprops : dict[str, Any] | None, optional + Properties used to draw the arrow, by default None + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.plotting import ISOPlot + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0, 0.5, 100), + ... 'ISOEventful': rng.normal(0, 0.5, 100), + ... }) + >>> plot = (ISOPlot(data=data) + ... .add_scatter() + ... .add_annotation( + ... "Interesting point", + ... xy=(0.5, 0.5), + ... xytext=(0.7, 0.7), + ... arrowprops=dict(arrowstyle="->") + ... ) + ... .apply_styling()) + >>> plot.show() # xdoctest: +SKIP + >>> plot.close() # Clean up + + """ + # Default arrow properties if none provided + if arrowprops is None: + arrowprops = {"arrowstyle": "->"} + + # Get the current axes + ax = self.get_single_axes() + ax.annotate(text, xy=xy, xytext=xytext, arrowprops=arrowprops) + + return self diff --git a/src/soundscapy/plotting/iso_plot_styling.py b/src/soundscapy/plotting/iso_plot_styling.py new file mode 100644 index 0000000..9104032 --- /dev/null +++ b/src/soundscapy/plotting/iso_plot_styling.py @@ -0,0 +1,344 @@ +""" +Styling methods for ISOPlot. + +This module provides a mixin class with methods for styling ISOPlot instances, +including grid lines, labels, and other visual elements. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +from matplotlib import pyplot as plt +from matplotlib import ticker +from matplotlib.artist import Artist +from matplotlib.axes import Axes + +from soundscapy.plotting.defaults import DEFAULT_STYLE_PARAMS +from soundscapy.plotting.plotting_types import ParamModel + +if TYPE_CHECKING: + from matplotlib.figure import Figure + + +class ISOPlotStylingMixin: + """Mixin providing styling methods for ISOPlot.""" + + def apply_styling(self, **kwargs: Any) -> Any: + """ + Apply styling to the plot. + + This method applies various styling elements to the plot, including: + - Setting axis limits and labels + - Adding grid lines + - Adding diagonal lines (if enabled) + - Setting titles + - Configuring legends + + Parameters + ---------- + **kwargs : Any + Additional styling parameters to override defaults + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.plotting import ISOPlot + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0, 0.5, 100), + ... 'ISOEventful': rng.normal(0, 0.5, 100), + ... }) + >>> plot = (ISOPlot(data=data) + ... .add_scatter() + ... .apply_styling( + ... xlim=(-1.5, 1.5), + ... ylim=(-1.5, 1.5), + ... diagonal_lines=True + ... )) + >>> plot.show() # xdoctest: +SKIP + >>> plot.close() # Clean up + + """ + # Update style parameters with provided kwargs + self._style_params = ParamModel.create("style", **{**DEFAULT_STYLE_PARAMS, **kwargs}) + + # Check if we have axes to style + self._check_for_axes() + + # Apply styling to each axes + for ax in self.yield_axes_objects(): + # Set axis limits + ax.set_xlim(self._style_params.xlim) + ax.set_ylim(self._style_params.ylim) + + # Set up grid + self._circumplex_grid(ax) + + # Add primary lines if enabled + if self._style_params.primary_lines: + self._primary_lines(ax) + self._primary_labels(ax) + + # Add diagonal lines if enabled + if self._style_params.diagonal_lines: + self._diagonal_lines_and_labels(ax) + + # Set titles + self._set_title() + self._set_axes_titles() + + # Move legend if needed + if self._style_params.legend_loc and self._style_params.legend_loc is not False: + self._move_legend() + + return self + + def _set_style(self) -> None: + """Set the style for the plot.""" + plt.style.use("seaborn-v0_8-whitegrid") + + def _circumplex_grid(self, ax: Axes) -> None: + """ + Set up the grid for a circumplex plot. + + Parameters + ---------- + ax : Axes + The axes to set up the grid for + """ + # Set up grid + ax.grid(True, linestyle="--", alpha=0.7, zorder=0) + + # Set up minor ticks + ax.xaxis.set_minor_locator(ticker.AutoMinorLocator(2)) + ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(2)) + + # Set up tick parameters + ax.tick_params(which="both", direction="in") + ax.tick_params(which="minor", length=4) + ax.tick_params(which="major", length=7) + + def _set_title(self) -> None: + """Set the title for the plot.""" + # If we have a figure title and no subplots, set it on the first axes + if not self._has_subplots and self.title is not None: + ax = self.get_single_axes() + ax.set_title(self.title, fontsize=self._style_params.title_fontsize) + + # If we have a figure title and subplots, set it on the figure + elif self._has_subplots and self.title is not None: + fig = self.get_figure() + fig.suptitle(self.title, fontsize=self._style_params.title_fontsize) + + def _set_axes_titles(self) -> None: + """Set titles for individual axes in subplots.""" + # Set titles for individual subplot contexts if they have titles + for i, context in enumerate(self.subplot_contexts): + if context.title is not None: + ax = self.get_single_axes(i) + ax.set_title(context.title, fontsize=self._style_params.title_fontsize) + + def _primary_lines(self, ax: Axes) -> None: + """ + Add primary axis lines to the plot. + + Parameters + ---------- + ax : Axes + The axes to add the lines to + """ + # Add horizontal and vertical lines at x=0 and y=0 + ax.axhline( + y=0, + color="black", + linestyle="-", + linewidth=self._style_params.linewidth, + zorder=self._style_params.prim_lines_zorder, + ) + ax.axvline( + x=0, + color="black", + linestyle="-", + linewidth=self._style_params.linewidth, + zorder=self._style_params.prim_lines_zorder, + ) + + def _primary_labels(self, ax: Axes) -> None: + """ + Add labels for primary axes. + + Parameters + ---------- + ax : Axes + The axes to add the labels to + """ + # Set x and y labels if they are provided + if self._style_params.xlabel is not False: + xlabel = self._style_params.xlabel or self.x + ax.set_xlabel( + xlabel, + fontdict=self._style_params.prim_ax_fontdict, + ) + + if self._style_params.ylabel is not False: + ylabel = self._style_params.ylabel or self.y + ax.set_ylabel( + ylabel, + fontdict=self._style_params.prim_ax_fontdict, + ) + + def _diagonal_lines_and_labels(self, ax: Axes) -> None: + """ + Add diagonal lines and labels to the plot. + + Parameters + ---------- + ax : Axes + The axes to add the diagonal lines and labels to + """ + # Get axis limits + xlim = ax.get_xlim() + ylim = ax.get_ylim() + + # Calculate diagonal line endpoints + x_range = xlim[1] - xlim[0] + y_range = ylim[1] - ylim[0] + + # Determine the smaller range to keep lines within bounds + range_min = min(x_range, y_range) + + # Calculate line endpoints + diag1_start = (xlim[0], ylim[0]) + diag1_end = (xlim[0] + range_min, ylim[0] + range_min) + + diag2_start = (xlim[1], ylim[0]) + diag2_end = (xlim[1] - range_min, ylim[0] + range_min) + + # Draw diagonal lines + ax.plot( + [diag1_start[0], diag1_end[0]], + [diag1_start[1], diag1_end[1]], + color="black", + linestyle="--", + linewidth=self._style_params.linewidth, + zorder=self._style_params.diag_lines_zorder, + ) + + ax.plot( + [diag2_start[0], diag2_end[0]], + [diag2_start[1], diag2_end[1]], + color="black", + linestyle="--", + linewidth=self._style_params.linewidth, + zorder=self._style_params.diag_lines_zorder, + ) + + # Add diagonal labels + # Calculate positions for labels + label_offset = 0.05 # Offset from the end of the line + + # First diagonal (bottom-left to top-right) + diag1_label_pos = ( + diag1_end[0] - label_offset * x_range, + diag1_end[1] + label_offset * y_range, + ) + + # Second diagonal (bottom-right to top-left) + diag2_label_pos = ( + diag2_end[0] + label_offset * x_range, + diag2_end[1] + label_offset * y_range, + ) + + # Add the labels + ax.text( + diag1_label_pos[0], + diag1_label_pos[1], + "Exciting", + ha="right", + va="bottom", + rotation=45, + zorder=self._style_params.diag_labels_zorder, + ) + + ax.text( + diag2_label_pos[0], + diag2_label_pos[1], + "Chaotic", + ha="left", + va="bottom", + rotation=-45, + zorder=self._style_params.diag_labels_zorder, + ) + + # Add labels for the bottom half of the diagonals + diag3_label_pos = ( + diag1_start[0] + label_offset * x_range, + diag1_start[1] - label_offset * y_range, + ) + + diag4_label_pos = ( + diag2_start[0] - label_offset * x_range, + diag2_start[1] - label_offset * y_range, + ) + + ax.text( + diag3_label_pos[0], + diag3_label_pos[1], + "Boring", + ha="left", + va="top", + rotation=45, + zorder=self._style_params.diag_labels_zorder, + ) + + ax.text( + diag4_label_pos[0], + diag4_label_pos[1], + "Calm", + ha="right", + va="top", + rotation=-45, + zorder=self._style_params.diag_labels_zorder, + ) + + def _move_legend(self) -> None: + """Move the legend to the specified location.""" + # Get the figure + fig = self.get_figure() + + # Find all legends in the figure + legends = [] + for ax in self.yield_axes_objects(): + legend = ax.get_legend() + if legend is not None: + legends.append(legend) + + # If we have legends, move them to the specified location + if legends: + for legend in legends: + legend.set_zorder(100) # Ensure legend is on top + + # If legend_loc is specified, move the legend + if self._style_params.legend_loc: + # Remove the legend from its current position + legend.remove() + + # Get the axes the legend belongs to + ax = legend.axes + + # Add the legend back at the specified location + handles, labels = ax.get_legend_handles_labels() + if handles: + ax.legend( + handles, + labels, + loc=self._style_params.legend_loc, + ) \ No newline at end of file From 3008b690547eff8c197fac874e405adae19c5d71 Mon Sep 17 00:00:00 2001 From: Andrew Mitchell Date: Sat, 10 May 2025 02:48:37 +0100 Subject: [PATCH 2/8] Refactor plot examples and update plotting/styling logic Removed redundant comments and streamlined examples for clarity across plotting methods. Improved diagonal line logic, refactored legend handling, and enhanced parameter handling for plotting/styling functions to improve maintainability and consistency.`` Signed-off-by: Andrew Mitchell --- pyproject.toml | 2 + src/soundscapy/plotting/iso_plot.py | 79 +++-- src/soundscapy/plotting/iso_plot_layers.py | 157 +++++++-- src/soundscapy/plotting/iso_plot_styling.py | 365 +++++++++++--------- uv.lock | 122 +++---- 5 files changed, 438 insertions(+), 287 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3247e03..a2b8472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,10 +53,12 @@ classifiers = [ "License :: OSI Approved :: BSD License", ] dependencies = [ + "coverage==7.8.0", "loguru>=0.7.2", "numpy!=1.26", "pandas[excel]>=2.2.2", "pydantic>=2.8.2", + "pytest-cov==6.1.1", "pyyaml>=6.0.2", "scipy>=1.14.1", "seaborn>=0.13.2", diff --git a/src/soundscapy/plotting/iso_plot.py b/src/soundscapy/plotting/iso_plot.py index f40857c..32a1f49 100644 --- a/src/soundscapy/plotting/iso_plot.py +++ b/src/soundscapy/plotting/iso_plot.py @@ -19,7 +19,7 @@ ... .add_simple_density(fill=False) ... .apply_styling() ... ) ->>> isoplot.show() # xdoctest: +SKIP +>>> isoplot.show() """ # ruff: noqa: SLF001, G004 @@ -78,7 +78,38 @@ class ISOPlot(ISOPlotLayersMixin, ISOPlotStylingMixin): ... .add_scatter() ... .add_density() ... .apply_styling()) - >>> cp.show() # xdoctest: +SKIP + >>> cp.show() + + Create a plot with default parameters: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot() + >>> isinstance(plot, ISOPlot) + True + + Create a plot with a DataFrame: + + >>> data = pd.DataFrame( + ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... rng.integers(1, 3, 100)], + ... columns=['ISOPleasant', 'ISOEventful', 'Group']) + >>> plot = ISOPlot(data=data, hue='Group') + >>> plot.hue + 'Group' + + + Create a plot directly with arrays: + + >>> x, y = rng.multivariate_normal([0, 0], [[1, 0], [0, 1]], 100).T + >>> plot = ISOPlot(x=x, y=y) + >>> isinstance(plot, ISOPlot) + True """ @@ -115,39 +146,6 @@ def __init__( axes : Axes | np.ndarray | None, optional Existing axes to plot on, by default None - Examples - -------- - Create a plot with default parameters: - - >>> import pandas as pd - >>> import numpy as np - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame( - ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... columns=['ISOPleasant', 'ISOEventful'] - ... ) - >>> plot = ISOPlot() - >>> isinstance(plot, ISOPlot) - True - - Create a plot with a DataFrame: - - >>> data = pd.DataFrame( - ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... rng.integers(1, 3, 100)], - ... columns=['ISOPleasant', 'ISOEventful', 'Group']) - >>> plot = ISOPlot(data=data, hue='Group') - >>> plot.hue - 'Group' - - - Create a plot directly with arrays: - - >>> x, y = rng.multivariate_normal([0, 0], [[1, 0], [0, 1]], 100).T - >>> plot = ISOPlot(x=x, y=y) - >>> isinstance(plot, ISOPlot) - True - """ # Process and validate input data and coordinates data, x, y = self._check_data_x_y(data, x, y) @@ -668,7 +666,8 @@ def _validate_subplots_datas( ) raise ValueError(msg) - def _allocate_subplot_axes(self, subplot_titles: list[str]) -> tuple[int, int]: + @staticmethod + def _allocate_subplot_axes(subplot_titles: list[str]) -> tuple[int, int]: """Allocate the subplot axes based on the number of data subsets.""" msg = ( "This is an experimental feature. " @@ -924,7 +923,7 @@ def add_layer( ... .create_subplots(nrows=2, ncols=2) ... .add_layer(ScatterLayer) ... .apply_styling()) - >>> plot.show() # xdoctest: +SKIP + >>> plot.show() >>> all(len(ctx.layers) == 1 for ctx in plot.subplot_contexts) True >>> plot.close() # Clean up @@ -935,7 +934,7 @@ def add_layer( ... .create_subplots(nrows=2, ncols=2) ... .add_layer(ScatterLayer, on_axis=0) ... .apply_styling()) - >>> plot.show() # xdoctest: +SKIP + >>> plot.show() >>> len(plot.subplot_contexts[0].layers) == 1 True >>> all(len(ctx.layers) == 0 for ctx in plot.subplot_contexts[1:]) @@ -948,7 +947,7 @@ def add_layer( ... .create_subplots(nrows=2, ncols=2) ... .add_layer(ScatterLayer, on_axis=[0, 2]) ... .apply_styling()) - >>> plot.show() # xdoctest: +SKIP + >>> plot.show() >>> len(plot.subplot_contexts[0].layers) == 1 True >>> len(plot.subplot_contexts[2].layers) == 1 @@ -970,7 +969,7 @@ def add_layer( ... # Add a layer with custom data to the second subplot ... .add_layer(ScatterLayer, data=custom_data, on_axis=1) ... .apply_styling()) - >>> plot.show() # xdoctest: +SKIP + >>> plot.show() >>> plot.close() """ diff --git a/src/soundscapy/plotting/iso_plot_layers.py b/src/soundscapy/plotting/iso_plot_layers.py index b2f730c..f92e6e3 100644 --- a/src/soundscapy/plotting/iso_plot_layers.py +++ b/src/soundscapy/plotting/iso_plot_layers.py @@ -59,10 +59,10 @@ def add_layer( ------- Any The current plot instance for chaining + """ - raise NotImplementedError( - "Classes using ISOPlotLayersMixin must implement add_layer" - ) + msg = "Classes using ISOPlotLayersMixin must implement add_layer" + raise NotImplementedError(msg) def get_single_axes(self, ax_idx: int | tuple[int, int] | None = None) -> Any: """ @@ -81,10 +81,10 @@ def get_single_axes(self, ax_idx: int | tuple[int, int] | None = None) -> Any: ------- Any The requested matplotlib Axes object + """ - raise NotImplementedError( - "Classes using ISOPlotLayersMixin must implement get_single_axes" - ) + msg = "Classes using ISOPlotLayersMixin must implement get_single_axes" + raise NotImplementedError(msg) def add_scatter( self, @@ -116,19 +116,39 @@ def add_scatter( Examples -------- + Add a scatter layer to all subplots: + >>> import pandas as pd >>> import numpy as np >>> from soundscapy.plotting import ISOPlot >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0, 0.5, 100), - ... 'ISOEventful': rng.normal(0, 0.5, 100), - ... 'Group': rng.integers(1, 3, 100) + >>> data = pd.DataFrame( + ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... rng.integers(1, 3, 100)], + ... columns=['ISOPleasant', 'ISOEventful', 'Group']) + >>> plot = (ISOPlot(data=data) + ... .create_subplots(nrows=2, ncols=1) + ... .add_scatter(s=50, alpha=0.7, hue='Group') + ... .apply_styling()) + >>> plot.show() # xdoctest: +SKIP + >>> all(len(ctx.layers) == 1 for ctx in plot.subplot_contexts) + True + >>> plot.close() # Clean up + + Add a scatter layer with custom data to a specific subplot: + + >>> custom_data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0.2, 0.1, 50), + ... 'ISOEventful': rng.normal(0.15, 0.2, 50), ... }) >>> plot = (ISOPlot(data=data) - ... .add_scatter(hue='Group') - ... .apply_styling()) + ... .create_subplots(nrows=2, ncols=1) + ... .add_scatter(hue='Group') + ... .add_scatter(on_axis=0, data=custom_data, color='red') + ... .apply_styling()) >>> plot.show() # xdoctest: +SKIP + >>> plot.subplot_contexts[0].layers[1].custom_data is custom_data + True >>> plot.close() # Clean up """ @@ -173,6 +193,53 @@ def add_spi( Either spi_target_data or msn_params must be provided, but not both. The test data for SPI calculations will be retrieved from the plot context. + Examples + -------- + Add a SPI layer to all subplots: + + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.spi import DirectParams + >>> from soundscapy.plotting import ISOPlot + >>> rng = np.random.default_rng(42) + >>> # Create a DataFrame with random data + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> # Define MSN parameters for the SPI target + >>> msn_params = DirectParams( + ... xi=np.array([0.5, 0.7]), + ... omega=np.array([[0.1, 0.05], [0.05, 0.1]]), + ... alpha=np.array([0, -5]), + ... ) + >>> # Create the plot with only an SPI layer + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .add_spi(msn_params=msn_params) + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 2 + True + >>> plot.close() # Clean up + + Add an SPI layer over top of 'real' data: + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .add_density() + ... .add_spi(msn_params=msn_params, show_score="on axis") + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 3 + True + >>> plot.close() # Clean up + """ # Validate that we have either spi_target_data or msn_params if spi_target_data is None and msn_params is None: @@ -232,18 +299,38 @@ def add_density( Examples -------- + Add a density layer to all subplots: + >>> import pandas as pd >>> import numpy as np >>> from soundscapy.plotting import ISOPlot >>> rng = np.random.default_rng(42) >>> data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0, 0.5, 100), - ... 'ISOEventful': rng.normal(0, 0.5, 100), + ... 'ISOPleasant': rng.normal(0.2, 0.25, 50), + ... 'ISOEventful': rng.normal(0.15, 0.4, 50), ... }) - >>> plot = (ISOPlot(data=data) - ... .add_density() - ... .apply_styling()) + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_density() + ... .apply_styling() + ... ) >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 1 + True + >>> plot.close() # Clean up + + Add a density layer with custom settings: + + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_density(levels=5, alpha=0.7) + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 1 + True >>> plot.close() # Clean up """ @@ -291,20 +378,45 @@ def add_simple_density( Examples -------- + Add a simple density layer: + >>> import pandas as pd >>> import numpy as np >>> from soundscapy.plotting import ISOPlot >>> rng = np.random.default_rng(42) >>> data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0, 0.5, 100), - ... 'ISOEventful': rng.normal(0, 0.5, 100), + ... 'ISOPleasant': rng.normal(0.2, 0.25, 30), + ... 'ISOEventful': rng.normal(0.15, 0.4, 30), ... }) - >>> plot = (ISOPlot(data=data) - ... .add_simple_density() - ... .apply_styling()) + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .add_simple_density() + ... .apply_styling() + ... ) >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 2 + True >>> plot.close() # Clean up + Add a simple density with splitting by group: + >>> data = pd.DataFrame( + ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... rng.integers(1, 3, 100)], + ... columns=['ISOPleasant', 'ISOEventful', 'Group']) + >>> plot = ( + ... ISOPlot(data=data, hue='Group') + ... .create_subplots() + ... .add_scatter() + ... .add_simple_density() + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 2 + True + >>> plot.close() + """ return self.add_layer( SimpleDensityLayer, @@ -351,6 +463,7 @@ def add_annotation( ... 'ISOEventful': rng.normal(0, 0.5, 100), ... }) >>> plot = (ISOPlot(data=data) + ... .create_subplots() ... .add_scatter() ... .add_annotation( ... "Interesting point", diff --git a/src/soundscapy/plotting/iso_plot_styling.py b/src/soundscapy/plotting/iso_plot_styling.py index 9104032..0f034ff 100644 --- a/src/soundscapy/plotting/iso_plot_styling.py +++ b/src/soundscapy/plotting/iso_plot_styling.py @@ -9,17 +9,17 @@ from typing import TYPE_CHECKING, Any -import numpy as np from matplotlib import pyplot as plt from matplotlib import ticker from matplotlib.artist import Artist -from matplotlib.axes import Axes -from soundscapy.plotting.defaults import DEFAULT_STYLE_PARAMS from soundscapy.plotting.plotting_types import ParamModel +from soundscapy.sspylogging import get_logger if TYPE_CHECKING: - from matplotlib.figure import Figure + from matplotlib.axes import Axes + +logger = get_logger() class ISOPlotStylingMixin: @@ -48,30 +48,65 @@ def apply_styling(self, **kwargs: Any) -> Any: Examples -------- + Apply styling with default parameters + >>> import pandas as pd >>> import numpy as np >>> from soundscapy.plotting import ISOPlot >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0, 0.5, 100), - ... 'ISOEventful': rng.normal(0, 0.5, 100), - ... }) - >>> plot = (ISOPlot(data=data) + >>> # Create simple data for styling example + >>> data = pd.DataFrame( + ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... rng.integers(1, 3, 100)], + ... columns=['ISOPleasant', 'ISOEventful', 'Group']) + >>> # Create plot with default styling + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> plot.get_figure() is not None + True + >>> plot.close() # Clean up + + Apply styling with custom parameters: + + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() ... .add_scatter() - ... .apply_styling( - ... xlim=(-1.5, 1.5), - ... ylim=(-1.5, 1.5), - ... diagonal_lines=True - ... )) + ... .apply_styling(xlim=(-2, 2), ylim=(-2, 2), primary_lines=False) + ... ) >>> plot.show() # xdoctest: +SKIP + >>> plot.get_figure() is not None + True + >>> plot.close() # Clean up + + Demonstrate the fluent interface (method chaining): + + >>> # Create plot with method chaining + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots(nrows=1, ncols=1) + ... .add_scatter(alpha=0.7) + ... .add_density(levels=5) + ... .apply_styling(title_fontsize=14) + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> # Verify results + >>> isinstance(plot, ISOPlot) + True >>> plot.close() # Clean up """ # Update style parameters with provided kwargs - self._style_params = ParamModel.create("style", **{**DEFAULT_STYLE_PARAMS, **kwargs}) - + # Use the default values from the StyleParams class and override with kwargs + self._style_params = ParamModel.create("style", **kwargs) # Check if we have axes to style self._check_for_axes() + self._set_style() # Apply styling to each axes for ax in self.yield_axes_objects(): @@ -101,7 +136,8 @@ def apply_styling(self, **kwargs: Any) -> Any: return self - def _set_style(self) -> None: + @staticmethod + def _set_style() -> None: """Set the style for the plot.""" plt.style.use("seaborn-v0_8-whitegrid") @@ -113,30 +149,45 @@ def _circumplex_grid(self, ax: Axes) -> None: ---------- ax : Axes The axes to set up the grid for + """ # Set up grid - ax.grid(True, linestyle="--", alpha=0.7, zorder=0) - - # Set up minor ticks - ax.xaxis.set_minor_locator(ticker.AutoMinorLocator(2)) - ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(2)) - - # Set up tick parameters - ax.tick_params(which="both", direction="in") - ax.tick_params(which="minor", length=4) - ax.tick_params(which="major", length=7) + ax.set_xlim(self._style_params.get("xlim")) + ax.set_ylim(self._style_params.get("ylim")) + ax.set_aspect("equal") + + ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator()) + ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator()) + + ax.grid(visible=True, which="major", color="grey", alpha=0.5) + ax.grid( + visible=True, + which="minor", + color="grey", + linestyle="dashed", + linewidth=0.5, + alpha=0.4, + zorder=self._style_params.get("prim_lines_zorder"), + ) def _set_title(self) -> None: - """Set the title for the plot.""" - # If we have a figure title and no subplots, set it on the first axes - if not self._has_subplots and self.title is not None: - ax = self.get_single_axes() - ax.set_title(self.title, fontsize=self._style_params.title_fontsize) - - # If we have a figure title and subplots, set it on the figure - elif self._has_subplots and self.title is not None: - fig = self.get_figure() - fig.suptitle(self.title, fontsize=self._style_params.title_fontsize) + """Set the title of the plot.""" + if self.title and self._has_subplots: + figure = self.get_figure() + figure.suptitle( + self.title, fontsize=self._style_params.get("title_fontsize") + ) + elif self.title and not self._has_subplots: + axis = self.get_single_axes() + if axis.get_title() == "": + axis.set_title( + self.title, fontsize=self._style_params.get("title_fontsize") + ) + else: + figure = self.get_figure() + figure.suptitle( + self.title, fontsize=self._style_params.get("title_fontsize") + ) def _set_axes_titles(self) -> None: """Set titles for individual axes in subplots.""" @@ -154,21 +205,24 @@ def _primary_lines(self, ax: Axes) -> None: ---------- ax : Axes The axes to add the lines to + """ # Add horizontal and vertical lines at x=0 and y=0 ax.axhline( y=0, - color="black", - linestyle="-", - linewidth=self._style_params.linewidth, - zorder=self._style_params.prim_lines_zorder, + color="grey", + linestyle="dashed", + alpha=1, + lw=self._style_params.get("linewidth"), + zorder=self._style_params.get("prim_lines_zorder"), ) ax.axvline( x=0, - color="black", - linestyle="-", - linewidth=self._style_params.linewidth, - zorder=self._style_params.prim_lines_zorder, + color="grey", + linestyle="dashed", + alpha=1, + lw=self._style_params.get("linewidth"), + zorder=self._style_params.get("prim_lines_zorder"), ) def _primary_labels(self, ax: Axes) -> None: @@ -179,6 +233,7 @@ def _primary_labels(self, ax: Axes) -> None: ---------- ax : Axes The axes to add the labels to + """ # Set x and y labels if they are provided if self._style_params.xlabel is not False: @@ -187,7 +242,7 @@ def _primary_labels(self, ax: Axes) -> None: xlabel, fontdict=self._style_params.prim_ax_fontdict, ) - + if self._style_params.ylabel is not False: ylabel = self._style_params.ylabel or self.y ax.set_ylabel( @@ -203,142 +258,120 @@ def _diagonal_lines_and_labels(self, ax: Axes) -> None: ---------- ax : Axes The axes to add the diagonal lines and labels to + + Examples + -------- + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.plotting import ISOPlot + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful']) + >>> # Create a plot with diagonal lines and labels + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .apply_styling(diagonal_lines=True) + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> plot.close('all') + """ # Get axis limits xlim = ax.get_xlim() ylim = ax.get_ylim() - - # Calculate diagonal line endpoints - x_range = xlim[1] - xlim[0] - y_range = ylim[1] - ylim[0] - - # Determine the smaller range to keep lines within bounds - range_min = min(x_range, y_range) - - # Calculate line endpoints - diag1_start = (xlim[0], ylim[0]) - diag1_end = (xlim[0] + range_min, ylim[0] + range_min) - - diag2_start = (xlim[1], ylim[0]) - diag2_end = (xlim[1] - range_min, ylim[0] + range_min) - - # Draw diagonal lines + ax.plot( - [diag1_start[0], diag1_end[0]], - [diag1_start[1], diag1_end[1]], - color="black", - linestyle="--", - linewidth=self._style_params.linewidth, - zorder=self._style_params.diag_lines_zorder, + xlim, + ylim, + linestyle="dashed", + color="grey", + alpha=0.5, + lw=self._style_params.get("linewidth"), + zorder=self._style_params.get("diag_lines_zorder"), ) - + logger.debug("Plotting diagonal line for axis.") ax.plot( - [diag2_start[0], diag2_end[0]], - [diag2_start[1], diag2_end[1]], - color="black", - linestyle="--", - linewidth=self._style_params.linewidth, - zorder=self._style_params.diag_lines_zorder, - ) - - # Add diagonal labels - # Calculate positions for labels - label_offset = 0.05 # Offset from the end of the line - - # First diagonal (bottom-left to top-right) - diag1_label_pos = ( - diag1_end[0] - label_offset * x_range, - diag1_end[1] + label_offset * y_range, + xlim, + ylim[::-1], + linestyle="dashed", + color="grey", + alpha=0.5, + lw=self._style_params.get("linewidth"), + zorder=self._style_params.get("diag_lines_zorder"), ) - - # Second diagonal (bottom-right to top-left) - diag2_label_pos = ( - diag2_end[0] + label_offset * x_range, - diag2_end[1] + label_offset * y_range, - ) - - # Add the labels + + diag_ax_font = { + "fontstyle": "italic", + "fontsize": "small", + "fontweight": "bold", + "color": "black", + "alpha": 0.5, + } ax.text( - diag1_label_pos[0], - diag1_label_pos[1], - "Exciting", - ha="right", - va="bottom", - rotation=45, - zorder=self._style_params.diag_labels_zorder, + xlim[1] / 2, + ylim[1] / 2, + "(vibrant)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=self._style_params.get("diag_labels_zorder"), ) - ax.text( - diag2_label_pos[0], - diag2_label_pos[1], - "Chaotic", - ha="left", - va="bottom", - rotation=-45, - zorder=self._style_params.diag_labels_zorder, + xlim[0] / 2, + ylim[1] / 2, + "(chaotic)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=self._style_params.get("diag_labels_zorder"), ) - - # Add labels for the bottom half of the diagonals - diag3_label_pos = ( - diag1_start[0] + label_offset * x_range, - diag1_start[1] - label_offset * y_range, - ) - - diag4_label_pos = ( - diag2_start[0] - label_offset * x_range, - diag2_start[1] - label_offset * y_range, - ) - ax.text( - diag3_label_pos[0], - diag3_label_pos[1], - "Boring", - ha="left", - va="top", - rotation=45, - zorder=self._style_params.diag_labels_zorder, + xlim[0] / 2, + ylim[0] / 2, + "(monotonous)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=self._style_params.get("diag_labels_zorder"), ) - ax.text( - diag4_label_pos[0], - diag4_label_pos[1], - "Calm", - ha="right", - va="top", - rotation=-45, - zorder=self._style_params.diag_labels_zorder, + xlim[1] / 2, + ylim[0] / 2, + "(calm)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=self._style_params.get("diag_labels_zorder"), ) def _move_legend(self) -> None: """Move the legend to the specified location.""" - # Get the figure - fig = self.get_figure() - - # Find all legends in the figure - legends = [] - for ax in self.yield_axes_objects(): - legend = ax.get_legend() - if legend is not None: - legends.append(legend) - - # If we have legends, move them to the specified location - if legends: - for legend in legends: - legend.set_zorder(100) # Ensure legend is on top - - # If legend_loc is specified, move the legend - if self._style_params.legend_loc: - # Remove the legend from its current position - legend.remove() - - # Get the axes the legend belongs to - ax = legend.axes - - # Add the legend back at the specified location - handles, labels = ax.get_legend_handles_labels() - if handles: - ax.legend( - handles, - labels, - loc=self._style_params.legend_loc, - ) \ No newline at end of file + for i, axis in enumerate(self.yield_axes_objects()): + old_legend = axis.get_legend() + if old_legend is None: + # logger.debug("_move_legend: No legend found for axis %s", i) + continue + + # Get handles and filter out None values + handles = [ + h for h in old_legend.legend_handles if isinstance(h, Artist | tuple) + ] + # Skip if no valid handles remain + if not handles: + continue + + labels = [t.get_text() for t in old_legend.get_texts()] + title = old_legend.get_title().get_text() + # Ensure labels and handles match in length + if len(handles) != len(labels): + labels = labels[: len(handles)] + + axis.legend( + handles, + labels, + loc=self._style_params.get("legend_loc"), + title=title, + ) diff --git a/uv.lock b/uv.lock index c10835d..90f94bf 100644 --- a/uv.lock +++ b/uv.lock @@ -478,62 +478,62 @@ wheels = [ [[package]] name = "coverage" -version = "7.6.12" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0c/d6/2b53ab3ee99f2262e6f0b8369a43f6d66658eab45510331c0b3d5c8c4272/coverage-7.6.12.tar.gz", hash = "sha256:48cfc4641d95d34766ad41d9573cc0f22a48aa88d22657a1fe01dca0dbae4de2", size = 805941, upload-time = "2025-02-11T14:47:03.797Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ba/67/81dc41ec8f548c365d04a29f1afd492d3176b372c33e47fa2a45a01dc13a/coverage-7.6.12-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:704c8c8c6ce6569286ae9622e534b4f5b9759b6f2cd643f1c1a61f666d534fe8", size = 208345, upload-time = "2025-02-11T14:44:51.83Z" }, - { url = "https://files.pythonhosted.org/packages/33/43/17f71676016c8829bde69e24c852fef6bd9ed39f774a245d9ec98f689fa0/coverage-7.6.12-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ad7525bf0241e5502168ae9c643a2f6c219fa0a283001cee4cf23a9b7da75879", size = 208775, upload-time = "2025-02-11T14:44:54.852Z" }, - { url = "https://files.pythonhosted.org/packages/86/25/c6ff0775f8960e8c0840845b723eed978d22a3cd9babd2b996e4a7c502c6/coverage-7.6.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06097c7abfa611c91edb9e6920264e5be1d6ceb374efb4986f38b09eed4cb2fe", size = 237925, upload-time = "2025-02-11T14:44:56.675Z" }, - { url = "https://files.pythonhosted.org/packages/b0/3d/5f5bd37046243cb9d15fff2c69e498c2f4fe4f9b42a96018d4579ed3506f/coverage-7.6.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:220fa6c0ad7d9caef57f2c8771918324563ef0d8272c94974717c3909664e674", size = 235835, upload-time = "2025-02-11T14:44:59.007Z" }, - { url = "https://files.pythonhosted.org/packages/b5/f1/9e6b75531fe33490b910d251b0bf709142e73a40e4e38a3899e6986fe088/coverage-7.6.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3688b99604a24492bcfe1c106278c45586eb819bf66a654d8a9a1433022fb2eb", size = 236966, upload-time = "2025-02-11T14:45:02.744Z" }, - { url = "https://files.pythonhosted.org/packages/4f/bc/aef5a98f9133851bd1aacf130e754063719345d2fb776a117d5a8d516971/coverage-7.6.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d1a987778b9c71da2fc8948e6f2656da6ef68f59298b7e9786849634c35d2c3c", size = 236080, upload-time = "2025-02-11T14:45:05.416Z" }, - { url = "https://files.pythonhosted.org/packages/eb/d0/56b4ab77f9b12aea4d4c11dc11cdcaa7c29130b837eb610639cf3400c9c3/coverage-7.6.12-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:cec6b9ce3bd2b7853d4a4563801292bfee40b030c05a3d29555fd2a8ee9bd68c", size = 234393, upload-time = "2025-02-11T14:45:08.627Z" }, - { url = "https://files.pythonhosted.org/packages/0d/77/28ef95c5d23fe3dd191a0b7d89c82fea2c2d904aef9315daf7c890e96557/coverage-7.6.12-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ace9048de91293e467b44bce0f0381345078389814ff6e18dbac8fdbf896360e", size = 235536, upload-time = "2025-02-11T14:45:10.313Z" }, - { url = "https://files.pythonhosted.org/packages/29/62/18791d3632ee3ff3f95bc8599115707d05229c72db9539f208bb878a3d88/coverage-7.6.12-cp310-cp310-win32.whl", hash = "sha256:ea31689f05043d520113e0552f039603c4dd71fa4c287b64cb3606140c66f425", size = 211063, upload-time = "2025-02-11T14:45:12.278Z" }, - { url = "https://files.pythonhosted.org/packages/fc/57/b3878006cedfd573c963e5c751b8587154eb10a61cc0f47a84f85c88a355/coverage-7.6.12-cp310-cp310-win_amd64.whl", hash = "sha256:676f92141e3c5492d2a1596d52287d0d963df21bf5e55c8b03075a60e1ddf8aa", size = 211955, upload-time = "2025-02-11T14:45:14.579Z" }, - { url = "https://files.pythonhosted.org/packages/64/2d/da78abbfff98468c91fd63a73cccdfa0e99051676ded8dd36123e3a2d4d5/coverage-7.6.12-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e18aafdfb3e9ec0d261c942d35bd7c28d031c5855dadb491d2723ba54f4c3015", size = 208464, upload-time = "2025-02-11T14:45:18.314Z" }, - { url = "https://files.pythonhosted.org/packages/31/f2/c269f46c470bdabe83a69e860c80a82e5e76840e9f4bbd7f38f8cebbee2f/coverage-7.6.12-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:66fe626fd7aa5982cdebad23e49e78ef7dbb3e3c2a5960a2b53632f1f703ea45", size = 208893, upload-time = "2025-02-11T14:45:19.881Z" }, - { url = "https://files.pythonhosted.org/packages/47/63/5682bf14d2ce20819998a49c0deadb81e608a59eed64d6bc2191bc8046b9/coverage-7.6.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ef01d70198431719af0b1f5dcbefc557d44a190e749004042927b2a3fed0702", size = 241545, upload-time = "2025-02-11T14:45:22.215Z" }, - { url = "https://files.pythonhosted.org/packages/6a/b6/6b6631f1172d437e11067e1c2edfdb7238b65dff965a12bce3b6d1bf2be2/coverage-7.6.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e92ae5a289a4bc4c0aae710c0948d3c7892e20fd3588224ebe242039573bf0", size = 239230, upload-time = "2025-02-11T14:45:24.864Z" }, - { url = "https://files.pythonhosted.org/packages/c7/01/9cd06cbb1be53e837e16f1b4309f6357e2dfcbdab0dd7cd3b1a50589e4e1/coverage-7.6.12-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e695df2c58ce526eeab11a2e915448d3eb76f75dffe338ea613c1201b33bab2f", size = 241013, upload-time = "2025-02-11T14:45:27.203Z" }, - { url = "https://files.pythonhosted.org/packages/4b/26/56afefc03c30871326e3d99709a70d327ac1f33da383cba108c79bd71563/coverage-7.6.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d74c08e9aaef995f8c4ef6d202dbd219c318450fe2a76da624f2ebb9c8ec5d9f", size = 239750, upload-time = "2025-02-11T14:45:29.577Z" }, - { url = "https://files.pythonhosted.org/packages/dd/ea/88a1ff951ed288f56aa561558ebe380107cf9132facd0b50bced63ba7238/coverage-7.6.12-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e995b3b76ccedc27fe4f477b349b7d64597e53a43fc2961db9d3fbace085d69d", size = 238462, upload-time = "2025-02-11T14:45:31.096Z" }, - { url = "https://files.pythonhosted.org/packages/6e/d4/1d9404566f553728889409eff82151d515fbb46dc92cbd13b5337fa0de8c/coverage-7.6.12-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b1f097878d74fe51e1ddd1be62d8e3682748875b461232cf4b52ddc6e6db0bba", size = 239307, upload-time = "2025-02-11T14:45:32.713Z" }, - { url = "https://files.pythonhosted.org/packages/12/c1/e453d3b794cde1e232ee8ac1d194fde8e2ba329c18bbf1b93f6f5eef606b/coverage-7.6.12-cp311-cp311-win32.whl", hash = "sha256:1f7ffa05da41754e20512202c866d0ebfc440bba3b0ed15133070e20bf5aeb5f", size = 211117, upload-time = "2025-02-11T14:45:34.228Z" }, - { url = "https://files.pythonhosted.org/packages/d5/db/829185120c1686fa297294f8fcd23e0422f71070bf85ef1cc1a72ecb2930/coverage-7.6.12-cp311-cp311-win_amd64.whl", hash = "sha256:e216c5c45f89ef8971373fd1c5d8d1164b81f7f5f06bbf23c37e7908d19e8558", size = 212019, upload-time = "2025-02-11T14:45:35.724Z" }, - { url = "https://files.pythonhosted.org/packages/e2/7f/4af2ed1d06ce6bee7eafc03b2ef748b14132b0bdae04388e451e4b2c529b/coverage-7.6.12-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b172f8e030e8ef247b3104902cc671e20df80163b60a203653150d2fc204d1ad", size = 208645, upload-time = "2025-02-11T14:45:37.95Z" }, - { url = "https://files.pythonhosted.org/packages/dc/60/d19df912989117caa95123524d26fc973f56dc14aecdec5ccd7d0084e131/coverage-7.6.12-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:641dfe0ab73deb7069fb972d4d9725bf11c239c309ce694dd50b1473c0f641c3", size = 208898, upload-time = "2025-02-11T14:45:40.27Z" }, - { url = "https://files.pythonhosted.org/packages/bd/10/fecabcf438ba676f706bf90186ccf6ff9f6158cc494286965c76e58742fa/coverage-7.6.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e549f54ac5f301e8e04c569dfdb907f7be71b06b88b5063ce9d6953d2d58574", size = 242987, upload-time = "2025-02-11T14:45:43.982Z" }, - { url = "https://files.pythonhosted.org/packages/4c/53/4e208440389e8ea936f5f2b0762dcd4cb03281a7722def8e2bf9dc9c3d68/coverage-7.6.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:959244a17184515f8c52dcb65fb662808767c0bd233c1d8a166e7cf74c9ea985", size = 239881, upload-time = "2025-02-11T14:45:45.537Z" }, - { url = "https://files.pythonhosted.org/packages/c4/47/2ba744af8d2f0caa1f17e7746147e34dfc5f811fb65fc153153722d58835/coverage-7.6.12-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bda1c5f347550c359f841d6614fb8ca42ae5cb0b74d39f8a1e204815ebe25750", size = 242142, upload-time = "2025-02-11T14:45:47.069Z" }, - { url = "https://files.pythonhosted.org/packages/e9/90/df726af8ee74d92ee7e3bf113bf101ea4315d71508952bd21abc3fae471e/coverage-7.6.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1ceeb90c3eda1f2d8c4c578c14167dbd8c674ecd7d38e45647543f19839dd6ea", size = 241437, upload-time = "2025-02-11T14:45:48.602Z" }, - { url = "https://files.pythonhosted.org/packages/f6/af/995263fd04ae5f9cf12521150295bf03b6ba940d0aea97953bb4a6db3e2b/coverage-7.6.12-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f16f44025c06792e0fb09571ae454bcc7a3ec75eeb3c36b025eccf501b1a4c3", size = 239724, upload-time = "2025-02-11T14:45:51.333Z" }, - { url = "https://files.pythonhosted.org/packages/1c/8e/5bb04f0318805e190984c6ce106b4c3968a9562a400180e549855d8211bd/coverage-7.6.12-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b076e625396e787448d27a411aefff867db2bffac8ed04e8f7056b07024eed5a", size = 241329, upload-time = "2025-02-11T14:45:53.19Z" }, - { url = "https://files.pythonhosted.org/packages/9e/9d/fa04d9e6c3f6459f4e0b231925277cfc33d72dfab7fa19c312c03e59da99/coverage-7.6.12-cp312-cp312-win32.whl", hash = "sha256:00b2086892cf06c7c2d74983c9595dc511acca00665480b3ddff749ec4fb2a95", size = 211289, upload-time = "2025-02-11T14:45:54.74Z" }, - { url = "https://files.pythonhosted.org/packages/53/40/53c7ffe3c0c3fff4d708bc99e65f3d78c129110d6629736faf2dbd60ad57/coverage-7.6.12-cp312-cp312-win_amd64.whl", hash = "sha256:7ae6eabf519bc7871ce117fb18bf14e0e343eeb96c377667e3e5dd12095e0288", size = 212079, upload-time = "2025-02-11T14:45:57.22Z" }, - { url = "https://files.pythonhosted.org/packages/76/89/1adf3e634753c0de3dad2f02aac1e73dba58bc5a3a914ac94a25b2ef418f/coverage-7.6.12-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:488c27b3db0ebee97a830e6b5a3ea930c4a6e2c07f27a5e67e1b3532e76b9ef1", size = 208673, upload-time = "2025-02-11T14:45:59.618Z" }, - { url = "https://files.pythonhosted.org/packages/ce/64/92a4e239d64d798535c5b45baac6b891c205a8a2e7c9cc8590ad386693dc/coverage-7.6.12-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d1095bbee1851269f79fd8e0c9b5544e4c00c0c24965e66d8cba2eb5bb535fd", size = 208945, upload-time = "2025-02-11T14:46:01.869Z" }, - { url = "https://files.pythonhosted.org/packages/b4/d0/4596a3ef3bca20a94539c9b1e10fd250225d1dec57ea78b0867a1cf9742e/coverage-7.6.12-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0533adc29adf6a69c1baa88c3d7dbcaadcffa21afbed3ca7a225a440e4744bf9", size = 242484, upload-time = "2025-02-11T14:46:03.527Z" }, - { url = "https://files.pythonhosted.org/packages/1c/ef/6fd0d344695af6718a38d0861408af48a709327335486a7ad7e85936dc6e/coverage-7.6.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:53c56358d470fa507a2b6e67a68fd002364d23c83741dbc4c2e0680d80ca227e", size = 239525, upload-time = "2025-02-11T14:46:05.973Z" }, - { url = "https://files.pythonhosted.org/packages/0c/4b/373be2be7dd42f2bcd6964059fd8fa307d265a29d2b9bcf1d044bcc156ed/coverage-7.6.12-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64cbb1a3027c79ca6310bf101014614f6e6e18c226474606cf725238cf5bc2d4", size = 241545, upload-time = "2025-02-11T14:46:07.79Z" }, - { url = "https://files.pythonhosted.org/packages/a6/7d/0e83cc2673a7790650851ee92f72a343827ecaaea07960587c8f442b5cd3/coverage-7.6.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:79cac3390bfa9836bb795be377395f28410811c9066bc4eefd8015258a7578c6", size = 241179, upload-time = "2025-02-11T14:46:11.853Z" }, - { url = "https://files.pythonhosted.org/packages/ff/8c/566ea92ce2bb7627b0900124e24a99f9244b6c8c92d09ff9f7633eb7c3c8/coverage-7.6.12-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:9b148068e881faa26d878ff63e79650e208e95cf1c22bd3f77c3ca7b1d9821a3", size = 239288, upload-time = "2025-02-11T14:46:13.411Z" }, - { url = "https://files.pythonhosted.org/packages/7d/e4/869a138e50b622f796782d642c15fb5f25a5870c6d0059a663667a201638/coverage-7.6.12-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8bec2ac5da793c2685ce5319ca9bcf4eee683b8a1679051f8e6ec04c4f2fd7dc", size = 241032, upload-time = "2025-02-11T14:46:15.005Z" }, - { url = "https://files.pythonhosted.org/packages/ae/28/a52ff5d62a9f9e9fe9c4f17759b98632edd3a3489fce70154c7d66054dd3/coverage-7.6.12-cp313-cp313-win32.whl", hash = "sha256:200e10beb6ddd7c3ded322a4186313d5ca9e63e33d8fab4faa67ef46d3460af3", size = 211315, upload-time = "2025-02-11T14:46:16.638Z" }, - { url = "https://files.pythonhosted.org/packages/bc/17/ab849b7429a639f9722fa5628364c28d675c7ff37ebc3268fe9840dda13c/coverage-7.6.12-cp313-cp313-win_amd64.whl", hash = "sha256:2b996819ced9f7dbb812c701485d58f261bef08f9b85304d41219b1496b591ef", size = 212099, upload-time = "2025-02-11T14:46:18.268Z" }, - { url = "https://files.pythonhosted.org/packages/d2/1c/b9965bf23e171d98505eb5eb4fb4d05c44efd256f2e0f19ad1ba8c3f54b0/coverage-7.6.12-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:299cf973a7abff87a30609879c10df0b3bfc33d021e1adabc29138a48888841e", size = 209511, upload-time = "2025-02-11T14:46:20.768Z" }, - { url = "https://files.pythonhosted.org/packages/57/b3/119c201d3b692d5e17784fee876a9a78e1b3051327de2709392962877ca8/coverage-7.6.12-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4b467a8c56974bf06e543e69ad803c6865249d7a5ccf6980457ed2bc50312703", size = 209729, upload-time = "2025-02-11T14:46:22.258Z" }, - { url = "https://files.pythonhosted.org/packages/52/4e/a7feb5a56b266304bc59f872ea07b728e14d5a64f1ad3a2cc01a3259c965/coverage-7.6.12-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2458f275944db8129f95d91aee32c828a408481ecde3b30af31d552c2ce284a0", size = 253988, upload-time = "2025-02-11T14:46:23.999Z" }, - { url = "https://files.pythonhosted.org/packages/65/19/069fec4d6908d0dae98126aa7ad08ce5130a6decc8509da7740d36e8e8d2/coverage-7.6.12-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a9d8be07fb0832636a0f72b80d2a652fe665e80e720301fb22b191c3434d924", size = 249697, upload-time = "2025-02-11T14:46:25.617Z" }, - { url = "https://files.pythonhosted.org/packages/1c/da/5b19f09ba39df7c55f77820736bf17bbe2416bbf5216a3100ac019e15839/coverage-7.6.12-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14d47376a4f445e9743f6c83291e60adb1b127607a3618e3185bbc8091f0467b", size = 252033, upload-time = "2025-02-11T14:46:28.069Z" }, - { url = "https://files.pythonhosted.org/packages/1e/89/4c2750df7f80a7872267f7c5fe497c69d45f688f7b3afe1297e52e33f791/coverage-7.6.12-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b95574d06aa9d2bd6e5cc35a5bbe35696342c96760b69dc4287dbd5abd4ad51d", size = 251535, upload-time = "2025-02-11T14:46:29.818Z" }, - { url = "https://files.pythonhosted.org/packages/78/3b/6d3ae3c1cc05f1b0460c51e6f6dcf567598cbd7c6121e5ad06643974703c/coverage-7.6.12-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:ecea0c38c9079570163d663c0433a9af4094a60aafdca491c6a3d248c7432827", size = 249192, upload-time = "2025-02-11T14:46:31.563Z" }, - { url = "https://files.pythonhosted.org/packages/6e/8e/c14a79f535ce41af7d436bbad0d3d90c43d9e38ec409b4770c894031422e/coverage-7.6.12-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2251fabcfee0a55a8578a9d29cecfee5f2de02f11530e7d5c5a05859aa85aee9", size = 250627, upload-time = "2025-02-11T14:46:33.145Z" }, - { url = "https://files.pythonhosted.org/packages/cb/79/b7cee656cfb17a7f2c1b9c3cee03dd5d8000ca299ad4038ba64b61a9b044/coverage-7.6.12-cp313-cp313t-win32.whl", hash = "sha256:eb5507795caabd9b2ae3f1adc95f67b1104971c22c624bb354232d65c4fc90b3", size = 212033, upload-time = "2025-02-11T14:46:35.79Z" }, - { url = "https://files.pythonhosted.org/packages/b6/c3/f7aaa3813f1fa9a4228175a7bd368199659d392897e184435a3b66408dd3/coverage-7.6.12-cp313-cp313t-win_amd64.whl", hash = "sha256:f60a297c3987c6c02ffb29effc70eadcbb412fe76947d394a1091a3615948e2f", size = 213240, upload-time = "2025-02-11T14:46:38.119Z" }, - { url = "https://files.pythonhosted.org/packages/7a/7f/05818c62c7afe75df11e0233bd670948d68b36cdbf2a339a095bc02624a8/coverage-7.6.12-pp39.pp310-none-any.whl", hash = "sha256:7e39e845c4d764208e7b8f6a21c541ade741e2c41afabdfa1caa28687a3c98cf", size = 200558, upload-time = "2025-02-11T14:47:00.292Z" }, - { url = "https://files.pythonhosted.org/packages/fb/b2/f655700e1024dec98b10ebaafd0cedbc25e40e4abe62a3c8e2ceef4f8f0a/coverage-7.6.12-py3-none-any.whl", hash = "sha256:eb8668cfbc279a536c633137deeb9435d2962caec279c3f8cf8b91fff6ff8953", size = 200552, upload-time = "2025-02-11T14:47:01.999Z" }, +version = "7.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/19/4f/2251e65033ed2ce1e68f00f91a0294e0f80c80ae8c3ebbe2f12828c4cd53/coverage-7.8.0.tar.gz", hash = "sha256:7a3d62b3b03b4b6fd41a085f3574874cf946cb4604d2b4d3e8dca8cd570ca501", size = 811872, upload-time = "2025-03-30T20:36:45.376Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/01/1c5e6ee4ebaaa5e079db933a9a45f61172048c7efa06648445821a201084/coverage-7.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2931f66991175369859b5fd58529cd4b73582461877ecfd859b6549869287ffe", size = 211379, upload-time = "2025-03-30T20:34:53.904Z" }, + { url = "https://files.pythonhosted.org/packages/e9/16/a463389f5ff916963471f7c13585e5f38c6814607306b3cb4d6b4cf13384/coverage-7.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52a523153c568d2c0ef8826f6cc23031dc86cffb8c6aeab92c4ff776e7951b28", size = 211814, upload-time = "2025-03-30T20:34:56.959Z" }, + { url = "https://files.pythonhosted.org/packages/b8/b1/77062b0393f54d79064dfb72d2da402657d7c569cfbc724d56ac0f9c67ed/coverage-7.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c8a5c139aae4c35cbd7cadca1df02ea8cf28a911534fc1b0456acb0b14234f3", size = 240937, upload-time = "2025-03-30T20:34:58.751Z" }, + { url = "https://files.pythonhosted.org/packages/d7/54/c7b00a23150083c124e908c352db03bcd33375494a4beb0c6d79b35448b9/coverage-7.8.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a26c0c795c3e0b63ec7da6efded5f0bc856d7c0b24b2ac84b4d1d7bc578d676", size = 238849, upload-time = "2025-03-30T20:35:00.521Z" }, + { url = "https://files.pythonhosted.org/packages/f7/ec/a6b7cfebd34e7b49f844788fda94713035372b5200c23088e3bbafb30970/coverage-7.8.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:821f7bcbaa84318287115d54becb1915eece6918136c6f91045bb84e2f88739d", size = 239986, upload-time = "2025-03-30T20:35:02.307Z" }, + { url = "https://files.pythonhosted.org/packages/21/8c/c965ecef8af54e6d9b11bfbba85d4f6a319399f5f724798498387f3209eb/coverage-7.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a321c61477ff8ee705b8a5fed370b5710c56b3a52d17b983d9215861e37b642a", size = 239896, upload-time = "2025-03-30T20:35:04.141Z" }, + { url = "https://files.pythonhosted.org/packages/40/83/070550273fb4c480efa8381735969cb403fa8fd1626d74865bfaf9e4d903/coverage-7.8.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ed2144b8a78f9d94d9515963ed273d620e07846acd5d4b0a642d4849e8d91a0c", size = 238613, upload-time = "2025-03-30T20:35:05.889Z" }, + { url = "https://files.pythonhosted.org/packages/07/76/fbb2540495b01d996d38e9f8897b861afed356be01160ab4e25471f4fed1/coverage-7.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:042e7841a26498fff7a37d6fda770d17519982f5b7d8bf5278d140b67b61095f", size = 238909, upload-time = "2025-03-30T20:35:07.76Z" }, + { url = "https://files.pythonhosted.org/packages/a3/7e/76d604db640b7d4a86e5dd730b73e96e12a8185f22b5d0799025121f4dcb/coverage-7.8.0-cp310-cp310-win32.whl", hash = "sha256:f9983d01d7705b2d1f7a95e10bbe4091fabc03a46881a256c2787637b087003f", size = 213948, upload-time = "2025-03-30T20:35:09.144Z" }, + { url = "https://files.pythonhosted.org/packages/5c/a7/f8ce4aafb4a12ab475b56c76a71a40f427740cf496c14e943ade72e25023/coverage-7.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:5a570cd9bd20b85d1a0d7b009aaf6c110b52b5755c17be6962f8ccd65d1dbd23", size = 214844, upload-time = "2025-03-30T20:35:10.734Z" }, + { url = "https://files.pythonhosted.org/packages/2b/77/074d201adb8383addae5784cb8e2dac60bb62bfdf28b2b10f3a3af2fda47/coverage-7.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e7ac22a0bb2c7c49f441f7a6d46c9c80d96e56f5a8bc6972529ed43c8b694e27", size = 211493, upload-time = "2025-03-30T20:35:12.286Z" }, + { url = "https://files.pythonhosted.org/packages/a9/89/7a8efe585750fe59b48d09f871f0e0c028a7b10722b2172dfe021fa2fdd4/coverage-7.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf13d564d310c156d1c8e53877baf2993fb3073b2fc9f69790ca6a732eb4bfea", size = 211921, upload-time = "2025-03-30T20:35:14.18Z" }, + { url = "https://files.pythonhosted.org/packages/e9/ef/96a90c31d08a3f40c49dbe897df4f1fd51fb6583821a1a1c5ee30cc8f680/coverage-7.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5761c70c017c1b0d21b0815a920ffb94a670c8d5d409d9b38857874c21f70d7", size = 244556, upload-time = "2025-03-30T20:35:15.616Z" }, + { url = "https://files.pythonhosted.org/packages/89/97/dcd5c2ce72cee9d7b0ee8c89162c24972fb987a111b92d1a3d1d19100c61/coverage-7.8.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5ff52d790c7e1628241ffbcaeb33e07d14b007b6eb00a19320c7b8a7024c040", size = 242245, upload-time = "2025-03-30T20:35:18.648Z" }, + { url = "https://files.pythonhosted.org/packages/b2/7b/b63cbb44096141ed435843bbb251558c8e05cc835c8da31ca6ffb26d44c0/coverage-7.8.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d39fc4817fd67b3915256af5dda75fd4ee10621a3d484524487e33416c6f3543", size = 244032, upload-time = "2025-03-30T20:35:20.131Z" }, + { url = "https://files.pythonhosted.org/packages/97/e3/7fa8c2c00a1ef530c2a42fa5df25a6971391f92739d83d67a4ee6dcf7a02/coverage-7.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b44674870709017e4b4036e3d0d6c17f06a0e6d4436422e0ad29b882c40697d2", size = 243679, upload-time = "2025-03-30T20:35:21.636Z" }, + { url = "https://files.pythonhosted.org/packages/4f/b3/e0a59d8df9150c8a0c0841d55d6568f0a9195692136c44f3d21f1842c8f6/coverage-7.8.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8f99eb72bf27cbb167b636eb1726f590c00e1ad375002230607a844d9e9a2318", size = 241852, upload-time = "2025-03-30T20:35:23.525Z" }, + { url = "https://files.pythonhosted.org/packages/9b/82/db347ccd57bcef150c173df2ade97976a8367a3be7160e303e43dd0c795f/coverage-7.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b571bf5341ba8c6bc02e0baeaf3b061ab993bf372d982ae509807e7f112554e9", size = 242389, upload-time = "2025-03-30T20:35:25.09Z" }, + { url = "https://files.pythonhosted.org/packages/21/f6/3f7d7879ceb03923195d9ff294456241ed05815281f5254bc16ef71d6a20/coverage-7.8.0-cp311-cp311-win32.whl", hash = "sha256:e75a2ad7b647fd8046d58c3132d7eaf31b12d8a53c0e4b21fa9c4d23d6ee6d3c", size = 213997, upload-time = "2025-03-30T20:35:26.914Z" }, + { url = "https://files.pythonhosted.org/packages/28/87/021189643e18ecf045dbe1e2071b2747901f229df302de01c998eeadf146/coverage-7.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:3043ba1c88b2139126fc72cb48574b90e2e0546d4c78b5299317f61b7f718b78", size = 214911, upload-time = "2025-03-30T20:35:28.498Z" }, + { url = "https://files.pythonhosted.org/packages/aa/12/4792669473297f7973518bec373a955e267deb4339286f882439b8535b39/coverage-7.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bbb5cc845a0292e0c520656d19d7ce40e18d0e19b22cb3e0409135a575bf79fc", size = 211684, upload-time = "2025-03-30T20:35:29.959Z" }, + { url = "https://files.pythonhosted.org/packages/be/e1/2a4ec273894000ebedd789e8f2fc3813fcaf486074f87fd1c5b2cb1c0a2b/coverage-7.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4dfd9a93db9e78666d178d4f08a5408aa3f2474ad4d0e0378ed5f2ef71640cb6", size = 211935, upload-time = "2025-03-30T20:35:31.912Z" }, + { url = "https://files.pythonhosted.org/packages/f8/3a/7b14f6e4372786709a361729164125f6b7caf4024ce02e596c4a69bccb89/coverage-7.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f017a61399f13aa6d1039f75cd467be388d157cd81f1a119b9d9a68ba6f2830d", size = 245994, upload-time = "2025-03-30T20:35:33.455Z" }, + { url = "https://files.pythonhosted.org/packages/54/80/039cc7f1f81dcbd01ea796d36d3797e60c106077e31fd1f526b85337d6a1/coverage-7.8.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0915742f4c82208ebf47a2b154a5334155ed9ef9fe6190674b8a46c2fb89cb05", size = 242885, upload-time = "2025-03-30T20:35:35.354Z" }, + { url = "https://files.pythonhosted.org/packages/10/e0/dc8355f992b6cc2f9dcd5ef6242b62a3f73264893bc09fbb08bfcab18eb4/coverage-7.8.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a40fcf208e021eb14b0fac6bdb045c0e0cab53105f93ba0d03fd934c956143a", size = 245142, upload-time = "2025-03-30T20:35:37.121Z" }, + { url = "https://files.pythonhosted.org/packages/43/1b/33e313b22cf50f652becb94c6e7dae25d8f02e52e44db37a82de9ac357e8/coverage-7.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a1f406a8e0995d654b2ad87c62caf6befa767885301f3b8f6f73e6f3c31ec3a6", size = 244906, upload-time = "2025-03-30T20:35:39.07Z" }, + { url = "https://files.pythonhosted.org/packages/05/08/c0a8048e942e7f918764ccc99503e2bccffba1c42568693ce6955860365e/coverage-7.8.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:77af0f6447a582fdc7de5e06fa3757a3ef87769fbb0fdbdeba78c23049140a47", size = 243124, upload-time = "2025-03-30T20:35:40.598Z" }, + { url = "https://files.pythonhosted.org/packages/5b/62/ea625b30623083c2aad645c9a6288ad9fc83d570f9adb913a2abdba562dd/coverage-7.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f2d32f95922927186c6dbc8bc60df0d186b6edb828d299ab10898ef3f40052fe", size = 244317, upload-time = "2025-03-30T20:35:42.204Z" }, + { url = "https://files.pythonhosted.org/packages/62/cb/3871f13ee1130a6c8f020e2f71d9ed269e1e2124aa3374d2180ee451cee9/coverage-7.8.0-cp312-cp312-win32.whl", hash = "sha256:769773614e676f9d8e8a0980dd7740f09a6ea386d0f383db6821df07d0f08545", size = 214170, upload-time = "2025-03-30T20:35:44.216Z" }, + { url = "https://files.pythonhosted.org/packages/88/26/69fe1193ab0bfa1eb7a7c0149a066123611baba029ebb448500abd8143f9/coverage-7.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:e5d2b9be5b0693cf21eb4ce0ec8d211efb43966f6657807f6859aab3814f946b", size = 214969, upload-time = "2025-03-30T20:35:45.797Z" }, + { url = "https://files.pythonhosted.org/packages/f3/21/87e9b97b568e223f3438d93072479c2f36cc9b3f6b9f7094b9d50232acc0/coverage-7.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5ac46d0c2dd5820ce93943a501ac5f6548ea81594777ca585bf002aa8854cacd", size = 211708, upload-time = "2025-03-30T20:35:47.417Z" }, + { url = "https://files.pythonhosted.org/packages/75/be/882d08b28a0d19c9c4c2e8a1c6ebe1f79c9c839eb46d4fca3bd3b34562b9/coverage-7.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:771eb7587a0563ca5bb6f622b9ed7f9d07bd08900f7589b4febff05f469bea00", size = 211981, upload-time = "2025-03-30T20:35:49.002Z" }, + { url = "https://files.pythonhosted.org/packages/7a/1d/ce99612ebd58082fbe3f8c66f6d8d5694976c76a0d474503fa70633ec77f/coverage-7.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42421e04069fb2cbcbca5a696c4050b84a43b05392679d4068acbe65449b5c64", size = 245495, upload-time = "2025-03-30T20:35:51.073Z" }, + { url = "https://files.pythonhosted.org/packages/dc/8d/6115abe97df98db6b2bd76aae395fcc941d039a7acd25f741312ced9a78f/coverage-7.8.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:554fec1199d93ab30adaa751db68acec2b41c5602ac944bb19187cb9a41a8067", size = 242538, upload-time = "2025-03-30T20:35:52.941Z" }, + { url = "https://files.pythonhosted.org/packages/cb/74/2f8cc196643b15bc096d60e073691dadb3dca48418f08bc78dd6e899383e/coverage-7.8.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aaeb00761f985007b38cf463b1d160a14a22c34eb3f6a39d9ad6fc27cb73008", size = 244561, upload-time = "2025-03-30T20:35:54.658Z" }, + { url = "https://files.pythonhosted.org/packages/22/70/c10c77cd77970ac965734fe3419f2c98665f6e982744a9bfb0e749d298f4/coverage-7.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:581a40c7b94921fffd6457ffe532259813fc68eb2bdda60fa8cc343414ce3733", size = 244633, upload-time = "2025-03-30T20:35:56.221Z" }, + { url = "https://files.pythonhosted.org/packages/38/5a/4f7569d946a07c952688debee18c2bb9ab24f88027e3d71fd25dbc2f9dca/coverage-7.8.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f319bae0321bc838e205bf9e5bc28f0a3165f30c203b610f17ab5552cff90323", size = 242712, upload-time = "2025-03-30T20:35:57.801Z" }, + { url = "https://files.pythonhosted.org/packages/bb/a1/03a43b33f50475a632a91ea8c127f7e35e53786dbe6781c25f19fd5a65f8/coverage-7.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04bfec25a8ef1c5f41f5e7e5c842f6b615599ca8ba8391ec33a9290d9d2db3a3", size = 244000, upload-time = "2025-03-30T20:35:59.378Z" }, + { url = "https://files.pythonhosted.org/packages/6a/89/ab6c43b1788a3128e4d1b7b54214548dcad75a621f9d277b14d16a80d8a1/coverage-7.8.0-cp313-cp313-win32.whl", hash = "sha256:dd19608788b50eed889e13a5d71d832edc34fc9dfce606f66e8f9f917eef910d", size = 214195, upload-time = "2025-03-30T20:36:01.005Z" }, + { url = "https://files.pythonhosted.org/packages/12/12/6bf5f9a8b063d116bac536a7fb594fc35cb04981654cccb4bbfea5dcdfa0/coverage-7.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:a9abbccd778d98e9c7e85038e35e91e67f5b520776781d9a1e2ee9d400869487", size = 214998, upload-time = "2025-03-30T20:36:03.006Z" }, + { url = "https://files.pythonhosted.org/packages/2a/e6/1e9df74ef7a1c983a9c7443dac8aac37a46f1939ae3499424622e72a6f78/coverage-7.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:18c5ae6d061ad5b3e7eef4363fb27a0576012a7447af48be6c75b88494c6cf25", size = 212541, upload-time = "2025-03-30T20:36:04.638Z" }, + { url = "https://files.pythonhosted.org/packages/04/51/c32174edb7ee49744e2e81c4b1414ac9df3dacfcb5b5f273b7f285ad43f6/coverage-7.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:95aa6ae391a22bbbce1b77ddac846c98c5473de0372ba5c463480043a07bff42", size = 212767, upload-time = "2025-03-30T20:36:06.503Z" }, + { url = "https://files.pythonhosted.org/packages/e9/8f/f454cbdb5212f13f29d4a7983db69169f1937e869a5142bce983ded52162/coverage-7.8.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e013b07ba1c748dacc2a80e69a46286ff145935f260eb8c72df7185bf048f502", size = 256997, upload-time = "2025-03-30T20:36:08.137Z" }, + { url = "https://files.pythonhosted.org/packages/e6/74/2bf9e78b321216d6ee90a81e5c22f912fc428442c830c4077b4a071db66f/coverage-7.8.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d766a4f0e5aa1ba056ec3496243150698dc0481902e2b8559314368717be82b1", size = 252708, upload-time = "2025-03-30T20:36:09.781Z" }, + { url = "https://files.pythonhosted.org/packages/92/4d/50d7eb1e9a6062bee6e2f92e78b0998848a972e9afad349b6cdde6fa9e32/coverage-7.8.0-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad80e6b4a0c3cb6f10f29ae4c60e991f424e6b14219d46f1e7d442b938ee68a4", size = 255046, upload-time = "2025-03-30T20:36:11.409Z" }, + { url = "https://files.pythonhosted.org/packages/40/9e/71fb4e7402a07c4198ab44fc564d09d7d0ffca46a9fb7b0a7b929e7641bd/coverage-7.8.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b87eb6fc9e1bb8f98892a2458781348fa37e6925f35bb6ceb9d4afd54ba36c73", size = 256139, upload-time = "2025-03-30T20:36:13.86Z" }, + { url = "https://files.pythonhosted.org/packages/49/1a/78d37f7a42b5beff027e807c2843185961fdae7fe23aad5a4837c93f9d25/coverage-7.8.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:d1ba00ae33be84066cfbe7361d4e04dec78445b2b88bdb734d0d1cbab916025a", size = 254307, upload-time = "2025-03-30T20:36:16.074Z" }, + { url = "https://files.pythonhosted.org/packages/58/e9/8fb8e0ff6bef5e170ee19d59ca694f9001b2ec085dc99b4f65c128bb3f9a/coverage-7.8.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f3c38e4e5ccbdc9198aecc766cedbb134b2d89bf64533973678dfcf07effd883", size = 255116, upload-time = "2025-03-30T20:36:18.033Z" }, + { url = "https://files.pythonhosted.org/packages/56/b0/d968ecdbe6fe0a863de7169bbe9e8a476868959f3af24981f6a10d2b6924/coverage-7.8.0-cp313-cp313t-win32.whl", hash = "sha256:379fe315e206b14e21db5240f89dc0774bdd3e25c3c58c2c733c99eca96f1ada", size = 214909, upload-time = "2025-03-30T20:36:19.644Z" }, + { url = "https://files.pythonhosted.org/packages/87/e9/d6b7ef9fecf42dfb418d93544af47c940aa83056c49e6021a564aafbc91f/coverage-7.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:2e4b6b87bb0c846a9315e3ab4be2d52fac905100565f4b92f02c445c8799e257", size = 216068, upload-time = "2025-03-30T20:36:21.282Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f1/1da77bb4c920aa30e82fa9b6ea065da3467977c2e5e032e38e66f1c57ffd/coverage-7.8.0-pp39.pp310.pp311-none-any.whl", hash = "sha256:b8194fb8e50d556d5849753de991d390c5a1edeeba50f68e3a9253fbd8bf8ccd", size = 203443, upload-time = "2025-03-30T20:36:41.959Z" }, + { url = "https://files.pythonhosted.org/packages/59/f1/4da7717f0063a222db253e7121bd6a56f6fb1ba439dcc36659088793347c/coverage-7.8.0-py3-none-any.whl", hash = "sha256:dbf364b4c5e7bae9250528167dfe40219b62e2d573c854d74be213e1e52069f7", size = 203435, upload-time = "2025-03-30T20:36:43.61Z" }, ] [package.optional-dependencies] @@ -2754,15 +2754,15 @@ wheels = [ [[package]] name = "pytest-cov" -version = "6.0.0" +version = "6.1.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "coverage", extra = ["toml"] }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/be/45/9b538de8cef30e17c7b45ef42f538a94889ed6a16f2387a6c89e73220651/pytest-cov-6.0.0.tar.gz", hash = "sha256:fde0b595ca248bb8e2d76f020b465f3b107c9632e6a1d1705f17834c89dcadc0", size = 66945, upload-time = "2024-10-29T20:13:35.363Z" } +sdist = { url = "https://files.pythonhosted.org/packages/25/69/5f1e57f6c5a39f81411b550027bf72842c4567ff5fd572bed1edc9e4b5d9/pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a", size = 66857, upload-time = "2025-04-05T14:07:51.592Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/36/3b/48e79f2cd6a61dbbd4807b4ed46cb564b4fd50a76166b1c4ea5c1d9e2371/pytest_cov-6.0.0-py3-none-any.whl", hash = "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35", size = 22949, upload-time = "2024-10-29T20:13:33.215Z" }, + { url = "https://files.pythonhosted.org/packages/28/d0/def53b4a790cfb21483016430ed828f64830dd981ebe1089971cd10cab25/pytest_cov-6.1.1-py3-none-any.whl", hash = "sha256:bddf29ed2d0ab6f4df17b4c55b0a657287db8684af9c42ea546b21b1041b3dde", size = 23841, upload-time = "2025-04-05T14:07:49.641Z" }, ] [[package]] @@ -3605,10 +3605,12 @@ wheels = [ name = "soundscapy" source = { editable = "." } dependencies = [ + { name = "coverage" }, { name = "loguru" }, { name = "numpy" }, { name = "pandas", extra = ["excel"] }, { name = "pydantic" }, + { name = "pytest-cov" }, { name = "pyyaml" }, { name = "scipy" }, { name = "seaborn" }, @@ -3672,12 +3674,14 @@ test = [ [package.metadata] requires-dist = [ { name = "acoustic-toolbox", marker = "extra == 'audio'", specifier = ">=0.1.2" }, + { name = "coverage", specifier = "==7.8.0" }, { name = "loguru", specifier = ">=0.7.2" }, { name = "mosqito", marker = "extra == 'audio'", specifier = ">=1.2.1" }, { name = "numba", marker = "extra == 'audio'", specifier = ">=0.59" }, { name = "numpy", specifier = "!=1.26" }, { name = "pandas", extras = ["excel"], specifier = ">=2.2.2" }, { name = "pydantic", specifier = ">=2.8.2" }, + { name = "pytest-cov", specifier = "==6.1.1" }, { name = "pyyaml", specifier = ">=6.0.2" }, { name = "rpy2", marker = "extra == 'spi'", specifier = ">=3.5.0" }, { name = "scikit-maad", marker = "extra == 'audio'", specifier = ">=1.4.3" }, From c5cefabc7a0e8b460d67a28eef48ff54480e6577 Mon Sep 17 00:00:00 2001 From: Andrew Mitchell Date: Sat, 10 May 2025 03:40:54 +0100 Subject: [PATCH 3/8] Add refactored plotting module with ISOPlot and layers Introduce a new plotting module with a cleaner architecture based on composition. Includes the main `ISOPlot` class, various visualization layers, and supporting constants, enabling flexible and modular circumplex plotting functionality. Signed-off-by: Andrew Mitchell --- src/soundscapy/plotting/new/README.md | 87 ++ src/soundscapy/plotting/new/__init__.py | 76 ++ src/soundscapy/plotting/new/constants.py | 46 + src/soundscapy/plotting/new/iso_plot.py | 797 +++++++++++++++++ src/soundscapy/plotting/new/layer.py | 434 ++++++++++ src/soundscapy/plotting/new/managers.py | 798 ++++++++++++++++++ .../plotting/new/parameter_models.py | 377 +++++++++ src/soundscapy/plotting/new/plot_context.py | 254 ++++++ src/soundscapy/plotting/new/protocols.py | 144 ++++ 9 files changed, 3013 insertions(+) create mode 100644 src/soundscapy/plotting/new/README.md create mode 100644 src/soundscapy/plotting/new/__init__.py create mode 100644 src/soundscapy/plotting/new/constants.py create mode 100644 src/soundscapy/plotting/new/iso_plot.py create mode 100644 src/soundscapy/plotting/new/layer.py create mode 100644 src/soundscapy/plotting/new/managers.py create mode 100644 src/soundscapy/plotting/new/parameter_models.py create mode 100644 src/soundscapy/plotting/new/plot_context.py create mode 100644 src/soundscapy/plotting/new/protocols.py diff --git a/src/soundscapy/plotting/new/README.md b/src/soundscapy/plotting/new/README.md new file mode 100644 index 0000000..9517558 --- /dev/null +++ b/src/soundscapy/plotting/new/README.md @@ -0,0 +1,87 @@ +# Refactored Plotting Module + +This directory contains a refactored implementation of the plotting functionality in soundscapy. The refactoring focuses on: + +1. Using composition instead of inheritance +2. Simplifying the relationships between components +3. Consolidating parameter models +4. Improving type safety + +## Architecture + +The new architecture consists of the following components: + +### Core Components + +- **ISOPlot**: The main entry point for creating plots. Uses composition to delegate functionality to specialized managers. +- **PlotContext**: Manages data, state, and parameters for a plot or subplot. The central component that owns parameter models. +- **Layer**: Base class for visualization layers. Layers know how to render themselves onto a PlotContext's axes. + +### Managers + +- **LayerManager**: Manages the creation and rendering of visualization layers. +- **StyleManager**: Manages the styling of plots. +- **SubplotManager**: Manages the creation and configuration of subplots. + +### Parameter Models + +- **BaseParams**: Base model for all parameter types. +- **AxisParams**: Parameters for axis configuration. +- **SeabornParams**: Base parameters for seaborn plotting functions. +- **ScatterParams**: Parameters for scatter plot functions. +- **DensityParams**: Parameters for density plot functions. +- **SimpleDensityParams**: Parameters for simple density plots. +- **SPISeabornParams**: Base parameters for SPI plotting functions. +- **SPISimpleDensityParams**: Parameters for SPI simple density plots. +- **StyleParams**: Configuration options for styling circumplex plots. +- **SubplotsParams**: Parameters for subplot configuration. + +### Layer Types + +- **ScatterLayer**: Layer for rendering scatter plots. +- **DensityLayer**: Layer for rendering density plots. +- **SimpleDensityLayer**: Layer for rendering simple density plots. +- **SPISimpleLayer**: Layer for rendering SPI simple density plots. + +### Protocols + +- **RenderableLayer**: Protocol defining what a renderable layer must implement. +- **ParameterProvider**: Protocol defining how parameters are provided. +- **ParamModel**: Protocol defining the interface for parameter models. +- **PlotContext**: Protocol defining the interface for plot contexts. + +## Usage + +The new implementation maintains the same public API as the original, so existing code should continue to work with minimal changes: + +```python +from soundscapy.plotting.new import ISOPlot + +# Create a plot +plot = ISOPlot(data=data, hue="LocationID") + +# Add layers +plot.create_subplots() +plot.add_scatter() +plot.add_density() +plot.apply_styling() + +# Show the plot +plot.show() +``` + +## Benefits of the New Architecture + +1. **Clearer Separation of Concerns**: Each component has a well-defined responsibility. +2. **Reduced Coupling**: Components are less tightly coupled, making the code more maintainable. +3. **Improved Type Safety**: Better use of type hints and protocols for structural typing. +4. **More Flexible Composition**: Easier to extend with new layer types and customize behavior. +5. **Reduced Duplication**: Single source of truth for parameters. +6. **Simplified Testing**: Components can be tested in isolation. + +## Implementation Notes + +- The refactored code is in a separate directory to avoid breaking existing code. +- The parameter models use Pydantic for validation, maintaining the type safety of the original implementation. +- The layer system has been simplified, with a focus on using parameters from the context. +- The ISOPlot class uses composition instead of inheritance, delegating functionality to specialized managers. \ No newline at end of file diff --git a/src/soundscapy/plotting/new/__init__.py b/src/soundscapy/plotting/new/__init__.py new file mode 100644 index 0000000..366891f --- /dev/null +++ b/src/soundscapy/plotting/new/__init__.py @@ -0,0 +1,76 @@ +""" +Refactored plotting module for soundscapy. + +This module provides a refactored implementation of the plotting functionality +in soundscapy, using composition instead of inheritance and with a cleaner +architecture. The main entry point is the ISOPlot class. + +Examples +-------- +Create a simple scatter plot: + +>>> import pandas as pd +>>> import numpy as np +>>> from soundscapy.plotting.new import ISOPlot +>>> # Create some sample data +>>> rng = np.random.default_rng(42) +>>> data = pd.DataFrame( +... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), +... columns=['ISOPleasant', 'ISOEventful'] +... ) +>>> # Create a plot with multiple layers +>>> plot = (ISOPlot(data=data) +... .add_scatter() +... .add_simple_density(fill=False) +... .apply_styling( +... xlim=(-1, 1), +... ylim=(-1, 1), +... primary_lines=True +... )) +>>> isinstance(plot, ISOPlot) +True +>>> # plot.show() # Uncomment to display the plot + +""" + +from soundscapy.plotting.new.iso_plot import ISOPlot +from soundscapy.plotting.new.layer import ( + DensityLayer, + Layer, + ScatterLayer, + SimpleDensityLayer, + SPISimpleLayer, +) +from soundscapy.plotting.new.parameter_models import ( + BaseParams, + DensityParams, + ScatterParams, + SimpleDensityParams, + SPISeabornParams, + SPISimpleDensityParams, + StyleParams, + SubplotsParams, +) +from soundscapy.plotting.new.plot_context import PlotContext + +__all__ = [ + # Parameter models + "BaseParams", + "DensityLayer", + "DensityParams", + # Main plotting class + "ISOPlot", + # Layer classes + "Layer", + # Context + "PlotContext", + "SPISeabornParams", + "SPISimpleDensityParams", + "SPISimpleLayer", + "ScatterLayer", + "ScatterParams", + "SimpleDensityLayer", + "SimpleDensityParams", + "StyleParams", + "SubplotsParams", +] diff --git a/src/soundscapy/plotting/new/constants.py b/src/soundscapy/plotting/new/constants.py new file mode 100644 index 0000000..321b95f --- /dev/null +++ b/src/soundscapy/plotting/new/constants.py @@ -0,0 +1,46 @@ +""" +Constants for soundscape plotting functions. + +This module provides common constants used for various plot types, +including default column names, limits, and other configuration values. +These constants are used by the parameter models to provide default values. +""" + +# Basic defaults +DEFAULT_XCOL = "ISOPleasant" +DEFAULT_YCOL = "ISOEventful" +DEFAULT_XLIM = (-1, 1) +DEFAULT_YLIM = (-1, 1) + +DEFAULT_FIGSIZE = (5, 5) +DEFAULT_POINT_SIZE = 20 +DEFAULT_BW_ADJUST = 1.2 + +DEFAULT_COLOR = "#0173B2" # First color from colorblind palette + +RECOMMENDED_MIN_SAMPLES = 30 + +# Default font settings for axis labels +DEFAULT_FONTDICT = { + "family": "sans-serif", + "fontstyle": "normal", + "fontsize": "large", + "fontweight": "medium", + "parse_math": True, + "c": "black", + "alpha": 1, +} + +# Default SPI text settings +DEFAULT_SPI_TEXT_KWARGS = { + "x": 0, + "y": -0.85, + "fontsize": 10, + "bbox": { + "facecolor": "white", + "edgecolor": "black", + "boxstyle": "round,pad=0.3", + }, + "ha": "center", + "va": "center", +} \ No newline at end of file diff --git a/src/soundscapy/plotting/new/iso_plot.py b/src/soundscapy/plotting/new/iso_plot.py new file mode 100644 index 0000000..dde50ca --- /dev/null +++ b/src/soundscapy/plotting/new/iso_plot.py @@ -0,0 +1,797 @@ +""" +Main plotting class for the soundscapy package. + +This module provides the ISOPlot class, which is the main entry point for creating +circumplex plots. The class uses composition instead of inheritance to delegate +functionality to specialized manager classes. + +Examples +-------- +Create a simple scatter plot: + +>>> import pandas as pd +>>> import numpy as np +>>> from soundscapy.plotting.new import ISOPlot +>>> # Create some sample data +>>> rng = np.random.default_rng(42) +>>> data = pd.DataFrame( +... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), +... columns=['ISOPleasant', 'ISOEventful'] +... ) +>>> # Create a plot and add a scatter layer +>>> plot = ISOPlot(data=data) +>>> plot.add_scatter() +>>> plot.apply_styling() +>>> isinstance(plot, ISOPlot) +True + +Create a plot with subplots and multiple layers: + +>>> # Add a group column to the data +>>> data['Group'] = rng.integers(1, 3, 100) +>>> # Create a plot with subplots by group +>>> plot = (ISOPlot(data=data, hue='Group') +... .add_scatter() +... .add_simple_density(fill=False) +... .apply_styling()) +>>> isinstance(plot, ISOPlot) +True + +""" + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Literal + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from soundscapy.plotting.new.constants import ( + DEFAULT_XCOL, + DEFAULT_YCOL, +) +from soundscapy.plotting.new.layer import ( + Layer, +) +from soundscapy.plotting.new.managers import ( + LayerManager, + StyleManager, + SubplotManager, +) +from soundscapy.plotting.new.parameter_models import ( + SubplotsParams, +) +from soundscapy.plotting.new.plot_context import PlotContext +from soundscapy.sspylogging import get_logger + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.figure import Figure + +logger = get_logger() + + +class ISOPlot: + """ + A class for creating circumplex plots using different visualization layers. + + This class provides methods for creating scatter plots, density plots, and other + visualizations based on the circumplex model of soundscape perception. It uses + composition to delegate functionality to specialized manager classes. + + Attributes + ---------- + main_context : PlotContext + The main plot context + figure : Figure | None + The matplotlib figure + axes : Axes | np.ndarray | None + The matplotlib axes + subplot_contexts : list[PlotContext] + List of subplot contexts + subplots_params : SubplotsParams + Parameters for subplot configuration + layers : LayerManager + Manager for layer-related functionality + styling : StyleManager + Manager for styling-related functionality + subplots : SubplotManager + Manager for subplot-related functionality + + Examples + -------- + Create a plot with default parameters: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> plot = ISOPlot() + >>> isinstance(plot, ISOPlot) + True + + Create a plot with a DataFrame: + + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot(data=data) + >>> plot.x + 'ISOPleasant' + >>> plot.y + 'ISOEventful' + + Create a plot with a DataFrame and hue: + + >>> data['Group'] = rng.integers(1, 3, 100) + >>> plot = ISOPlot(data=data, hue='Group') + >>> plot.hue + 'Group' + + Create a plot directly with arrays: + + >>> x, y = rng.multivariate_normal([0, 0], [[1, 0], [0, 1]], 100).T + >>> plot = ISOPlot(x=x, y=y) + >>> isinstance(plot, ISOPlot) + True + >>> plot.main_context.data is not None + True + + """ + + def __init__( + self, + data: pd.DataFrame | None = None, + x: str | np.ndarray | pd.Series | None = DEFAULT_XCOL, + y: str | np.ndarray | pd.Series | None = DEFAULT_YCOL, + title: str | None = "Soundscape Plot", + hue: str | None = None, + palette: str | list | dict | None = "colorblind", + figure: Figure | None = None, + axes: Axes | np.ndarray | None = None, + ) -> None: + """ + Initialize an ISOPlot instance. + + Parameters + ---------- + data : pd.DataFrame | None, optional + The data to be plotted, by default None + x : str | np.ndarray | pd.Series | None, optional + Column name or data for x-axis, by default DEFAULT_XCOL + y : str | np.ndarray | pd.Series | None, optional + Column name or data for y-axis, by default DEFAULT_YCOL + title : str | None, optional + Title of the plot, by default "Soundscape Plot" + hue : str | None, optional + Column name for color encoding, by default None + palette : str | list | dict | None, optional + Color palette to use, by default "colorblind" + figure : Figure | None, optional + Existing figure to plot on, by default None + axes : Axes | np.ndarray | None, optional + Existing axes to plot on, by default None + + """ + # Process and validate input data and coordinates + data, x, y = self._check_data_x_y(data, x, y) + self._check_data_hue(data, hue) + + # Initialize the main plot context + self.main_context = PlotContext( + data=data, + x=x if isinstance(x, str) else DEFAULT_XCOL, + y=y if isinstance(y, str) else DEFAULT_YCOL, + hue=hue, + title=title, + ) + + # Store additional plot attributes + self.figure = figure + self.axes = axes + self.palette = palette + + # Initialize subplot management + self.subplot_contexts: list[PlotContext] = [] + self.subplots_params = SubplotsParams() + + # Initialize managers using composition + self.layers = LayerManager(self) + self.styling = StyleManager(self) + self.subplots = SubplotManager(self) + + @property + def x(self) -> str: + """Get the x-axis column name.""" + return self.main_context.x + + @property + def y(self) -> str: + """Get the y-axis column name.""" + return self.main_context.y + + @property + def hue(self) -> str | None: + """Get the hue column name.""" + return self.main_context.hue + + @property + def title(self) -> str | None: + """Get the plot title.""" + return self.main_context.title + + @staticmethod + def _check_data_x_y( + data: pd.DataFrame | None, + x: str | pd.Series | np.ndarray | None, + y: str | pd.Series | np.ndarray | None, + ) -> tuple[ + pd.DataFrame | None, str | np.ndarray | pd.Series, str | np.ndarray | pd.Series + ]: + """ + Process and validate input data and coordinates. + + Parameters + ---------- + data : pd.DataFrame | None + The data to be plotted + x : str | pd.Series | np.ndarray | None + Column name or data for x-axis + y : str | pd.Series | np.ndarray | None + Column name or data for y-axis + + Returns + ------- + tuple[pd.DataFrame | None, str | np.ndarray | pd.Series, str | np.ndarray | pd.Series] + Processed data, x, and y + + """ # noqa: E501 + # Case 1: x and y are arrays/series, data is None + if ( + isinstance(x, (np.ndarray, pd.Series)) + and isinstance(y, (np.ndarray, pd.Series)) + and data is None + ): + # Create a DataFrame from x and y + data = pd.DataFrame( + { + DEFAULT_XCOL: x, + DEFAULT_YCOL: y, + } + ) + x = DEFAULT_XCOL + y = DEFAULT_YCOL + return data, x, y + + # Case 2: data is provided, x and y are column names + if data is not None: + # Ensure x and y are column names if they're strings + if isinstance(x, str) and x not in data.columns: + msg = f"Column '{x}' not found in data" + raise ValueError(msg) + if isinstance(y, str) and y not in data.columns: + msg = f"Column '{y}' not found in data" + raise ValueError(msg) + return data, x or DEFAULT_XCOL, y or DEFAULT_YCOL + + # Case 3: No data provided, use default column names + return None, x or DEFAULT_XCOL, y or DEFAULT_YCOL + + @staticmethod + def _check_data_hue(data: pd.DataFrame | None, hue: str | None) -> None: + """ + Check if the hue column exists in the data. + + Parameters + ---------- + data : pd.DataFrame | None + The data to be plotted + hue : str | None + Column name for color encoding + + Raises + ------ + ValueError + If the hue column is not found in the data + + """ + if data is not None and hue is not None and hue not in data.columns: + msg = f"Hue column '{hue}' not found in data" + raise ValueError(msg) + + # Convenience methods that delegate to managers + + def create_subplots( + self, + nrows: int = 1, + ncols: int = 1, + figsize: tuple[float, float] | None = None, + subplot_by: str | None = None, + *, + sharex: bool | Literal["none", "all", "row", "col"] = True, + sharey: bool | Literal["none", "all", "row", "col"] = True, + **kwargs: Any, + ) -> ISOPlot: + """ + Create a grid of subplots. + + Parameters + ---------- + nrows : int, optional + Number of rows, by default 1 + ncols : int, optional + Number of columns, by default 1 + figsize : tuple[float, float] | None, optional + Figure size, by default None + subplot_by : str | None, optional + Column to create subplots by, by default None + * + sharex : bool | Literal["none", "all", "row", "col"], optional + Whether to share x-axis, by default True + sharey : bool | Literal["none", "all", "row", "col"], optional + Whether to share y-axis, by default True + **kwargs : Any + Additional parameters for subplots + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Create a 2x2 grid of subplots: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot(data=data) + >>> plot = plot.create_subplots(nrows=2,ncols=2) + >>> len(plot.subplot_contexts) + 4 + + Create subplots by a grouping variable: + + >>> data['Group'] = rng.integers(1, 3, 100) + >>> plot = ISOPlot(data=data) + >>> plot = plot.create_subplots(subplot_by='Group') + >>> len(plot.subplot_contexts) + 2 + >>> plot.subplot_contexts[0].title is not None + True + + """ + return self.subplots.create_subplots( + nrows=nrows, + ncols=ncols, + figsize=figsize, + sharex=sharex, + sharey=sharey, + subplot_by=subplot_by, + **kwargs, + ) + + def add_scatter( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> ISOPlot: + """ + Add a scatter layer to the plot. + + Parameters + ---------- + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the scatter layer + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Add a scatter layer to a plot: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_scatter() + >>> len(plot.main_context.layers) + 1 + + Add a scatter layer with custom parameters: + + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_scatter(s=50, alpha=0.5, color='red') + >>> len(plot.main_context.layers) + 1 + + Add a scatter layer to a specific subplot: + + >>> plot = ISOPlot(data=data) + >>> plot = plot.create_subplots(nrows=2,ncols=2) + >>> plot = plot.add_scatter(on_axis=0) + >>> len(plot.subplot_contexts[0].layers) + 1 + >>> len(plot.subplot_contexts[1].layers) + 0 + + """ + return self.layers.add_scatter( + data=data, + on_axis=on_axis, + **params, + ) + + def add_density( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> ISOPlot: + """ + Add a density layer to the plot. + + Parameters + ---------- + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the density layer + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Add a density layer to a plot: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_density() + >>> len(plot.main_context.layers) + 1 + + Add a density layer with custom parameters: + + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_density(fill=False, levels=5, alpha=0.7) + >>> len(plot.main_context.layers) + 1 + + Add a density layer to a specific subplot: + + >>> plot = ISOPlot(data=data) + >>> plot = plot.create_subplots(nrows=2,ncols=2) + >>> plot = plot.add_density(on_axis=1) + >>> len(plot.subplot_contexts[0].layers) + 0 + >>> len(plot.subplot_contexts[1].layers) + 1 + + """ + return self.layers.add_density( + data=data, + on_axis=on_axis, + **params, + ) + + def add_simple_density( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> ISOPlot: + """ + Add a simple density layer to the plot. + + Parameters + ---------- + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the simple density layer + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Add a simple density layer to a plot: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_simple_density() + >>> len(plot.main_context.layers) + 1 + + Add a simple density layer with custom parameters: + + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_simple_density(fill=False, thresh=0.3) + >>> len(plot.main_context.layers) + 1 + + Add a simple density layer to multiple subplots: + + >>> plot = ISOPlot(data=data) + >>> plot = plot.create_subplots(nrows=2,ncols=2) + >>> plot = plot.add_simple_density(on_axis=[0, 2]) + >>> len(plot.subplot_contexts[0].layers) + 1 + >>> len(plot.subplot_contexts[1].layers) + 0 + >>> len(plot.subplot_contexts[2].layers) + 1 + + """ + return self.layers.add_simple_density( + data=data, + on_axis=on_axis, + **params, + ) + + def add_spi_simple( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> ISOPlot: + """ + Add an SPI simple layer to the plot. + + Parameters + ---------- + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the SPI simple layer + + Returns + ------- + ISOPlot + The current plot instance for chaining + + """ + return self.layers.add_spi_simple( + data=data, + on_axis=on_axis, + **params, + ) + + def add_layer( + self, + layer_class: type[Layer], + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> ISOPlot: + """ + Add a visualization layer, optionally targeting specific subplot(s). + + Parameters + ---------- + layer_class : Layer subclass + The type of layer to add + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the layer + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Add a scatter layer using the generic add_layer method: + + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.plotting.new import ScatterLayer + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_layer(ScatterLayer) + >>> len(plot.main_context.layers) + 1 + >>> isinstance(plot.main_context.layers[0], ScatterLayer) + True + + Add a density layer to a specific subplot: + + >>> from soundscapy.plotting.new import DensityLayer + >>> plot = ISOPlot(data=data) + >>> plot = plot.create_subplots(nrows=2,ncols=2) + >>> plot = plot.add_layer(DensityLayer, on_axis=3, fill=False) + >>> len(plot.subplot_contexts[3].layers) + 1 + >>> isinstance(plot.subplot_contexts[3].layers[0], DensityLayer) + True + + Add a layer with custom data: + + >>> custom_data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0.5, 0.1, 50), + ... 'ISOEventful': rng.normal(0.5, 0.1, 50), + ... }) + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_layer(ScatterLayer, data=custom_data, color='red') + >>> len(plot.main_context.layers) + 1 + + """ + return self.layers.add_layer( + layer_class=layer_class, + data=data, + on_axis=on_axis, + **params, + ) + + def apply_styling( + self, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **style_params: Any, + ) -> ISOPlot: + """ + Apply styling to the plot. + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **style_params : Any + Additional styling parameters + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Apply default styling to a plot: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_scatter() + >>> plot = plot.apply_styling() + >>> isinstance(plot, ISOPlot) + True + + Apply custom styling to a plot: + + >>> plot = ISOPlot(data=data) + >>> plot = plot.add_scatter() + >>> plot = plot.apply_styling( + ... xlim=(-2, 2), + ... ylim=(-2, 2), + ... xlabel="Pleasant", + ... ylabel="Eventful", + ... primary_lines=True, + ... diagonal_lines=True + ... ) + >>> isinstance(plot, ISOPlot) + True + + Apply styling to a specific subplot: + + >>> plot = ISOPlot(data=data) + >>> plot = plot.create_subplots(nrows=2,ncols=2) + >>> plot = plot.add_scatter() + >>> plot = plot.apply_styling(on_axis=0, title="Subplot 0") + >>> isinstance(plot, ISOPlot) + True + + """ + return self.styling.apply_styling( + on_axis=on_axis, + **style_params, + ) + + @functools.wraps(plt.show) + def show(self) -> None: + """ + Display the plot. + + Examples + -------- + Create and show a plot: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot(data=data) + >>> plot.add_scatter() + >>> plot.apply_styling() + >>> # plot.show() # Uncomment to display the plot + + """ + plt.show() + + @functools.wraps(plt.close) + def close(self) -> None: + """ + Close the plot. + + Examples + -------- + Create, show, and close a plot: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot(data=data) + >>> plot.add_scatter() + >>> plot.apply_styling() + >>> # plot.show() # Uncomment to display the plot + >>> plot.close() # Close the plot + + """ + if self.figure is not None: + plt.close(self.figure) diff --git a/src/soundscapy/plotting/new/layer.py b/src/soundscapy/plotting/new/layer.py new file mode 100644 index 0000000..60b1ce1 --- /dev/null +++ b/src/soundscapy/plotting/new/layer.py @@ -0,0 +1,434 @@ +""" +Layer classes for visualization. + +This module provides the base Layer class and specialized layer implementations +for different visualization techniques. Layers know how to render themselves onto +a PlotContext's axes using parameters provided by the context. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, ClassVar, cast + +import seaborn as sns + +from soundscapy.plotting.new.constants import RECOMMENDED_MIN_SAMPLES +from soundscapy.sspylogging import get_logger + +if TYPE_CHECKING: + import pandas as pd + from matplotlib.axes import Axes + + from soundscapy.plotting.new.parameter_models import ( + BaseParams, + DensityParams, + ScatterParams, + SimpleDensityParams, + SPISimpleDensityParams, + ) + from soundscapy.plotting.new.protocols import PlotContext + +logger = get_logger() + + +class Layer: + """ + Base class for all visualization layers. + + A Layer encapsulates a specific visualization technique. Layers know how to + render themselves onto a PlotContext's axes using parameters provided by the context. + + Attributes + ---------- + custom_data : pd.DataFrame | None + Optional custom data for this specific layer, overriding context data + param_overrides : dict[str, Any] + Parameter overrides for this layer + + """ + + # Class registry for layer types + _layer_registry: ClassVar[dict[str, type[Layer]]] = {} + + # Parameter type this layer uses (for getting params from context) + param_type: ClassVar[str] = "base" + + def __init__( + self, + custom_data: pd.DataFrame | None = None, + **param_overrides: Any, + ) -> None: + """ + Initialize a Layer. + + Parameters + ---------- + custom_data : pd.DataFrame | None + Optional custom data for this specific layer, overriding context data + **param_overrides : dict[str, Any] + Parameter overrides for this layer + + """ + self.custom_data = custom_data + self.param_overrides = param_overrides + + def __init_subclass__(cls, **kwargs: Any) -> None: + """Register subclasses in the registry.""" + super().__init_subclass__(**kwargs) + # Skip registration for the base class + if cls is not Layer: + cls._layer_registry[cls.__name__.lower()] = cls + + def render(self, context: PlotContext) -> None: + """ + Render this layer on the given context. + + Parameters + ---------- + context : PlotContext + The context containing data and axes for rendering + + Raises + ------ + ValueError + If the context has no associated axes or data + + """ + if context.ax is None: + msg = "Cannot render layer: context has no associated axes" + raise ValueError(msg) + + # Use custom data if provided, otherwise context data + data = self.custom_data if self.custom_data is not None else context.data + + if data is None: + msg = "No data available for rendering layer" + raise ValueError(msg) + + # Get parameters from context and apply overrides + params = self._get_params_from_context(context) + + # Render the layer + self._render_implementation(data, context, context.ax, params) + + def _get_params_from_context(self, context: PlotContext) -> BaseParams: + """ + Get parameters from context and apply overrides. + + Parameters + ---------- + context : PlotContext + The context to get parameters from + + Returns + ------- + BaseParams + The parameters for this layer + + """ + # Get parameters from context based on layer type + params = context.get_params_for_layer(type(self)) + + # Apply overrides + if self.param_overrides: + params.update(**self.param_overrides) + + return cast("BaseParams", params) + + def _render_implementation( + self, + data: pd.DataFrame, + context: PlotContext, + ax: Axes, + params: BaseParams, + ) -> None: + """ + Implement actual rendering (to be overridden by subclasses). + + Parameters + ---------- + data : pd.DataFrame + The data to render + context : PlotContext + The context containing state for rendering + ax : Axes + The matplotlib axes to render on + params : BaseParams + The parameters for this layer + + Raises + ------ + NotImplementedError + If not implemented by subclass + + """ + msg = "Subclasses must implement _render_implementation" + raise NotImplementedError(msg) + + @classmethod + def create( + cls, context: PlotContext, layer_type: str | None = None, **kwargs: Any + ) -> Layer: + """ + Factory method to create a layer of the specified type. + + Parameters + ---------- + context : PlotContext + The context to associate with the layer + layer_type : str | None + The type of layer to create (e.g., 'scatter', 'density') + If None, uses the class name + **kwargs : Any + Additional parameters for the layer + + Returns + ------- + Layer + The created layer instance + + Raises + ------ + ValueError + If the layer type is unknown + + """ # noqa: D401 + if layer_type is None: + # Use the current class if no type specified + return cls(context=context, **kwargs) + + # Get the layer class from the registry + layer_type = layer_type.lower() + if layer_type not in cls._layer_registry: + msg = f"Unknown layer type: {layer_type}" + raise ValueError(msg) + + # Create and return the layer + layer_class = cls._layer_registry[layer_type] + return layer_class(**kwargs) + + +class ScatterLayer(Layer): + """Layer for rendering scatter plots.""" + + param_type = "scatter" + + def _render_implementation( + self, + data: pd.DataFrame, + context: PlotContext, + ax: Axes, + params: BaseParams, + ) -> None: + """ + Render a scatter plot. + + Parameters + ---------- + data : pd.DataFrame + The data to render + context : PlotContext + The context containing state for rendering + ax : Axes + The matplotlib axes to render on + params : BaseParams + The parameters for this layer + + """ + # Cast params to the correct type + scatter_params = cast("ScatterParams", params) + + # Create a copy of the parameters with data + kwargs = scatter_params.as_seaborn_kwargs() + kwargs["data"] = data + + # Ensure x and y are set correctly + kwargs["x"] = context.x + kwargs["y"] = context.y + + # Render the scatter plot + sns.scatterplot(ax=ax, **kwargs) + + +class DensityLayer(Layer): + """Layer for rendering density plots.""" + + param_type = "density" + + def _render_implementation( + self, + data: pd.DataFrame, + context: PlotContext, + ax: Axes, + params: BaseParams, + ) -> None: + """ + Render a density plot. + + Parameters + ---------- + data : pd.DataFrame + The data to render + context : PlotContext + The context containing state for rendering + ax : Axes + The matplotlib axes to render on + params : BaseParams + The parameters for this layer + + """ + # Check if we have enough data for a density plot + if len(data) < RECOMMENDED_MIN_SAMPLES: + warnings.warn( + "Density plots are not recommended for " + f"small datasets (<{RECOMMENDED_MIN_SAMPLES} samples).", + UserWarning, + stacklevel=2, + ) + + # Cast params to the correct type + density_params = cast("DensityParams", params) + + # Create a copy of the parameters with data + kwargs = density_params.as_seaborn_kwargs() + kwargs["data"] = data + + # Ensure x and y are set correctly + kwargs["x"] = context.x + kwargs["y"] = context.y + + # Render the density plot + sns.kdeplot(ax=ax, **kwargs) + + +class SimpleDensityLayer(DensityLayer): + """Layer for rendering simple density plots (filled contours).""" + + param_type = "simple_density" + + def _render_implementation( + self, + data: pd.DataFrame, + context: PlotContext, + ax: Axes, + params: BaseParams, + ) -> None: + """ + Render a simple density plot. + + Parameters + ---------- + data : pd.DataFrame + The data to render + context : PlotContext + The context containing state for rendering + ax : Axes + The matplotlib axes to render on + params : BaseParams + The parameters for this layer + + """ + # Check if we have enough data for a density plot + if len(data) < RECOMMENDED_MIN_SAMPLES: + warnings.warn( + "Density plots are not recommended for " + f"small datasets (<{RECOMMENDED_MIN_SAMPLES} samples).", + UserWarning, + stacklevel=2, + ) + + # Cast params to the correct type + simple_density_params = cast("SimpleDensityParams", params) + + # Create a copy of the parameters with data + kwargs = simple_density_params.as_seaborn_kwargs() + kwargs["data"] = data + + # Ensure x and y are set correctly + kwargs["x"] = context.x + kwargs["y"] = context.y + + # Set specific parameters for simple density + kwargs["levels"] = simple_density_params.levels + kwargs["thresh"] = getattr(simple_density_params, "thresh", 0.05) + + # Render the simple density plot + sns.kdeplot(ax=ax, **kwargs) + + +class SPISimpleLayer(SimpleDensityLayer): + """Layer for rendering SPI simple density plots.""" + + param_type = "spi_simple_density" + + def _render_implementation( + self, + data: pd.DataFrame, + context: PlotContext, + ax: Axes, + params: BaseParams, + ) -> None: + """ + Render an SPI simple density plot. + + Parameters + ---------- + data : pd.DataFrame + The data to render + context : PlotContext + The context containing state for rendering + ax : Axes + The matplotlib axes to render on + params : BaseParams + The parameters for this layer + + """ + # Cast params to the correct type + spi_params = cast("SPISimpleDensityParams", params) + + # Create a copy of the parameters with data + kwargs = spi_params.as_seaborn_kwargs() + kwargs["data"] = data + + # Ensure x and y are set correctly + kwargs["x"] = context.x + kwargs["y"] = context.y + + # Set specific parameters for SPI simple density + kwargs["color"] = spi_params.color + kwargs["label"] = spi_params.label + + # Render the SPI simple density plot + sns.kdeplot(ax=ax, **kwargs) + + # Add SPI score text if needed + if hasattr(spi_params, "show_score") and spi_params.show_score: + self._add_spi_score_text(context, ax, spi_params) + + def _add_spi_score_text( + self, context: PlotContext, ax: Axes, params: SPISimpleDensityParams + ) -> None: + """ + Add SPI score text to the plot. + + Parameters + ---------- + context : PlotContext + The context containing state for rendering + ax : Axes + The matplotlib axes to render on + params : SPISimpleDensityParams + The parameters for this layer + + """ + # This is a simplified version - in a real implementation, + # we would calculate and display the actual SPI score + if params.show_score == "under title": + # Add text under the title + if context.title: + ax.set_title(f"{context.title}\nSPI Score: 0.75") + elif params.show_score == "on axis" and params.axis_text_kw: + # Add text on the axis + text_kwargs = params.axis_text_kw.copy() + ax.text(s="SPI Score: 0.75", transform=ax.transAxes, **text_kwargs) diff --git a/src/soundscapy/plotting/new/managers.py b/src/soundscapy/plotting/new/managers.py new file mode 100644 index 0000000..484f900 --- /dev/null +++ b/src/soundscapy/plotting/new/managers.py @@ -0,0 +1,798 @@ +""" +Manager classes for the plotting module. + +This module provides manager classes that encapsulate specific functionality +for the ISOPlot class. These managers replace the mixin-based approach in the +original implementation, using composition instead of inheritance. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, Literal, cast + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.axes import Axes + +from soundscapy.plotting.new.constants import ( + RECOMMENDED_MIN_SAMPLES, +) +from soundscapy.plotting.new.layer import ( + DensityLayer, + Layer, + ScatterLayer, + SimpleDensityLayer, + SPISimpleLayer, +) +from soundscapy.plotting.new.protocols import RenderableLayer +from soundscapy.sspylogging import get_logger + +if TYPE_CHECKING: + from soundscapy.plotting.new.plot_context import PlotContext + +logger = get_logger() + + +class LayerManager: + """ + Manages the creation and rendering of visualization layers. + + This class encapsulates the layer-related functionality that was previously + implemented as a mixin in the ISOPlot class. + + Attributes + ---------- + plot : Any + The parent plot instance + + """ + + def __init__(self, plot: Any) -> None: + """ + Initialize a LayerManager. + + Parameters + ---------- + plot : Any + The parent plot instance + + """ + self.plot = plot + + def add_scatter( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> Any: + """ + Add a scatter layer to the plot. + + Parameters + ---------- + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the scatter layer + + Returns + ------- + Any + The parent plot instance for chaining + + """ + return self.add_layer( + ScatterLayer, + data=data, + on_axis=on_axis, + **params, + ) + + def add_density( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> Any: + """ + Add a density layer to the plot. + + Parameters + ---------- + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the density layer + + Returns + ------- + Any + The parent plot instance for chaining + + """ + # Check if we have enough data for a density plot + plot_data = data if data is not None else self.plot.main_context.data + if plot_data is not None and len(plot_data) < RECOMMENDED_MIN_SAMPLES: + warnings.warn( + "Density plots are not recommended for " + f"small datasets (<{RECOMMENDED_MIN_SAMPLES} samples).", + UserWarning, + stacklevel=2, + ) + + return self.add_layer( + DensityLayer, + data=data, + on_axis=on_axis, + **params, + ) + + def add_simple_density( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> Any: + """ + Add a simple density layer to the plot. + + Parameters + ---------- + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the simple density layer + + Returns + ------- + Any + The parent plot instance for chaining + + """ + return self.add_layer( + SimpleDensityLayer, + data=data, + on_axis=on_axis, + **params, + ) + + def add_spi_simple( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> Any: + """ + Add an SPI simple layer to the plot. + + Parameters + ---------- + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the SPI simple layer + + Returns + ------- + Any + The parent plot instance for chaining + + """ + return self.add_layer( + SPISimpleLayer, + data=data, + on_axis=on_axis, + **params, + ) + + def add_layer( + self, + layer_class: type[RenderableLayer], + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> Any: + """ + Add a visualization layer, optionally targeting specific subplot(s). + + Parameters + ---------- + layer_class : type[RenderableLayer] + The type of layer to add + data : pd.DataFrame | None, optional + Custom data for this layer, by default None + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the layer + + Returns + ------- + Any + The parent plot instance for chaining + + """ + # Create the layer instance + layer = cast("Layer", layer_class(custom_data=data, **params)) + + # Check if we have axes to render on + self._check_for_axes() + + # If no subplots created yet, add to main context + if not self.plot.subplot_contexts: + if self.plot.main_context.ax is None: + # Get the single axis and assign it to main context + if isinstance(self.plot.axes, Axes): + self.plot.main_context.ax = self.plot.axes + elif isinstance(self.plot.axes, np.ndarray) and self.plot.axes.size > 0: + self.plot.main_context.ax = self.plot.axes.flatten()[0] + + # Add layer to main context + self.plot.main_context.layers.append(layer) + # Render the layer immediately + layer.render(self.plot.main_context) + return self.plot + + # Handle various axis targeting options + target_contexts = self._resolve_target_contexts(on_axis) + + # Add the layer to each target context and render it + for context in target_contexts: + context.layers.append(layer) + layer.render(context) + + return self.plot + + def _check_for_axes(self) -> None: + """ + Check if we have axes to render on, create if needed. + + This method ensures that the plot has axes to render on, + creating them if necessary. + """ + if self.plot.figure is None: + # Create a new figure and axes + self.plot.figure, self.plot.axes = plt.subplots(figsize=(5, 5)) + + def _resolve_target_contexts( + self, on_axis: int | tuple[int, int] | list[int] | None + ) -> list[PlotContext]: + """ + Resolve which subplot contexts to target based on axis specification. + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None + The axis specification: + - None: All subplot contexts + - int: Single subplot at flattened index + - tuple[int, int]: Subplot at (row, col) + - list[int]: Multiple subplots at specified indices + + Returns + ------- + list[PlotContext] + List of target subplot contexts + + """ + # If no specific axis, target all subplot contexts + if on_axis is None: + return self.plot.subplot_contexts + + # Convert axis specification to list of indices + indices = self._resolve_axis_indices(on_axis) + + # Get the contexts for each valid index + target_contexts = [] + for idx in indices: + if 0 <= idx < len(self.plot.subplot_contexts): + target_contexts.append(self.plot.subplot_contexts[idx]) + else: + msg = f"Subplot index {idx} out of range" + raise IndexError(msg) + + return target_contexts + + def _resolve_axis_indices( + self, on_axis: int | tuple[int, int] | list[int] + ) -> list[int]: + """ + Convert axis specification to list of indices. + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] + The axis specification to resolve + + Returns + ------- + list[int] + List of flattened indices + + Raises + ------ + ValueError + If an invalid axis specification is provided + + """ + if isinstance(on_axis, int): + return [on_axis] + if isinstance(on_axis, tuple) and len(on_axis) == 2: + # Convert (row, col) to flattened index + row, col = on_axis + return [row * self.plot.subplots_params.ncols + col] + if isinstance(on_axis, list): + return on_axis + msg = f"Invalid axis specification: {on_axis}" + raise ValueError(msg) + + +class StyleManager: + """ + Manages the styling of plots. + + This class encapsulates the styling-related functionality that was previously + implemented as a mixin in the ISOPlot class. + + Attributes + ---------- + plot : Any + The parent plot instance + + """ + + def __init__(self, plot: Any) -> None: + """ + Initialize a StyleManager. + + Parameters + ---------- + plot : Any + The parent plot instance + + """ + self.plot = plot + + def apply_styling( + self, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **style_params: Any, + ) -> Any: + """ + Apply styling to the plot. + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes, by default None + **style_params : Any + Additional styling parameters + + Returns + ------- + Any + The parent plot instance for chaining + + """ + # Update style parameters + self.plot.main_context.update_params("style", **style_params) + + # If no subplots, apply to main context + if not self.plot.subplot_contexts: + self._apply_styling_to_context(self.plot.main_context) + return self.plot + + # Apply to specified subplots + target_contexts = self._resolve_target_contexts(on_axis) + for context in target_contexts: + self._apply_styling_to_context(context) + + return self.plot + + def _apply_styling_to_context(self, context: PlotContext) -> None: + """ + Apply styling to a specific context. + + Parameters + ---------- + context : PlotContext + The context to apply styling to + + """ + if context.ax is None: + return + + # Get style parameters + style_params = context.get_params("style") + + # Apply styling to the axes + ax = context.ax + + # Set limits + if hasattr(style_params, "xlim"): + ax.set_xlim(style_params.xlim) + if hasattr(style_params, "ylim"): + ax.set_ylim(style_params.ylim) + + # Set labels + if hasattr(style_params, "xlabel") and style_params.xlabel is not False: + xlabel = style_params.xlabel or context.x + ax.set_xlabel(xlabel, fontdict=style_params.prim_ax_fontdict) + + if hasattr(style_params, "ylabel") and style_params.ylabel is not False: + ylabel = style_params.ylabel or context.y + ax.set_ylabel(ylabel, fontdict=style_params.prim_ax_fontdict) + + # Set title + if context.title: + ax.set_title(context.title, fontsize=style_params.title_fontsize) + + # Add primary axes lines + if hasattr(style_params, "primary_lines") and style_params.primary_lines: + self._add_primary_lines(ax, style_params) + + # Add diagonal lines + if hasattr(style_params, "diagonal_lines") and style_params.diagonal_lines: + self._add_diagonal_lines(ax, style_params) + + # Add legend if needed + if hasattr(style_params, "legend_loc") and style_params.legend_loc: + ax.legend(loc=style_params.legend_loc) + + def _add_primary_lines(self, ax: Axes, style_params: Any) -> None: + """ + Add primary axes lines to the plot. + + Parameters + ---------- + ax : Axes + The axes to add lines to + style_params : Any + The style parameters + + """ + # Add horizontal and vertical lines at 0 + ax.axhline( + y=0, + color="black", + linestyle="-", + linewidth=style_params.linewidth, + zorder=style_params.prim_lines_zorder, + ) + ax.axvline( + x=0, + color="black", + linestyle="-", + linewidth=style_params.linewidth, + zorder=style_params.prim_lines_zorder, + ) + + def _add_diagonal_lines(self, ax: Axes, style_params: Any) -> None: + """ + Add diagonal lines to the plot. + + Parameters + ---------- + ax : Axes + The axes to add lines to + style_params : Any + The style parameters + + """ + # Add diagonal lines + xlim = ax.get_xlim() + ylim = ax.get_ylim() + + # Diagonal line from bottom-left to top-right + ax.plot( + xlim, + ylim, + color="black", + linestyle="--", + linewidth=style_params.linewidth, + zorder=style_params.diag_lines_zorder, + ) + + # Diagonal line from bottom-right to top-left + ax.plot( + xlim, + ylim[::-1], + color="black", + linestyle="--", + linewidth=style_params.linewidth, + zorder=style_params.diag_lines_zorder, + ) + + def _resolve_target_contexts( + self, on_axis: int | tuple[int, int] | list[int] | None + ) -> list[PlotContext]: + """ + Resolve which subplot contexts to target based on axis specification. + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None + The axis specification: + - None: All subplot contexts + - int: Single subplot at flattened index + - tuple[int, int]: Subplot at (row, col) + - list[int]: Multiple subplots at specified indices + + Returns + ------- + list[PlotContext] + List of target subplot contexts + + """ + # If no specific axis, target all subplot contexts + if on_axis is None: + return self.plot.subplot_contexts + + # Convert axis specification to list of indices + indices = self._resolve_axis_indices(on_axis) + + # Get the contexts for each valid index + target_contexts = [] + for idx in indices: + if 0 <= idx < len(self.plot.subplot_contexts): + target_contexts.append(self.plot.subplot_contexts[idx]) + else: + msg = f"Subplot index {idx} out of range" + raise IndexError(msg) + + return target_contexts + + def _resolve_axis_indices( + self, on_axis: int | tuple[int, int] | list[int] + ) -> list[int]: + """ + Convert axis specification to list of indices. + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] + The axis specification to resolve + + Returns + ------- + list[int] + List of flattened indices + + Raises + ------ + ValueError + If an invalid axis specification is provided + + """ + if isinstance(on_axis, int): + return [on_axis] + if isinstance(on_axis, tuple) and len(on_axis) == 2: + # Convert (row, col) to flattened index + row, col = on_axis + return [row * self.plot.subplots_params.ncols + col] + if isinstance(on_axis, list): + return on_axis + msg = f"Invalid axis specification: {on_axis}" + raise ValueError(msg) + + +class SubplotManager: + """ + Manages the creation and configuration of subplots. + + This class encapsulates the subplot-related functionality that was previously + implemented directly in the ISOPlot class. + + Attributes + ---------- + plot : Any + The parent plot instance + + """ + + def __init__(self, plot: Any) -> None: + """ + Initialize a SubplotManager. + + Parameters + ---------- + plot : Any + The parent plot instance + + """ + self.plot = plot + + def create_subplots( + self, + nrows: int = 1, + ncols: int = 1, + figsize: tuple[float, float] | None = None, + sharex: bool | Literal["none", "all", "row", "col"] = True, + sharey: bool | Literal["none", "all", "row", "col"] = True, + subplot_by: str | None = None, + **kwargs: Any, + ) -> Any: + """ + Create a grid of subplots. + + Parameters + ---------- + nrows : int, optional + Number of rows, by default 1 + ncols : int, optional + Number of columns, by default 1 + figsize : tuple[float, float] | None, optional + Figure size, by default None + sharex : bool | Literal["none", "all", "row", "col"], optional + Whether to share x-axis, by default True + sharey : bool | Literal["none", "all", "row", "col"], optional + Whether to share y-axis, by default True + subplot_by : str | None, optional + Column to create subplots by, by default None + **kwargs : Any + Additional parameters for subplots + + Returns + ------- + Any + The parent plot instance for chaining + + """ + # Update subplot parameters + self.plot.subplots_params.update( + nrows=nrows, + ncols=ncols, + figsize=figsize or (5 * ncols, 5 * nrows), + sharex=sharex, + sharey=sharey, + subplot_by=subplot_by, + **kwargs, + ) + + # Create figure and axes + self.plot.figure, self.plot.axes = plt.subplots( + **self.plot.subplots_params.as_plt_subplots_args() + ) + + # Create subplot contexts + self._create_subplot_contexts() + + return self.plot + + def _create_subplot_contexts(self) -> None: + """ + Create subplot contexts based on the current configuration. + + This method creates a PlotContext for each subplot, either with + the same data or with data split by a grouping variable. + """ + # Clear existing subplot contexts + self.plot.subplot_contexts = [] + + # Get subplot parameters + params = self.plot.subplots_params + + # If subplot_by is specified, create subplots by group + if params.subplot_by and self.plot.main_context.data is not None: + self._create_subplots_by_group() + return + + # Otherwise, create a grid of subplots with the same data + axes = self.plot.axes + if not isinstance(axes, np.ndarray): + axes = np.array([[axes]]) + + # Create a context for each axis + for i in range(params.nrows): + for j in range(params.ncols): + # Get the axis for this subplot + ax = ( + axes[i, j] + if params.nrows > 1 and params.ncols > 1 + else axes[i] + if params.nrows > 1 + else axes[j] + if params.ncols > 1 + else axes + ) + + # Create a title for this subplot + title = ( + f"Subplot {i * params.ncols + j + 1}" + if self.plot.main_context.title is None + else f"{self.plot.main_context.title} {i * params.ncols + j + 1}" + ) + + # Create a child context for this subplot + context = self.plot.main_context.create_child( + ax=ax, + title=title, + ) + + # Add to subplot contexts + self.plot.subplot_contexts.append(context) + + def _create_subplots_by_group(self) -> None: + """ + Create subplots by grouping the data. + + This method creates a subplot for each unique value in the + subplot_by column of the data. + """ + # Get subplot parameters + params = self.plot.subplots_params + subplot_by = params.subplot_by + + if subplot_by is None or self.plot.main_context.data is None: + return + + # Get unique values in the subplot_by column + data = self.plot.main_context.data + groups = data[subplot_by].unique() + + # Limit to the number of subplots if specified + if params.n_subplots_by > 0: + groups = groups[: params.n_subplots_by] + + # Check if we have enough subplots + if len(groups) > params.n_subplots: + msg = f"Not enough subplots for all groups: {len(groups)} groups, {params.n_subplots} subplots" + raise ValueError(msg) + + # Get axes array + axes = self.plot.axes + if not isinstance(axes, np.ndarray): + axes = np.array([[axes]]) + + # Create a context for each group + for i, group in enumerate(groups): + # Get the row and column for this subplot + row = i // params.ncols + col = i % params.ncols + + # Get the axis for this subplot + ax = ( + axes[row, col] + if params.nrows > 1 and params.ncols > 1 + else axes[row] + if params.nrows > 1 + else axes[col] + if params.ncols > 1 + else axes + ) + + # Filter data for this group + group_data = data[data[subplot_by] == group] + + # Create a title for this subplot + title = ( + f"{group}" + if self.plot.main_context.title is None + else f"{self.plot.main_context.title}: {group}" + ) + + # Create a child context for this subplot + context = self.plot.main_context.create_child( + data=group_data, + ax=ax, + title=title, + ) + + # Add to subplot contexts + self.plot.subplot_contexts.append(context) diff --git a/src/soundscapy/plotting/new/parameter_models.py b/src/soundscapy/plotting/new/parameter_models.py new file mode 100644 index 0000000..118985c --- /dev/null +++ b/src/soundscapy/plotting/new/parameter_models.py @@ -0,0 +1,377 @@ +""" +Parameter models for the plotting module. + +This module provides Pydantic models for parameter validation and management. +These models replace the dictionary-based defaults in the original implementation +and provide a single source of truth for parameter values with proper type validation. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias + +import numpy as np +import pandas as pd +from matplotlib.colors import Colormap +from pydantic import BaseModel, ConfigDict, Field +from pydantic.alias_generators import to_snake + +from soundscapy.plotting.new.constants import ( + DEFAULT_BW_ADJUST, + DEFAULT_COLOR, + DEFAULT_FONTDICT, + DEFAULT_POINT_SIZE, + DEFAULT_SPI_TEXT_KWARGS, + DEFAULT_XCOL, + DEFAULT_XLIM, + DEFAULT_YCOL, + DEFAULT_YLIM, +) +from soundscapy.sspylogging import get_logger + +if TYPE_CHECKING: + from matplotlib.typing import ColorType + +logger = get_logger() + +# Type aliases +SeabornPaletteType: TypeAlias = str | list | dict | Colormap + +MplLegendLocType: TypeAlias = ( + Literal[ + "best", + "upper right", + "upper left", + "lower left", + "lower right", + "right", + "center left", + "center right", + "lower center", + "upper center", + "center", + ] + | tuple[float, float] +) + + +class BaseParams(BaseModel): + """ + Base model for all parameter types. + + This class provides common configuration settings and utility methods + for all parameter models. + """ + + model_config = ConfigDict( + extra="allow", # Allow extra fields for flexibility + arbitrary_types_allowed=True, # Allow complex matplotlib types + validate_assignment=True, # Validate when attributes are set + alias_generator=to_snake, # Use snake_case for aliases + ) + + def update( + self, + *, + extra: Literal["allow", "forbid", "ignore"] = "allow", + na_rm: bool = True, + **kwargs: Any, + ) -> Self: + """ + Update parameters with new values. + + Parameters + ---------- + extra : Literal["allow", "forbid", "ignore"], optional + Controls how extra fields are handled. Default is "allow". + na_rm : bool, optional + Whether to remove None values. Default is True. + **kwargs : Any + New parameter values + + Returns + ------- + Self + The updated parameter instance (for chaining) + + """ + if extra == "forbid": + # Forbid extra fields + unknown_keys = set(kwargs) - set(self.model_fields) + if unknown_keys: + msg = f"Unknown parameters: {unknown_keys}" + raise ValueError(msg) + elif extra == "ignore": + # Ignore extra fields + kwargs = {k: v for k, v in kwargs.items() if k in self.model_fields} + elif extra != "allow": + msg = f"Invalid value for 'extra': {extra}" + raise ValueError(msg) + + # Remove None values if na_rm is True + if na_rm: + # Filter out None values to avoid overriding defaults with None + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except ValueError as e: # noqa: PERF203 + logger.warning("Invalid value for %s: %s", key, e) + + return self + + def as_dict(self, **kwargs) -> dict[str, Any]: + """ + Get all parameters as a dictionary. + + Returns + ------- + Dict[str, Any] + Dictionary of parameter values + + """ + return self.model_dump(**kwargs) + + def get_changed_params(self) -> dict[str, Any]: + """ + Get parameters that have been changed from their defaults. + + Returns + ------- + Dict[str, Any] + Dictionary of changed parameters + + """ + return self.model_dump(exclude_unset=True) + + +class AxisParams(BaseParams): + """Parameters for axis configuration.""" + + x: str = DEFAULT_XCOL + y: str = DEFAULT_YCOL + xlim: tuple[float, float] = DEFAULT_XLIM + ylim: tuple[float, float] = DEFAULT_YLIM + xlabel: str | Literal[False] | None = r"$P_{ISO}$" + ylabel: str | Literal[False] | None = r"$E_{ISO}$" + + +class SeabornParams(BaseParams): + """Base parameters for seaborn plotting functions.""" + + data: pd.DataFrame | None = None + x: str | np.ndarray | pd.Series | None = DEFAULT_XCOL + y: str | np.ndarray | pd.Series | None = DEFAULT_YCOL + palette: SeabornPaletteType | None = "colorblind" + alpha: float = 0.8 + color: ColorType | None = DEFAULT_COLOR + zorder: float = 3 + hue: str | np.ndarray | pd.Series | None = None + + def crosscheck_palette_hue(self) -> None: + """ + Check if the palette is valid for the given hue. + + Sets palette to None if hue is None. + """ + self.palette = self.palette if self.hue is not None else None + + def as_seaborn_kwargs(self) -> dict[str, Any]: + """ + Convert parameters to kwargs compatible with seaborn functions. + + Returns + ------- + dict[str, Any] + Dictionary of parameter values suitable for seaborn plotting functions. + + """ + return self.as_dict() + + +class ScatterParams(SeabornParams): + """Parameters for scatter plot functions.""" + + s: float | None = DEFAULT_POINT_SIZE + + +class DensityParams(SeabornParams): + """Parameters for density plot functions.""" + + fill: bool = True + common_norm: bool = False + common_grid: bool = False + bw_adjust: float = DEFAULT_BW_ADJUST + levels: int | tuple[float, ...] = 10 + clip: tuple[tuple[float, float], tuple[float, float]] | None = ( + DEFAULT_XLIM, + DEFAULT_YLIM, + ) + + def to_outline( + self, + *, + alpha: float = 1, + fill: bool = False, + ) -> Self: + """ + Get parameters for the outline of density plots. + + Parameters + ---------- + alpha : float, optional + The alpha value for the outline. Default is 1. + fill : bool, optional + Whether to fill the outline. Default is False. + + Returns + ------- + Self + The parameters for the outline of density plots. + + """ + return self.model_copy(update={"alpha": alpha, "fill": fill, "legend": False}) + + +class SimpleDensityParams(DensityParams): + """Parameters for simple density plots.""" + + # Override default levels for simple density plots + thresh: float = 0.5 + levels: int | tuple[float, ...] = 2 + alpha: float = 0.5 + + +class SPISeabornParams(SeabornParams): + """Base parameters for seaborn plotting functions for SPI data.""" + + color: ColorType | None = "red" + hue: str | np.ndarray | pd.Series | None = None + palette: SeabornPaletteType | None = None + label: str = "SPI" + n: int = 1000 + show_score: Literal["on axis", "under title"] = "under title" + axis_text_kw: dict[str, Any] | None = Field( + default_factory=lambda: DEFAULT_SPI_TEXT_KWARGS.copy() + ) + + def as_seaborn_kwargs(self) -> dict[str, Any]: + """ + Convert parameters to kwargs compatible with seaborn functions. + + Returns + ------- + dict[str, Any] + Dictionary of parameter values suitable for seaborn plotting functions. + + """ + new = self.model_copy() + # Drop SPI-specific parameters + return new.as_dict(exclude={"n", "show_score", "axis_text_kw"}) + + +class SPISimpleDensityParams(SimpleDensityParams): + """Parameters for simple density plotting of SPI data.""" + + color: ColorType | None = "red" + label: str = "SPI" + n: int = 1000 + show_score: Literal["on axis", "under title"] = "under title" + axis_text_kw: dict[str, Any] | None = Field( + default_factory=lambda: DEFAULT_SPI_TEXT_KWARGS.copy() + ) + + def as_seaborn_kwargs(self) -> dict[str, Any]: + """ + Convert parameters to kwargs compatible with seaborn functions. + + Returns + ------- + dict[str, Any] + Dictionary of parameter values suitable for seaborn plotting functions. + + """ + return self.as_dict(exclude={"n", "show_score", "axis_text_kw"}) + + +class JointPlotParams(BaseParams): + """Parameters for jointplot functions.""" + + data: pd.DataFrame | None = None + x: str | np.ndarray | pd.Series | None = DEFAULT_XCOL + y: str | np.ndarray | pd.Series | None = DEFAULT_YCOL + xlim: tuple[float, float] | None = DEFAULT_XLIM + ylim: tuple[float, float] | None = DEFAULT_YLIM + hue: str | np.ndarray | pd.Series | None = None + palette: SeabornPaletteType | None = "colorblind" + marginal_ticks: bool | None = None + + +class StyleParams(BaseParams): + """ + Configuration options for styling circumplex plots. + """ + + xlim: tuple[float, float] = DEFAULT_XLIM + ylim: tuple[float, float] = DEFAULT_YLIM + xlabel: str | Literal[False] | None = r"$P_{ISO}$" + ylabel: str | Literal[False] | None = r"$E_{ISO}$" + diag_lines_zorder: int = 1 + diag_labels_zorder: int = 4 + prim_lines_zorder: int = 2 + data_zorder: int = 3 + legend_loc: MplLegendLocType | Literal[False] = "best" + linewidth: float = 1.5 + primary_lines: bool = True + diagonal_lines: bool = False + title_fontsize: int = 14 + prim_ax_fontdict: dict[str, Any] = Field( + default_factory=lambda: DEFAULT_FONTDICT.copy() + ) + + +class SubplotsParams(BaseParams): + """Parameters for subplot configuration.""" + + nrows: int = 1 + ncols: int = 1 + figsize: tuple[float, float] = (5, 5) + sharex: bool | Literal["none", "all", "row", "col"] = True + sharey: bool | Literal["none", "all", "row", "col"] = True + subplot_by: str | None = None + n_subplots_by: int = -1 + auto_allocate_axes: bool = False + adjust_figsize: bool = True + + @property + def n_subplots(self) -> int: + """ + Calculate the total number of subplots. + + Returns + ------- + int + Total number of subplots. + + """ + return self.nrows * self.ncols + + def as_plt_subplots_args(self) -> dict[str, Any]: + """ + Pass matplotlib subplot arguments to a plt.subplots call. + + Returns + ------- + dict[str, Any] + Dictionary of subplot parameters. + + """ + return self.as_dict( + exclude={ + "subplot_by", + "n_subplots_by", + "auto_allocate_axes", + "adjust_figsize", + } + ) diff --git a/src/soundscapy/plotting/new/plot_context.py b/src/soundscapy/plotting/new/plot_context.py new file mode 100644 index 0000000..556ff66 --- /dev/null +++ b/src/soundscapy/plotting/new/plot_context.py @@ -0,0 +1,254 @@ +""" +Data and state management for plotting layers. + +This module provides the PlotContext class that manages data, state, and parameters +for ISOPlot visualizations. The PlotContext is the central component in the plotting +architecture, owning both data and parameter models for different layer types. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +from soundscapy.plotting.new.constants import DEFAULT_XCOL, DEFAULT_YCOL +from soundscapy.plotting.new.parameter_models import ( + BaseParams, + DensityParams, + ScatterParams, + SimpleDensityParams, + SPISimpleDensityParams, + StyleParams, +) +from soundscapy.sspylogging import get_logger + +if TYPE_CHECKING: + import pandas as pd + from matplotlib.axes import Axes + + from soundscapy.plotting.new.protocols import ParamModel, RenderableLayer + +logger = get_logger() + + +class PlotContext: + """ + Manages data, state, and parameters for a plot or subplot. + + This class centralizes the management of data, coordinates, parameters, and other + state needed for rendering plot layers. It owns parameter models for different + layer types and provides them to layers when needed. + + Attributes + ---------- + data : pd.DataFrame | None + The data associated with this context + x : str + The column name for x-axis data + y : str + The column name for y-axis data + hue : str | None + The column name for color encoding, if any + ax : Axes | None + The matplotlib Axes object this context is associated with + title : str | None + The title for this context's plot + layers : list[RenderableLayer] + The visualization layers to be rendered on this context + parent : PlotContext | None + The parent context, if this is a child context + + """ + + def __init__( + self, + data: pd.DataFrame | None = None, + x: str = DEFAULT_XCOL, + y: str = DEFAULT_YCOL, + hue: str | None = None, + ax: Axes | None = None, + title: str | None = None, + ) -> None: + """ + Initialize a PlotContext. + + Parameters + ---------- + data : pd.DataFrame | None + Data to be visualized + x : str + Column name for x-axis data + y : str + Column name for y-axis data + hue : str | None + Column name for color encoding + ax : Axes | None + Matplotlib axis to render on + title : str | None + Title for this plot context + + """ + # Basic properties + self.data = data + self.x = x + self.y = y + self.hue = hue + self.ax = ax + self.title = title + self.layers: list[RenderableLayer] = [] + self.parent: PlotContext | None = None + + # Parameter models for different layer types + self._param_models: dict[str, BaseParams] = {} + + # Initialize default parameter models + self._init_param_models() + + def _init_param_models(self) -> None: + """Initialize parameter models with context values.""" + # Common parameters for all models + common_params = { + "data": self.data, + "x": self.x, + "y": self.y, + "hue": self.hue, + } + + # Create parameter models for different layer types + self._param_models["scatter"] = ScatterParams(**common_params) + self._param_models["density"] = DensityParams(**common_params) + self._param_models["simple_density"] = SimpleDensityParams(**common_params) + self._param_models["spi_simple_density"] = SPISimpleDensityParams( + **common_params + ) + self._param_models["style"] = StyleParams( + xlim=(-1, 1), + ylim=(-1, 1), + xlabel=r"$P_{ISO}$", + ylabel=r"$E_{ISO}$", + ) + + def get_params(self, param_type: str) -> BaseParams: + """ + Get parameters for a specific type. + + Parameters + ---------- + param_type : str + The type of parameters to get (e.g., 'scatter', 'density') + + Returns + ------- + BaseParams + The parameter model instance + + Raises + ------ + ValueError + If the parameter type is unknown + + """ + if param_type not in self._param_models: + msg = f"Unknown parameter type: {param_type}" + raise ValueError(msg) + + return self._param_models[param_type] + + def get_params_for_layer(self, layer_type: type[RenderableLayer]) -> ParamModel: + """ + Get parameters appropriate for a specific layer type. + + This method maps layer types to their corresponding parameter models. + + Parameters + ---------- + layer_type : type[RenderableLayer] + The type of layer to get parameters for + + Returns + ------- + ParamModel + The parameter model instance + + """ + # Map layer class names to parameter types + # This could be improved with a more formal registry + layer_name = layer_type.__name__.lower() + + if "scatter" in layer_name: + return cast("ParamModel", self.get_params("scatter")) + if "simpledensity" in layer_name: + if "spi" in layer_name: + return cast("ParamModel", self.get_params("spi_simple_density")) + return cast("ParamModel", self.get_params("simple_density")) + if "density" in layer_name: + return cast("ParamModel", self.get_params("density")) + + # Default to scatter parameters if no match + logger.warning( + f"No specific parameters for layer type {layer_type.__name__}, " # noqa: G004 + "using scatter parameters" + ) + return cast("ParamModel", self.get_params("scatter")) + + def update_params(self, param_type: str, **kwargs: Any) -> BaseParams: + """ + Update parameters for a specific type. + + Parameters + ---------- + param_type : str + The type of parameters to update + **kwargs : Any + Parameter values to update + + Returns + ------- + BaseParams + The updated parameter model instance + + """ + params = self.get_params(param_type) + params.update(**kwargs) + return params + + def create_child( + self, + data: pd.DataFrame | None = None, + title: str | None = None, + ax: Axes | None = None, + ) -> PlotContext: + """ + Create a child context that inherits properties from this context. + + Parameters + ---------- + data : pd.DataFrame | None + Data for the child context. If None, inherits from parent. + title : str | None + Title for the child context + ax : Axes | None + Matplotlib axis for the child context + + Returns + ------- + PlotContext + A new child context with inherited properties + + """ + child = PlotContext( + data=data if data is not None else self.data, + x=self.x, + y=self.y, + hue=self.hue, + ax=ax, + title=title, + ) + + # Copy parameter models from parent to child + for param_type, model in self._param_models.items(): + child._param_models[param_type] = model.model_copy() + + # Set parent reference + child.parent = self + + return child diff --git a/src/soundscapy/plotting/new/protocols.py b/src/soundscapy/plotting/new/protocols.py new file mode 100644 index 0000000..15765f3 --- /dev/null +++ b/src/soundscapy/plotting/new/protocols.py @@ -0,0 +1,144 @@ +""" +Protocol classes for the plotting module. + +This module defines Protocol classes that specify interfaces for various +components in the plotting system. These protocols enable structural typing +rather than nominal typing, making it easier to compose functionality without +complex inheritance hierarchies. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, TypeVar + +if TYPE_CHECKING: + import pandas as pd + from matplotlib.axes import Axes + +# Type variable for generic parameter models +P = TypeVar("P", bound="ParamModel") + + +class RenderableLayer(Protocol): + """Protocol defining what a renderable layer must implement.""" + + def render(self, context: PlotContext) -> None: + """ + Render the layer on the given context. + + Parameters + ---------- + context : PlotContext + The context containing data and axes for rendering + + """ + ... + + +class ParameterProvider(Protocol): + """Protocol defining how parameters are provided.""" + + def get_params(self, param_type: str) -> ParamModel: + """ + Get parameters for a specific type. + + Parameters + ---------- + param_type : str + The type of parameters to get + + Returns + ------- + ParamModel + The parameter model instance + + """ + ... + + +class ParamModel(Protocol): + """Protocol defining the interface for parameter models.""" + + def update(self, **kwargs: Any) -> Any: + """ + Update parameters with new values. + + Parameters + ---------- + **kwargs : Any + New parameter values + + Returns + ------- + Self + The updated parameter instance + + """ + ... + + def as_dict(self, **kwargs: Any) -> dict[str, Any]: + """ + Get all parameters as a dictionary. + + Returns + ------- + Dict[str, Any] + Dictionary of parameter values + + """ + ... + + +class PlotContext(Protocol): + """Protocol defining the interface for plot contexts.""" + + data: pd.DataFrame | None + x: str + y: str + hue: str | None + ax: Axes | None + title: str | None + layers: list[RenderableLayer] + + def get_params_for_layer(self, layer_type: type[RenderableLayer]) -> ParamModel: + """ + Get parameters appropriate for a specific layer type. + + Parameters + ---------- + layer_type : type[RenderableLayer] + The type of layer to get parameters for + + Returns + ------- + ParamModel + The parameter model instance + + """ + ... + + def create_child( + self, + data: pd.DataFrame | None = None, + title: str | None = None, + ax: Axes | None = None, + ) -> PlotContext: + """ + Create a child context that inherits properties from this context. + + Parameters + ---------- + data : pd.DataFrame | None + Data for the child context. If None, inherits from parent. + title : str | None + Title for the child context + ax : Axes | None + Matplotlib axis for the child context + + Returns + ------- + PlotContext + A new child context with inherited properties + + """ + ... From af71ad451748d94e79b4ee856fda80cbabf8a2f2 Mon Sep 17 00:00:00 2001 From: Andrew Mitchell Date: Sat, 10 May 2025 03:51:46 +0100 Subject: [PATCH 4/8] Revert first pass at refactoring New refactoring is in the plotting/new submodule Signed-off-by: Andrew Mitchell --- src/soundscapy/plotting/iso_plot.py | 778 ++++++++++++++++++-- src/soundscapy/plotting/iso_plot_layers.py | 487 ------------ src/soundscapy/plotting/iso_plot_styling.py | 377 ---------- 3 files changed, 733 insertions(+), 909 deletions(-) delete mode 100644 src/soundscapy/plotting/iso_plot_layers.py delete mode 100644 src/soundscapy/plotting/iso_plot_styling.py diff --git a/src/soundscapy/plotting/iso_plot.py b/src/soundscapy/plotting/iso_plot.py index 32a1f49..4fa6f60 100644 --- a/src/soundscapy/plotting/iso_plot.py +++ b/src/soundscapy/plotting/iso_plot.py @@ -4,7 +4,7 @@ Example: ------- >>> from soundscapy import isd, surveys ->>> from soundscapy.plotting import ISOPlot +>>> from soundscapy.plotting.iso_plot_new import ISOPlot >>> df = isd.load() >>> df = surveys.add_iso_coords(df) >>> sub_df = isd.select_location_ids(df, ['CamdenTown', 'RegentsParkJapan']) @@ -19,7 +19,7 @@ ... .add_simple_density(fill=False) ... .apply_styling() ... ) ->>> isoplot.show() +>>> isoplot.show() # xdoctest: +SKIP """ # ruff: noqa: SLF001, G004 @@ -30,19 +30,30 @@ import warnings from typing import TYPE_CHECKING, Any -import matplotlib.pyplot as plt import numpy as np import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from matplotlib import ticker +from matplotlib.artist import Artist from matplotlib.axes import Axes from matplotlib.figure import Figure, SubFigure from soundscapy.plotting.defaults import ( + DEFAULT_STYLE_PARAMS, DEFAULT_XCOL, DEFAULT_YCOL, RECOMMENDED_MIN_SAMPLES, ) -from soundscapy.plotting.iso_plot_layers import ISOPlotLayersMixin -from soundscapy.plotting.iso_plot_styling import ISOPlotStylingMixin +from soundscapy.plotting.layers import ( + DensityLayer, + Layer, + ScatterLayer, + SimpleDensityLayer, + SPIDensityLayer, + SPIScatterLayer, + SPISimpleLayer, +) from soundscapy.plotting.plot_context import PlotContext from soundscapy.plotting.plotting_types import ( ParamModel, @@ -54,13 +65,16 @@ if TYPE_CHECKING: from collections.abc import Generator - from soundscapy.plotting.layers import Layer from soundscapy.plotting.plotting_types import SeabornPaletteType + from soundscapy.spi.msn import ( + CentredParams, + DirectParams, + ) logger = get_logger() -class ISOPlot(ISOPlotLayersMixin, ISOPlotStylingMixin): +class ISOPlot: """ A class for creating circumplex plots using different backends. @@ -78,38 +92,7 @@ class ISOPlot(ISOPlotLayersMixin, ISOPlotStylingMixin): ... .add_scatter() ... .add_density() ... .apply_styling()) - >>> cp.show() - - Create a plot with default parameters: - - >>> import pandas as pd - >>> import numpy as np - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame( - ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... columns=['ISOPleasant', 'ISOEventful'] - ... ) - >>> plot = ISOPlot() - >>> isinstance(plot, ISOPlot) - True - - Create a plot with a DataFrame: - - >>> data = pd.DataFrame( - ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... rng.integers(1, 3, 100)], - ... columns=['ISOPleasant', 'ISOEventful', 'Group']) - >>> plot = ISOPlot(data=data, hue='Group') - >>> plot.hue - 'Group' - - - Create a plot directly with arrays: - - >>> x, y = rng.multivariate_normal([0, 0], [[1, 0], [0, 1]], 100).T - >>> plot = ISOPlot(x=x, y=y) - >>> isinstance(plot, ISOPlot) - True + >>> cp.show() # xdoctest: +SKIP """ @@ -146,6 +129,39 @@ def __init__( axes : Axes | np.ndarray | None, optional Existing axes to plot on, by default None + Examples + -------- + Create a plot with default parameters: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> plot = ISOPlot() + >>> isinstance(plot, ISOPlot) + True + + Create a plot with a DataFrame: + + >>> data = pd.DataFrame( + ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... rng.integers(1, 3, 100)], + ... columns=['ISOPleasant', 'ISOEventful', 'Group']) + >>> plot = ISOPlot(data=data, hue='Group') + >>> plot.hue + 'Group' + + + Create a plot directly with arrays: + + >>> x, y = rng.multivariate_normal([0, 0], [[1, 0], [0, 1]], 100).T + >>> plot = ISOPlot(x=x, y=y) + >>> isinstance(plot, ISOPlot) + True + """ # Process and validate input data and coordinates data, x, y = self._check_data_x_y(data, x, y) @@ -666,8 +682,7 @@ def _validate_subplots_datas( ) raise ValueError(msg) - @staticmethod - def _allocate_subplot_axes(subplot_titles: list[str]) -> tuple[int, int]: + def _allocate_subplot_axes(self, subplot_titles: list[str]) -> tuple[int, int]: """Allocate the subplot axes based on the number of data subsets.""" msg = ( "This is an experimental feature. " @@ -923,7 +938,7 @@ def add_layer( ... .create_subplots(nrows=2, ncols=2) ... .add_layer(ScatterLayer) ... .apply_styling()) - >>> plot.show() + >>> plot.show() # xdoctest: +SKIP >>> all(len(ctx.layers) == 1 for ctx in plot.subplot_contexts) True >>> plot.close() # Clean up @@ -934,7 +949,7 @@ def add_layer( ... .create_subplots(nrows=2, ncols=2) ... .add_layer(ScatterLayer, on_axis=0) ... .apply_styling()) - >>> plot.show() + >>> plot.show() # xdoctest: +SKIP >>> len(plot.subplot_contexts[0].layers) == 1 True >>> all(len(ctx.layers) == 0 for ctx in plot.subplot_contexts[1:]) @@ -947,7 +962,7 @@ def add_layer( ... .create_subplots(nrows=2, ncols=2) ... .add_layer(ScatterLayer, on_axis=[0, 2]) ... .apply_styling()) - >>> plot.show() + >>> plot.show() # xdoctest: +SKIP >>> len(plot.subplot_contexts[0].layers) == 1 True >>> len(plot.subplot_contexts[2].layers) == 1 @@ -969,7 +984,7 @@ def add_layer( ... # Add a layer with custom data to the second subplot ... .add_layer(ScatterLayer, data=custom_data, on_axis=1) ... .apply_styling()) - >>> plot.show() + >>> plot.show() # xdoctest: +SKIP >>> plot.close() """ @@ -1081,3 +1096,676 @@ def _resolve_axis_indices( return on_axis msg = f"Invalid axis specification: {on_axis}" raise ValueError(msg) + + def add_scatter( + self, + data: pd.DataFrame | None = None, + *, + on_axis: int | tuple[int, int] | list[int] | None = None, + **params: Any, + ) -> ISOPlot: + """ + Add a scatter layer to specific subplot(s). + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes + data : pd.DataFrame, optional + Custom data for this specific scatter plot + **params : dict + Parameters for the scatter plot + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Add a scatter layer to all subplots: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... rng.integers(1, 3, 100)], + ... columns=['ISOPleasant', 'ISOEventful', 'Group']) + >>> plot = (ISOPlot(data=data) + ... .create_subplots(nrows=2, ncols=1) + ... .add_scatter(s=50, alpha=0.7, hue='Group') + ... .apply_styling()) + >>> plot.show() # xdoctest: +SKIP + >>> all(len(ctx.layers) == 1 for ctx in plot.subplot_contexts) + True + >>> plot.close() # Clean up + + Add a scatter layer with custom data to a specific subplot: + + >>> custom_data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0.2, 0.1, 50), + ... 'ISOEventful': rng.normal(0.15, 0.2, 50), + ... }) + >>> plot = (ISOPlot(data=data) + ... .create_subplots(nrows=2, ncols=1) + ... .add_scatter(hue='Group') + ... .add_scatter(on_axis=0, data=custom_data, color='red') + ... .apply_styling()) + >>> plot.show() # xdoctest: +SKIP + >>> plot.subplot_contexts[0].layers[1].custom_data is custom_data + True + >>> plot.close() # Clean up + + """ + # Merge default scatter parameters with provided ones + # Remove data from scatter_params to avoid conflict + scatter_params = self._scatter_params.model_copy().drop("data").update(**params) + + return self.add_layer( + ScatterLayer, data=data, on_axis=on_axis, **scatter_params.as_dict() + ) + + def add_spi( + self, + on_axis: int | tuple[int, int] | list[int] | None = None, + spi_target_data: pd.DataFrame | np.ndarray | None = None, + msn_params: DirectParams | CentredParams | None = None, + *, + layer_class: type[Layer] = SPISimpleLayer, + **params: Any, + ) -> ISOPlot: + """ + Add a SPI layer to specific subplot(s). + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes + spi_target_data : pd.DataFrame | np.ndarray | None, optional + Custom data for this specific SPI plot + msn_params : DirectParams | CentredParams | None, optional + Parameters for the SPI plot + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Add a SPI layer to all subplots: + + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.spi import DirectParams + >>> rng = np.random.default_rng(42) + >>> # Create a DataFrame with random data + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> # Define MSN parameters for the SPI target + >>> msn_params = DirectParams( + ... xi=np.array([0.5, 0.7]), + ... omega=np.array([[0.1, 0.05], [0.05, 0.1]]), + ... alpha=np.array([0, -5]), + ... ) + >>> # Create the plot with only an SPI layer + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .add_spi(msn_params=msn_params) + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 2 + True + >>> plot.close() # Clean up + + Add an SPI layer over top of 'real' data: + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .add_density() + ... .add_spi(msn_params=msn_params, show_score="on axis") + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 3 + True + >>> plot.close() # Clean up + + """ + if layer_class == SPISimpleLayer: + spi_simple_params = ( + self._spi_simple_density_params.model_copy() + .drop("data") + .update(**params) + ) + + return self.add_layer( + layer_class, + on_axis=on_axis, + msn_params=msn_params, + spi_target_data=spi_target_data, + **spi_simple_params.as_dict(), + ) + if layer_class in (SPIDensityLayer, SPIScatterLayer): + msg = ( + "Only the simple density layer type is currently supported for " + "SPI plots. Please use SPISimpleLayer" + ) + raise NotImplementedError(msg) + + msg = "Invalid layer class provided. Expected SPISimpleLayer. " + raise ValueError(msg) + + def add_density( + self, + on_axis: int | tuple[int, int] | list[int] | None = None, + data: pd.DataFrame | None = None, + *, + include_outline: bool = False, + **params: Any, + ) -> ISOPlot: + """ + Add a density layer to specific subplot(s). + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes + data : pd.DataFrame, optional + Custom data for this specific density plot + include_outline : bool, optional + Whether to include an outline around the density plot, by default False + **params : dict + Parameters for the density plot + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Add a density layer to all subplots: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0.2, 0.25, 50), + ... 'ISOEventful': rng.normal(0.15, 0.4, 50), + ... }) + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_density() + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 1 + True + >>> plot.close() # Clean up + + Add a density layer with custom settings: + + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_density(levels=5, alpha=0.7) + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 1 + True + >>> plot.close() # Clean up + + """ + # Merge default density parameters with provided ones + density_params = self._density_params.model_copy().drop("data").update(**params) + + return self.add_layer( + DensityLayer, + data=data, + on_axis=on_axis, + include_outline=include_outline, + **density_params.as_dict(), + ) + + def add_simple_density( + self, + on_axis: int | tuple[int, int] | list[int] | None = None, + data: pd.DataFrame | None = None, + *, + include_outline: bool = True, + **params: Any, + ) -> ISOPlot: + """ + Add a simple density layer to specific subplot(s). + + Parameters + ---------- + on_axis : int | tuple[int, int] | list[int] | None, optional + Target specific axis/axes + data : pd.DataFrame, optional + Custom data for this specific density plot + thresh : float, optional + Threshold for density contours, by default 0.5 + levels : int | Iterable[float], optional + Contour levels, by default 2 + alpha : float, optional + Transparency level, by default 0.5 + include_outline : bool, optional + Whether to include an outline around the density plot, by default True + **params : dict + Additional parameters for the density plot + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Add a simple density layer: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame({ + ... 'ISOPleasant': rng.normal(0.2, 0.25, 30), + ... 'ISOEventful': rng.normal(0.15, 0.4, 30), + ... }) + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .add_simple_density() + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 2 + True + >>> plot.close() # Clean up + + Add a simple density with splitting by group: + >>> data = pd.DataFrame( + ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... rng.integers(1, 3, 100)], + ... columns=['ISOPleasant', 'ISOEventful', 'Group']) + >>> plot = ( + ... ISOPlot(data=data, hue='Group') + ... .create_subplots() + ... .add_scatter() + ... .add_simple_density() + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> len(plot.subplot_contexts[0].layers) == 2 + True + >>> plot.close() + ... + + """ + # Merge default simple density parameters with provided ones + simple_density_params = ( + self._simple_density_params.model_copy().drop("data").update(**params) + ) + + return self.add_layer( + SimpleDensityLayer, + on_axis=on_axis, + data=data, + include_outline=include_outline, + **simple_density_params.as_dict(), + ) + + def add_annotation( + self, + text: str, + xy: tuple[float, float], + xytext: tuple[float, float], + arrowprops: dict[str, Any] | None = None, + ) -> ISOPlot: + """ + Add an annotation to the plot. + + Parameters + ---------- + text : str + The text to display in the annotation. + xy : tuple[float, float] + The point to annotate. + xytext : tuple[float, float] + The point at which to place the text. + arrowprops : dict[str, Any] | None, optional + Properties for the arrow connecting the annotation text to the point. + + Returns + ------- + ISOPlot + The current plot instance for chaining + + """ + msg = "AnnotationLayer is not yet implemented. " + raise NotImplementedError(msg) + # TODO(MitchellAcoustics): Implement AnnotationLayer # noqa: TD003 + return self.add_layer( + "AnnotationLayer", + text=text, + xy=xy, + xytext=xytext, + arrowprops=arrowprops, + ) + + def apply_styling( + self, + **kwargs: Any, + ) -> ISOPlot: + """ + Apply styling to the plot. + + Parameters + ---------- + **kwargs: Styling parameters to override defaults + + Returns + ------- + ISOPlot + The current plot instance for chaining + + Examples + -------- + Apply styling with default parameters: + + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> # Create simple data for styling example + >>> data = pd.DataFrame( + ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... rng.integers(1, 3, 100)], + ... columns=['ISOPleasant', 'ISOEventful', 'Group']) + >>> # Create plot with default styling + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .apply_styling() + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> plot.get_figure() is not None + True + >>> plot.close() # Clean up + + Apply styling with custom parameters: + + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .apply_styling(xlim=(-2, 2), ylim=(-2, 2), primary_lines=False) + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> plot.get_figure() is not None + True + >>> plot.close() # Clean up + + Demonstrate the fluent interface (method chaining): + + >>> # Create plot with method chaining + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots(nrows=1, ncols=1) + ... .add_scatter(alpha=0.7) + ... .add_density(levels=5) + ... .apply_styling(title_fontsize=14) + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> # Verify results + >>> isinstance(plot, ISOPlot) + True + >>> plot.close() # Clean up + + """ + self._style_params.update(**kwargs) + self._check_for_axes() + + self._set_style() + self._circumplex_grid() + self._set_title() + self._set_axes_titles() + self._primary_labels() + if self._style_params.get("primary_lines"): + self._primary_lines() + if self._style_params.get("diagonal_lines"): + self._diagonal_lines_and_labels() + + if self._style_params.get("legend_loc") is not False: + self._move_legend() + + return self + + def _set_style(self) -> None: + """Set the overall style for the plot.""" + sns.set_style({"xtick.direction": "in", "ytick.direction": "in"}) + + def _circumplex_grid(self) -> ISOPlot: + """Add the circumplex grid to the plot.""" + for _, axis in enumerate(self.yield_axes_objects()): + axis.set_xlim(self._style_params.get("xlim")) + axis.set_ylim(self._style_params.get("ylim")) + axis.set_aspect("equal") + + axis.get_yaxis().set_minor_locator(ticker.AutoMinorLocator()) + axis.get_xaxis().set_minor_locator(ticker.AutoMinorLocator()) + + axis.grid(visible=True, which="major", color="grey", alpha=0.5) + axis.grid( + visible=True, + which="minor", + color="grey", + linestyle="dashed", + linewidth=0.5, + alpha=0.4, + zorder=self._style_params.get("prim_lines_zorder"), + ) + + return self + + def _set_title(self) -> ISOPlot: + """Set the title of the plot.""" + if self.title and self._has_subplots: + figure = self.get_figure() + figure.suptitle( + self.title, fontsize=self._style_params.get("title_fontsize") + ) + elif self.title and not self._has_subplots: + axis = self.get_single_axes() + if axis.get_title() == "": + axis.set_title( + self.title, fontsize=self._style_params.get("title_fontsize") + ) + else: + figure = self.get_figure() + figure.suptitle( + self.title, fontsize=self._style_params.get("title_fontsize") + ) + return self + + def _set_axes_titles(self) -> ISOPlot: + """Set the titles of the subplots.""" + for context in self.subplot_contexts: + if context.ax and context.title: + context.ax.set_title(context.title) + return self + + def _primary_lines(self) -> ISOPlot: + """Add primary lines to the plot.""" + for _, axis in enumerate(self.yield_axes_objects()): + axis.axhline( + y=0, + color="grey", + linestyle="dashed", + alpha=1, + lw=self._style_params.get("linewidth"), + zorder=self._style_params.get("prim_lines_zorder"), + ) + axis.axvline( + x=0, + color="grey", + linestyle="dashed", + alpha=1, + lw=self._style_params.get("linewidth"), + zorder=self._style_params.get("prim_lines_zorder"), + ) + return self + + def _primary_labels(self) -> ISOPlot: + """Handle the default labels for the x and y axes.""" + xlabel = self._style_params.get("xlabel") + ylabel = self._style_params.get("ylabel") + + xlabel = self.x if xlabel is None else xlabel + ylabel = self.y if ylabel is None else ylabel + fontdict = self._style_params.get("prim_ax_fontdict") + + # BUG: For some reason, this ruins the sharex and sharey + # functionality, but only when a layer is applied + # a specific subplot. + for _, axis in enumerate(self.yield_axes_objects()): + axis.set_xlabel( + xlabel, fontdict=fontdict + ) if xlabel is not False else axis.xaxis.label.set_visible(False) + + axis.set_ylabel( + ylabel, fontdict=fontdict + ) if ylabel is not False else axis.yaxis.label.set_visible(False) + + return self + + def _diagonal_lines_and_labels(self) -> ISOPlot: + """ + Add diagonal lines and labels to the plot. + + Examples + -------- + >>> import pandas as pd + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful']) + >>> # Create a plot with diagonal lines and labels + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .apply_styling(diagonal_lines=True) + ... ) + >>> plot.show() # xdoctest: +SKIP + >>> plot.close('all') + + """ + for _, axis in enumerate(self.yield_axes_objects()): + xlim = self._style_params.get("xlim", DEFAULT_STYLE_PARAMS["xlim"]) + ylim = self._style_params.get("ylim", DEFAULT_STYLE_PARAMS["ylim"]) + axis.plot( + xlim, + ylim, + linestyle="dashed", + color="grey", + alpha=0.5, + lw=self._style_params.get("linewidth"), + zorder=self._style_params.get("diag_lines_zorder"), + ) + logger.debug("Plotting diagonal line for axis.") + axis.plot( + xlim, + ylim[::-1], + linestyle="dashed", + color="grey", + alpha=0.5, + lw=self._style_params.get("linewidth"), + zorder=self._style_params.get("diag_lines_zorder"), + ) + + diag_ax_font = { + "fontstyle": "italic", + "fontsize": "small", + "fontweight": "bold", + "color": "black", + "alpha": 0.5, + } + axis.text( + xlim[1] / 2, + ylim[1] / 2, + "(vibrant)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=self._style_params.get("diag_labels_zorder"), + ) + axis.text( + xlim[0] / 2, + ylim[1] / 2, + "(chaotic)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=self._style_params.get("diag_labels_zorder"), + ) + axis.text( + xlim[0] / 2, + ylim[0] / 2, + "(monotonous)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=self._style_params.get("diag_labels_zorder"), + ) + axis.text( + xlim[1] / 2, + ylim[0] / 2, + "(calm)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=self._style_params.get("diag_labels_zorder"), + ) + return self + + def _move_legend(self) -> ISOPlot: + """Move the legend to the specified location.""" + for i, axis in enumerate(self.yield_axes_objects()): + old_legend = axis.get_legend() + if old_legend is None: + # logger.debug("_move_legend: No legend found for axis %s", i) + continue + + # Get handles and filter out None values + handles = [ + h for h in old_legend.legend_handles if isinstance(h, Artist | tuple) + ] + # Skip if no valid handles remain + if not handles: + continue + + labels = [t.get_text() for t in old_legend.get_texts()] + title = old_legend.get_title().get_text() + # Ensure labels and handles match in length + if len(handles) != len(labels): + labels = labels[: len(handles)] + + axis.legend( + handles, + labels, + loc=self._style_params.get("legend_loc"), + title=title, + ) + return self diff --git a/src/soundscapy/plotting/iso_plot_layers.py b/src/soundscapy/plotting/iso_plot_layers.py deleted file mode 100644 index f92e6e3..0000000 --- a/src/soundscapy/plotting/iso_plot_layers.py +++ /dev/null @@ -1,487 +0,0 @@ -""" -Layer-specific methods for ISOPlot. - -This module provides a mixin class with methods for adding different types of -visualization layers to ISOPlot instances. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import numpy as np -import pandas as pd - -from soundscapy.plotting.layers import ( - DensityLayer, - Layer, - ScatterLayer, - SimpleDensityLayer, - SPISimpleLayer, -) - -if TYPE_CHECKING: - from soundscapy.spi.msn import ( - CentredParams, - DirectParams, - ) - - -class ISOPlotLayersMixin: - """Mixin providing layer-specific methods for ISOPlot.""" - - def add_layer( - self, - layer_class: type[Layer], - data: pd.DataFrame | None = None, - *, - on_axis: int | tuple[int, int] | list[int] | None = None, - **params: Any, - ) -> Any: - """ - Add a visualization layer, optionally targeting specific subplot(s). - - This is a stub method that should be implemented by classes using this mixin. - The actual implementation is in the ISOPlot class. - - Parameters - ---------- - layer_class : Layer subclass - The type of layer to add - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes - data : pd.DataFrame, optional - Custom data for this specific layer - **params : dict - Parameters for the layer - - Returns - ------- - Any - The current plot instance for chaining - - """ - msg = "Classes using ISOPlotLayersMixin must implement add_layer" - raise NotImplementedError(msg) - - def get_single_axes(self, ax_idx: int | tuple[int, int] | None = None) -> Any: - """ - Get a specific axes object. - - This is a stub method that should be implemented by classes using this mixin. - The actual implementation is in the ISOPlot class. - - Parameters - ---------- - ax_idx : int | tuple[int, int] | None, optional - The index of the axes to get. If None, returns the first axes. - Can be an integer for flattened access or a tuple of (row, col). - - Returns - ------- - Any - The requested matplotlib Axes object - - """ - msg = "Classes using ISOPlotLayersMixin must implement get_single_axes" - raise NotImplementedError(msg) - - def add_scatter( - self, - data: pd.DataFrame | None = None, - *, - on_axis: int | tuple[int, int] | list[int] | None = None, - **params: Any, - ) -> Any: - """ - Add a scatter layer to the plot. - - Parameters - ---------- - data : pd.DataFrame | None, optional - Custom data for this specific layer, by default None - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes: - - int: Index of subplot (flattened) - - tuple: (row, col) coordinates - - list: Multiple indices to apply the layer to - - None: Apply to all subplots (default) - **params : Any - Additional parameters for the scatter layer - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - Add a scatter layer to all subplots: - - >>> import pandas as pd - >>> import numpy as np - >>> from soundscapy.plotting import ISOPlot - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame( - ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... rng.integers(1, 3, 100)], - ... columns=['ISOPleasant', 'ISOEventful', 'Group']) - >>> plot = (ISOPlot(data=data) - ... .create_subplots(nrows=2, ncols=1) - ... .add_scatter(s=50, alpha=0.7, hue='Group') - ... .apply_styling()) - >>> plot.show() # xdoctest: +SKIP - >>> all(len(ctx.layers) == 1 for ctx in plot.subplot_contexts) - True - >>> plot.close() # Clean up - - Add a scatter layer with custom data to a specific subplot: - - >>> custom_data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0.2, 0.1, 50), - ... 'ISOEventful': rng.normal(0.15, 0.2, 50), - ... }) - >>> plot = (ISOPlot(data=data) - ... .create_subplots(nrows=2, ncols=1) - ... .add_scatter(hue='Group') - ... .add_scatter(on_axis=0, data=custom_data, color='red') - ... .apply_styling()) - >>> plot.show() # xdoctest: +SKIP - >>> plot.subplot_contexts[0].layers[1].custom_data is custom_data - True - >>> plot.close() # Clean up - - """ - return self.add_layer(ScatterLayer, data=data, on_axis=on_axis, **params) - - def add_spi( - self, - on_axis: int | tuple[int, int] | list[int] | None = None, - spi_target_data: pd.DataFrame | np.ndarray | None = None, - msn_params: DirectParams | CentredParams | None = None, - *, - layer_class: type[Layer] = SPISimpleLayer, - **params: Any, - ) -> Any: - """ - Add an SPI (Soundscape Perception Index) layer to the plot. - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes: - - int: Index of subplot (flattened) - - tuple: (row, col) coordinates - - list: Multiple indices to apply the layer to - - None: Apply to all subplots (default) - spi_target_data : pd.DataFrame | np.ndarray | None, optional - Pre-sampled data for SPI target distribution, by default None - msn_params : DirectParams | CentredParams | None, optional - Parameters to generate SPI data if no spi_target_data is provided, by default None - layer_class : type[Layer], optional - The type of SPI layer to add, by default SPISimpleLayer - **params : Any - Additional parameters for the SPI layer - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Notes - ----- - Either spi_target_data or msn_params must be provided, but not both. - The test data for SPI calculations will be retrieved from the plot context. - - Examples - -------- - Add a SPI layer to all subplots: - - >>> import pandas as pd - >>> import numpy as np - >>> from soundscapy.spi import DirectParams - >>> from soundscapy.plotting import ISOPlot - >>> rng = np.random.default_rng(42) - >>> # Create a DataFrame with random data - >>> data = pd.DataFrame( - ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... columns=['ISOPleasant', 'ISOEventful'] - ... ) - >>> # Define MSN parameters for the SPI target - >>> msn_params = DirectParams( - ... xi=np.array([0.5, 0.7]), - ... omega=np.array([[0.1, 0.05], [0.05, 0.1]]), - ... alpha=np.array([0, -5]), - ... ) - >>> # Create the plot with only an SPI layer - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .add_spi(msn_params=msn_params) - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 2 - True - >>> plot.close() # Clean up - - Add an SPI layer over top of 'real' data: - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .add_density() - ... .add_spi(msn_params=msn_params, show_score="on axis") - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 3 - True - >>> plot.close() # Clean up - - """ - # Validate that we have either spi_target_data or msn_params - if spi_target_data is None and msn_params is None: - msg = ( - "No data provided for SPI plot. " - "Please provide either spi_target_data or msn_params." - ) - raise ValueError(msg) - - if spi_target_data is not None and msn_params is not None: - msg = ( - "Please provide either spi_target_data or msn_params, not both. " - "Got: \n" - f"\n`spi_target_data`: {type(spi_target_data)}\n`msn_params`: {type(msn_params)}" - ) - raise ValueError(msg) - - # Add the SPI layer - return self.add_layer( - layer_class, - on_axis=on_axis, - spi_target_data=spi_target_data, - msn_params=msn_params, - **params, - ) - - def add_density( - self, - on_axis: int | tuple[int, int] | list[int] | None = None, - data: pd.DataFrame | None = None, - *, - include_outline: bool = False, - **params: Any, - ) -> Any: - """ - Add a density layer to the plot. - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes: - - int: Index of subplot (flattened) - - tuple: (row, col) coordinates - - list: Multiple indices to apply the layer to - - None: Apply to all subplots (default) - data : pd.DataFrame | None, optional - Custom data for this specific layer, by default None - include_outline : bool, optional - Whether to include an outline around the density plot, by default False - **params : Any - Additional parameters for the density layer - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - Add a density layer to all subplots: - - >>> import pandas as pd - >>> import numpy as np - >>> from soundscapy.plotting import ISOPlot - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0.2, 0.25, 50), - ... 'ISOEventful': rng.normal(0.15, 0.4, 50), - ... }) - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_density() - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 1 - True - >>> plot.close() # Clean up - - Add a density layer with custom settings: - - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_density(levels=5, alpha=0.7) - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 1 - True - >>> plot.close() # Clean up - - """ - return self.add_layer( - DensityLayer, - data=data, - on_axis=on_axis, - include_outline=include_outline, - **params, - ) - - def add_simple_density( - self, - on_axis: int | tuple[int, int] | list[int] | None = None, - data: pd.DataFrame | None = None, - *, - include_outline: bool = True, - **params: Any, - ) -> Any: - """ - Add a simplified density layer to the plot. - - This creates a density plot with fewer contour levels, typically used - to highlight the main density region. - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes: - - int: Index of subplot (flattened) - - tuple: (row, col) coordinates - - list: Multiple indices to apply the layer to - - None: Apply to all subplots (default) - data : pd.DataFrame | None, optional - Custom data for this specific layer, by default None - include_outline : bool, optional - Whether to include an outline around the density plot, by default True - **params : Any - Additional parameters for the simple density layer - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - Add a simple density layer: - - >>> import pandas as pd - >>> import numpy as np - >>> from soundscapy.plotting import ISOPlot - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0.2, 0.25, 30), - ... 'ISOEventful': rng.normal(0.15, 0.4, 30), - ... }) - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .add_simple_density() - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 2 - True - >>> plot.close() # Clean up - - Add a simple density with splitting by group: - >>> data = pd.DataFrame( - ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... rng.integers(1, 3, 100)], - ... columns=['ISOPleasant', 'ISOEventful', 'Group']) - >>> plot = ( - ... ISOPlot(data=data, hue='Group') - ... .create_subplots() - ... .add_scatter() - ... .add_simple_density() - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> len(plot.subplot_contexts[0].layers) == 2 - True - >>> plot.close() - - """ - return self.add_layer( - SimpleDensityLayer, - data=data, - on_axis=on_axis, - include_outline=include_outline, - **params, - ) - - def add_annotation( - self, - text: str, - xy: tuple[float, float], - xytext: tuple[float, float], - arrowprops: dict[str, Any] | None = None, - ) -> Any: - """ - Add an annotation to the plot. - - Parameters - ---------- - text : str - The text of the annotation - xy : tuple[float, float] - The point (x, y) to annotate - xytext : tuple[float, float] - The position (x, y) to place the text - arrowprops : dict[str, Any] | None, optional - Properties used to draw the arrow, by default None - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - >>> import pandas as pd - >>> import numpy as np - >>> from soundscapy.plotting import ISOPlot - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0, 0.5, 100), - ... 'ISOEventful': rng.normal(0, 0.5, 100), - ... }) - >>> plot = (ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .add_annotation( - ... "Interesting point", - ... xy=(0.5, 0.5), - ... xytext=(0.7, 0.7), - ... arrowprops=dict(arrowstyle="->") - ... ) - ... .apply_styling()) - >>> plot.show() # xdoctest: +SKIP - >>> plot.close() # Clean up - - """ - # Default arrow properties if none provided - if arrowprops is None: - arrowprops = {"arrowstyle": "->"} - - # Get the current axes - ax = self.get_single_axes() - ax.annotate(text, xy=xy, xytext=xytext, arrowprops=arrowprops) - - return self diff --git a/src/soundscapy/plotting/iso_plot_styling.py b/src/soundscapy/plotting/iso_plot_styling.py deleted file mode 100644 index 0f034ff..0000000 --- a/src/soundscapy/plotting/iso_plot_styling.py +++ /dev/null @@ -1,377 +0,0 @@ -""" -Styling methods for ISOPlot. - -This module provides a mixin class with methods for styling ISOPlot instances, -including grid lines, labels, and other visual elements. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -from matplotlib import pyplot as plt -from matplotlib import ticker -from matplotlib.artist import Artist - -from soundscapy.plotting.plotting_types import ParamModel -from soundscapy.sspylogging import get_logger - -if TYPE_CHECKING: - from matplotlib.axes import Axes - -logger = get_logger() - - -class ISOPlotStylingMixin: - """Mixin providing styling methods for ISOPlot.""" - - def apply_styling(self, **kwargs: Any) -> Any: - """ - Apply styling to the plot. - - This method applies various styling elements to the plot, including: - - Setting axis limits and labels - - Adding grid lines - - Adding diagonal lines (if enabled) - - Setting titles - - Configuring legends - - Parameters - ---------- - **kwargs : Any - Additional styling parameters to override defaults - - Returns - ------- - ISOPlot - The current plot instance for chaining - - Examples - -------- - Apply styling with default parameters - - >>> import pandas as pd - >>> import numpy as np - >>> from soundscapy.plotting import ISOPlot - >>> rng = np.random.default_rng(42) - >>> # Create simple data for styling example - >>> data = pd.DataFrame( - ... np.c_[rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... rng.integers(1, 3, 100)], - ... columns=['ISOPleasant', 'ISOEventful', 'Group']) - >>> # Create plot with default styling - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .apply_styling() - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> plot.get_figure() is not None - True - >>> plot.close() # Clean up - - Apply styling with custom parameters: - - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .apply_styling(xlim=(-2, 2), ylim=(-2, 2), primary_lines=False) - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> plot.get_figure() is not None - True - >>> plot.close() # Clean up - - Demonstrate the fluent interface (method chaining): - - >>> # Create plot with method chaining - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots(nrows=1, ncols=1) - ... .add_scatter(alpha=0.7) - ... .add_density(levels=5) - ... .apply_styling(title_fontsize=14) - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> # Verify results - >>> isinstance(plot, ISOPlot) - True - >>> plot.close() # Clean up - - """ - # Update style parameters with provided kwargs - # Use the default values from the StyleParams class and override with kwargs - self._style_params = ParamModel.create("style", **kwargs) - # Check if we have axes to style - self._check_for_axes() - self._set_style() - - # Apply styling to each axes - for ax in self.yield_axes_objects(): - # Set axis limits - ax.set_xlim(self._style_params.xlim) - ax.set_ylim(self._style_params.ylim) - - # Set up grid - self._circumplex_grid(ax) - - # Add primary lines if enabled - if self._style_params.primary_lines: - self._primary_lines(ax) - self._primary_labels(ax) - - # Add diagonal lines if enabled - if self._style_params.diagonal_lines: - self._diagonal_lines_and_labels(ax) - - # Set titles - self._set_title() - self._set_axes_titles() - - # Move legend if needed - if self._style_params.legend_loc and self._style_params.legend_loc is not False: - self._move_legend() - - return self - - @staticmethod - def _set_style() -> None: - """Set the style for the plot.""" - plt.style.use("seaborn-v0_8-whitegrid") - - def _circumplex_grid(self, ax: Axes) -> None: - """ - Set up the grid for a circumplex plot. - - Parameters - ---------- - ax : Axes - The axes to set up the grid for - - """ - # Set up grid - ax.set_xlim(self._style_params.get("xlim")) - ax.set_ylim(self._style_params.get("ylim")) - ax.set_aspect("equal") - - ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator()) - ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator()) - - ax.grid(visible=True, which="major", color="grey", alpha=0.5) - ax.grid( - visible=True, - which="minor", - color="grey", - linestyle="dashed", - linewidth=0.5, - alpha=0.4, - zorder=self._style_params.get("prim_lines_zorder"), - ) - - def _set_title(self) -> None: - """Set the title of the plot.""" - if self.title and self._has_subplots: - figure = self.get_figure() - figure.suptitle( - self.title, fontsize=self._style_params.get("title_fontsize") - ) - elif self.title and not self._has_subplots: - axis = self.get_single_axes() - if axis.get_title() == "": - axis.set_title( - self.title, fontsize=self._style_params.get("title_fontsize") - ) - else: - figure = self.get_figure() - figure.suptitle( - self.title, fontsize=self._style_params.get("title_fontsize") - ) - - def _set_axes_titles(self) -> None: - """Set titles for individual axes in subplots.""" - # Set titles for individual subplot contexts if they have titles - for i, context in enumerate(self.subplot_contexts): - if context.title is not None: - ax = self.get_single_axes(i) - ax.set_title(context.title, fontsize=self._style_params.title_fontsize) - - def _primary_lines(self, ax: Axes) -> None: - """ - Add primary axis lines to the plot. - - Parameters - ---------- - ax : Axes - The axes to add the lines to - - """ - # Add horizontal and vertical lines at x=0 and y=0 - ax.axhline( - y=0, - color="grey", - linestyle="dashed", - alpha=1, - lw=self._style_params.get("linewidth"), - zorder=self._style_params.get("prim_lines_zorder"), - ) - ax.axvline( - x=0, - color="grey", - linestyle="dashed", - alpha=1, - lw=self._style_params.get("linewidth"), - zorder=self._style_params.get("prim_lines_zorder"), - ) - - def _primary_labels(self, ax: Axes) -> None: - """ - Add labels for primary axes. - - Parameters - ---------- - ax : Axes - The axes to add the labels to - - """ - # Set x and y labels if they are provided - if self._style_params.xlabel is not False: - xlabel = self._style_params.xlabel or self.x - ax.set_xlabel( - xlabel, - fontdict=self._style_params.prim_ax_fontdict, - ) - - if self._style_params.ylabel is not False: - ylabel = self._style_params.ylabel or self.y - ax.set_ylabel( - ylabel, - fontdict=self._style_params.prim_ax_fontdict, - ) - - def _diagonal_lines_and_labels(self, ax: Axes) -> None: - """ - Add diagonal lines and labels to the plot. - - Parameters - ---------- - ax : Axes - The axes to add the diagonal lines and labels to - - Examples - -------- - >>> import pandas as pd - >>> import numpy as np - >>> from soundscapy.plotting import ISOPlot - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame( - ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... columns=['ISOPleasant', 'ISOEventful']) - >>> # Create a plot with diagonal lines and labels - >>> plot = ( - ... ISOPlot(data=data) - ... .create_subplots() - ... .add_scatter() - ... .apply_styling(diagonal_lines=True) - ... ) - >>> plot.show() # xdoctest: +SKIP - >>> plot.close('all') - - """ - # Get axis limits - xlim = ax.get_xlim() - ylim = ax.get_ylim() - - ax.plot( - xlim, - ylim, - linestyle="dashed", - color="grey", - alpha=0.5, - lw=self._style_params.get("linewidth"), - zorder=self._style_params.get("diag_lines_zorder"), - ) - logger.debug("Plotting diagonal line for axis.") - ax.plot( - xlim, - ylim[::-1], - linestyle="dashed", - color="grey", - alpha=0.5, - lw=self._style_params.get("linewidth"), - zorder=self._style_params.get("diag_lines_zorder"), - ) - - diag_ax_font = { - "fontstyle": "italic", - "fontsize": "small", - "fontweight": "bold", - "color": "black", - "alpha": 0.5, - } - ax.text( - xlim[1] / 2, - ylim[1] / 2, - "(vibrant)", - ha="center", - va="center", - fontdict=diag_ax_font, - zorder=self._style_params.get("diag_labels_zorder"), - ) - ax.text( - xlim[0] / 2, - ylim[1] / 2, - "(chaotic)", - ha="center", - va="center", - fontdict=diag_ax_font, - zorder=self._style_params.get("diag_labels_zorder"), - ) - ax.text( - xlim[0] / 2, - ylim[0] / 2, - "(monotonous)", - ha="center", - va="center", - fontdict=diag_ax_font, - zorder=self._style_params.get("diag_labels_zorder"), - ) - ax.text( - xlim[1] / 2, - ylim[0] / 2, - "(calm)", - ha="center", - va="center", - fontdict=diag_ax_font, - zorder=self._style_params.get("diag_labels_zorder"), - ) - - def _move_legend(self) -> None: - """Move the legend to the specified location.""" - for i, axis in enumerate(self.yield_axes_objects()): - old_legend = axis.get_legend() - if old_legend is None: - # logger.debug("_move_legend: No legend found for axis %s", i) - continue - - # Get handles and filter out None values - handles = [ - h for h in old_legend.legend_handles if isinstance(h, Artist | tuple) - ] - # Skip if no valid handles remain - if not handles: - continue - - labels = [t.get_text() for t in old_legend.get_texts()] - title = old_legend.get_title().get_text() - # Ensure labels and handles match in length - if len(handles) != len(labels): - labels = labels[: len(handles)] - - axis.legend( - handles, - labels, - loc=self._style_params.get("legend_loc"), - title=title, - ) From 779f28f9dc1a559e125d5869c82764cd782cc47f Mon Sep 17 00:00:00 2001 From: Andrew Mitchell Date: Sat, 10 May 2025 16:48:26 +0100 Subject: [PATCH 5/8] Refactor plotting module to unify layer handling and styling Consolidated layer management into `add_layer` with improved versatility, replacing multiple layer-specific methods with a unified interface. Deprecated legacy methods for backward compatibility and refined styling and subplot creation workflows for better clarity and maintainability. Updated dependencies to include `deprecated` library. Signed-off-by: Andrew Mitchell --- pyproject.toml | 1 + src/soundscapy/plotting/new/README.md | 14 +- src/soundscapy/plotting/new/__init__.py | 1 + src/soundscapy/plotting/new/iso_plot.py | 92 ++- src/soundscapy/plotting/new/layer.py | 492 +++++++++++++- src/soundscapy/plotting/new/managers.py | 603 ++++++++---------- .../plotting/new/parameter_models.py | 21 +- src/soundscapy/plotting/new/plot_context.py | 177 ++++- src/soundscapy/plotting/new/protocols.py | 2 +- uv.lock | 78 +++ 10 files changed, 1028 insertions(+), 453 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2b8472..91f7f94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ classifiers = [ ] dependencies = [ "coverage==7.8.0", + "deprecated>=1.2.18", "loguru>=0.7.2", "numpy!=1.26", "pandas[excel]>=2.2.2", diff --git a/src/soundscapy/plotting/new/README.md b/src/soundscapy/plotting/new/README.md index 9517558..cc6b5f0 100644 --- a/src/soundscapy/plotting/new/README.md +++ b/src/soundscapy/plotting/new/README.md @@ -1,6 +1,7 @@ # Refactored Plotting Module -This directory contains a refactored implementation of the plotting functionality in soundscapy. The refactoring focuses on: +This directory contains a refactored implementation of the plotting functionality in soundscapy. The refactoring focuses +on: 1. Using composition instead of inheritance 2. Simplifying the relationships between components @@ -13,8 +14,10 @@ The new architecture consists of the following components: ### Core Components -- **ISOPlot**: The main entry point for creating plots. Uses composition to delegate functionality to specialized managers. -- **PlotContext**: Manages data, state, and parameters for a plot or subplot. The central component that owns parameter models. +- **ISOPlot**: The main entry point for creating plots. Uses composition to delegate functionality to specialized + managers. +- **PlotContext**: Manages data, state, and parameters for a plot or subplot. The central component that owns parameter + models. - **Layer**: Base class for visualization layers. Layers know how to render themselves onto a PlotContext's axes. ### Managers @@ -52,7 +55,8 @@ The new architecture consists of the following components: ## Usage -The new implementation maintains the same public API as the original, so existing code should continue to work with minimal changes: +The new implementation maintains the same public API as the original, so existing code should continue to work with +minimal changes: ```python from soundscapy.plotting.new import ISOPlot @@ -60,7 +64,7 @@ from soundscapy.plotting.new import ISOPlot # Create a plot plot = ISOPlot(data=data, hue="LocationID") -# Add layers +# Add layer_mgr plot.create_subplots() plot.add_scatter() plot.add_density() diff --git a/src/soundscapy/plotting/new/__init__.py b/src/soundscapy/plotting/new/__init__.py index 366891f..aa8e368 100644 --- a/src/soundscapy/plotting/new/__init__.py +++ b/src/soundscapy/plotting/new/__init__.py @@ -20,6 +20,7 @@ ... ) >>> # Create a plot with multiple layers >>> plot = (ISOPlot(data=data) +... .create_subplots() ... .add_scatter() ... .add_simple_density(fill=False) ... .apply_styling( diff --git a/src/soundscapy/plotting/new/iso_plot.py b/src/soundscapy/plotting/new/iso_plot.py index dde50ca..ab584ee 100644 --- a/src/soundscapy/plotting/new/iso_plot.py +++ b/src/soundscapy/plotting/new/iso_plot.py @@ -19,9 +19,10 @@ ... columns=['ISOPleasant', 'ISOEventful'] ... ) >>> # Create a plot and add a scatter layer ->>> plot = ISOPlot(data=data) +>>> plot = ISOPlot(data=data).create_subplots() >>> plot.add_scatter() >>> plot.apply_styling() +>>> plot.show() >>> isinstance(plot, ISOPlot) True @@ -31,6 +32,7 @@ >>> data['Group'] = rng.integers(1, 3, 100) >>> # Create a plot with subplots by group >>> plot = (ISOPlot(data=data, hue='Group') +... .create_subplots() ... .add_scatter() ... .add_simple_density(fill=False) ... .apply_styling()) @@ -52,11 +54,10 @@ DEFAULT_XCOL, DEFAULT_YCOL, ) -from soundscapy.plotting.new.layer import ( - Layer, -) from soundscapy.plotting.new.managers import ( + AxisSpec, LayerManager, + LayerSpec, StyleManager, SubplotManager, ) @@ -93,11 +94,11 @@ class ISOPlot: List of subplot contexts subplots_params : SubplotsParams Parameters for subplot configuration - layers : LayerManager + layer_mgr : LayerManager Manager for layer-related functionality - styling : StyleManager + style_mgr : StyleManager Manager for styling-related functionality - subplots : SubplotManager + subplot_mgr : SubplotManager Manager for subplot-related functionality Examples @@ -191,16 +192,16 @@ def __init__( # Store additional plot attributes self.figure = figure self.axes = axes - self.palette = palette + self.palette = palette # TODO: Should move to PlotContext? # Initialize subplot management self.subplot_contexts: list[PlotContext] = [] self.subplots_params = SubplotsParams() # Initialize managers using composition - self.layers = LayerManager(self) - self.styling = StyleManager(self) - self.subplots = SubplotManager(self) + self.layer_mgr = LayerManager(self) + self.style_mgr = StyleManager(self) + self.subplot_mgr = SubplotManager(self) @property def x(self) -> str: @@ -367,7 +368,7 @@ def create_subplots( True """ - return self.subplots.create_subplots( + return self.subplot_mgr.create_subplots( nrows=nrows, ncols=ncols, figsize=figsize, @@ -412,16 +413,16 @@ def add_scatter( ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_scatter() - >>> len(plot.main_context.layers) + >>> len(plot.subplot_contexts[0].layers) 1 Add a scatter layer with custom parameters: - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_scatter(s=50, alpha=0.5, color='red') - >>> len(plot.main_context.layers) + >>> len(plot.subplot_contexts[0].layers) 1 Add a scatter layer to a specific subplot: @@ -435,7 +436,7 @@ def add_scatter( 0 """ - return self.layers.add_scatter( + return self.layer_mgr.add_scatter( data=data, on_axis=on_axis, **params, @@ -476,16 +477,16 @@ def add_density( ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_density() - >>> len(plot.main_context.layers) + >>> len(plot.subplot_contexts[0].layers) 1 Add a density layer with custom parameters: - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_density(fill=False, levels=5, alpha=0.7) - >>> len(plot.main_context.layers) + >>> len(plot.subplot_contexts[0].layers) 1 Add a density layer to a specific subplot: @@ -499,7 +500,7 @@ def add_density( 1 """ - return self.layers.add_density( + return self.layer_mgr.add_density( data=data, on_axis=on_axis, **params, @@ -540,16 +541,16 @@ def add_simple_density( ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_simple_density() - >>> len(plot.main_context.layers) + >>> len(plot.subplot_contexts[0].layers) 1 Add a simple density layer with custom parameters: - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_simple_density(fill=False, thresh=0.3) - >>> len(plot.main_context.layers) + >>> len(plot.subplot_contexts[0].layers) 1 Add a simple density layer to multiple subplots: @@ -565,7 +566,7 @@ def add_simple_density( 1 """ - return self.layers.add_simple_density( + return self.layer_mgr.add_simple_density( data=data, on_axis=on_axis, **params, @@ -596,7 +597,7 @@ def add_spi_simple( The current plot instance for chaining """ - return self.layers.add_spi_simple( + return self.layer_mgr.add_spi_simple( data=data, on_axis=on_axis, **params, @@ -604,10 +605,10 @@ def add_spi_simple( def add_layer( self, - layer_class: type[Layer], + layer_spec: LayerSpec, data: pd.DataFrame | None = None, *, - on_axis: int | tuple[int, int] | list[int] | None = None, + on_axis: AxisSpec | None = None, **params: Any, ) -> ISOPlot: """ @@ -641,21 +642,21 @@ def add_layer( ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_layer(ScatterLayer) - >>> len(plot.main_context.layers) + >>> len(plot.subplot_contexts[0].layers) 1 - >>> isinstance(plot.main_context.layers[0], ScatterLayer) + >>> isinstance(plot.subplot_contexts[0].layers[0], ScatterLayer) True Add a density layer to a specific subplot: - >>> from soundscapy.plotting.new import DensityLayer >>> plot = ISOPlot(data=data) >>> plot = plot.create_subplots(nrows=2,ncols=2) - >>> plot = plot.add_layer(DensityLayer, on_axis=3, fill=False) + >>> plot = plot.add_layer("density", on_axis=3, fill=False) >>> len(plot.subplot_contexts[3].layers) 1 + >>> from soundscapy.plotting.new import DensityLayer >>> isinstance(plot.subplot_contexts[3].layers[0], DensityLayer) True @@ -665,17 +666,14 @@ def add_layer( ... 'ISOPleasant': rng.normal(0.5, 0.1, 50), ... 'ISOEventful': rng.normal(0.5, 0.1, 50), ... }) - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_layer(ScatterLayer, data=custom_data, color='red') - >>> len(plot.main_context.layers) + >>> len(plot.subplot_contexts[0].layers) 1 """ - return self.layers.add_layer( - layer_class=layer_class, - data=data, - on_axis=on_axis, - **params, + return self.layer_mgr.add_layer( + layer_spec=layer_spec, data=data, on_axis=on_axis, **params ) def apply_styling( @@ -710,7 +708,7 @@ def apply_styling( ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_scatter() >>> plot = plot.apply_styling() >>> isinstance(plot, ISOPlot) @@ -718,7 +716,7 @@ def apply_styling( Apply custom styling to a plot: - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_scatter() >>> plot = plot.apply_styling( ... xlim=(-2, 2), @@ -741,7 +739,7 @@ def apply_styling( True """ - return self.styling.apply_styling( + return self.style_mgr.apply_styling( on_axis=on_axis, **style_params, ) @@ -762,7 +760,7 @@ def show(self) -> None: ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot.add_scatter() >>> plot.apply_styling() >>> # plot.show() # Uncomment to display the plot @@ -786,7 +784,7 @@ def close(self) -> None: ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data) + >>> plot = ISOPlot(data=data).create_subplots() >>> plot.add_scatter() >>> plot.apply_styling() >>> # plot.show() # Uncomment to display the plot diff --git a/src/soundscapy/plotting/new/layer.py b/src/soundscapy/plotting/new/layer.py index 60b1ce1..939d9c5 100644 --- a/src/soundscapy/plotting/new/layer.py +++ b/src/soundscapy/plotting/new/layer.py @@ -9,15 +9,16 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast +import numpy as np +import pandas as pd import seaborn as sns from soundscapy.plotting.new.constants import RECOMMENDED_MIN_SAMPLES from soundscapy.sspylogging import get_logger if TYPE_CHECKING: - import pandas as pd from matplotlib.axes import Axes from soundscapy.plotting.new.parameter_models import ( @@ -28,6 +29,10 @@ SPISimpleDensityParams, ) from soundscapy.plotting.new.protocols import PlotContext + from soundscapy.spi.msn import ( + CentredParams, + DirectParams, + ) logger = get_logger() @@ -109,6 +114,10 @@ def render(self, context: PlotContext) -> None: # Get parameters from context and apply overrides params = self._get_params_from_context(context) + # Remove palette if no hue in this layer + if params.hue is None: + params.palette = None + # Render the layer self._render_implementation(data, context, context.ax, params) @@ -357,11 +366,442 @@ def _render_implementation( sns.kdeplot(ax=ax, **kwargs) -class SPISimpleLayer(SimpleDensityLayer): +class SPILayer(Layer): + """Base layer for rendering SPI plots.""" + + param_type = "spi" + + def __init__( + self, + spi_target_data: pd.DataFrame | np.ndarray | None = None, + *, + msn_params: DirectParams | CentredParams | None = None, + n: int = 10000, + custom_data: pd.DataFrame | None = None, + **params: Any, + ) -> None: + """ + Initialize an SPILayer. + + Parameters + ---------- + spi_target_data : pd.DataFrame | np.ndarray | None, optional + Pre-sampled data for SPI target distribution. + When None, msn_params must be provided. + msn_params : DirectParams | CentredParams | None, optional + Parameters to generate SPI data if no spi_target_data is provided + n : int, optional + Number of samples to generate if using msn_params, by default 10000 + custom_data : pd.DataFrame | None, optional + Custom data for this layer, by default None + **params : Any + Additional parameters for the layer + + Notes + ----- + Either spi_target_data or msn_params must be provided, but not both. + The test data for SPI calculations will be retrieved from the plot context. + + """ + # If custom_data is provided but spi_target_data is not, use custom_data as spi_target_data + if custom_data is not None and spi_target_data is None: + logger.warning( + "`spi_target_data` not found, but `custom_data` was found. " + "Using `custom_data` as the SPI target data. " + "\nNote: Passing the SPI data to `spi_target_data` is preferred." + ) + spi_target_data = custom_data + custom_data = None + + # Validate inputs and get SPI parameters + spi_target_data, self.spi_params = self._validate_spi_inputs( + spi_target_data, msn_params + ) + + # Generate the SPI target data + self.spi_data: pd.DataFrame = self._generate_spi_data( + spi_target_data, self.spi_params, n + ) + + # Add n to params + params["n"] = n + + # Initialize the base layer with the SPI data + super().__init__(custom_data=self.spi_data, **params) + + def render(self, context: PlotContext) -> None: + """ + Render this layer on the given context. + + Parameters + ---------- + context : PlotContext + The context containing data and axes for rendering + + """ + if context.ax is None: + msg = "Cannot render layer: context has no associated axes" + raise ValueError(msg) + + # Get the SPI target data + target_data = self.spi_data + + # Process the SPI data to match the context + target_data = self._process_spi_data(target_data, context) + + if target_data is None: + msg = "No data available for rendering SPI layer" + raise ValueError(msg) + + # Get parameters from context + params = self._get_params_from_context(context) + + # Render the layer + self._render_implementation(target_data, context, context.ax, params) + + def _render_implementation( + self, + data: pd.DataFrame, + context: PlotContext, + ax: Axes, + params: BaseParams, + ) -> None: + """ + Render an SPI plot. + + Parameters + ---------- + data : pd.DataFrame + The data to render + context : PlotContext + The context containing state for rendering + ax : Axes + The matplotlib axes to render on + params : BaseParams + The parameters for this layer + + """ + target_data = data[[context.x, context.y]] + + # Get test data from context + test_data = context.data + if test_data is None: + warnings.warn( + "Cannot find data to test SPI against. Skipping this plot.", + UserWarning, + stacklevel=2, + ) + return + + # Calculate SPI score + spi_score = self._calc_context_spi_score(target_data, test_data) + + # Show the score + self.show_score( + spi_score, + show_score=params.show_score + if hasattr(params, "show_score") + else "under title", + context=context, + ax=ax, + axis_text_kwargs=params.axis_text_kw + if hasattr(params, "axis_text_kw") + else {}, + ) + + def show_score( + self, + spi_score: int | None, + show_score: Literal["on axis", "under title"], + context: PlotContext, + ax: Axes, + axis_text_kwargs: dict[str, Any], + ) -> None: + """ + Show the SPI score on the plot. + + Parameters + ---------- + spi_score : int | None + The SPI score to show + show_score : Literal["on axis", "under title"] + Where to show the score + context : PlotContext + The context containing data and axes for rendering + ax : Axes + The axes to render the score on + axis_text_kwargs : dict[str, Any] + Additional arguments for the axis text + + """ + if spi_score is not None: + if show_score == "on axis": + self._add_score_as_text( + ax=ax, + spi_score=spi_score, + **axis_text_kwargs, + ) + elif show_score == "under title": + self._add_score_under_title( + context=context, + ax=ax, + spi_score=spi_score, + ) + + @staticmethod + def _add_score_as_text(ax: Axes, spi_score: int, **text_kwargs: Any) -> None: + """ + Add the SPI score as text on the axis. + + Parameters + ---------- + ax : Axes + The axes to add the text to + spi_score : int + The SPI score to show + **text_kwargs : Any + Additional arguments for the text + + """ + from soundscapy.plotting.new.constants import DEFAULT_SPI_TEXT_KWARGS + + text_kwargs_copy = DEFAULT_SPI_TEXT_KWARGS.copy() + text_kwargs_copy.update(**text_kwargs) + text_kwargs_copy["s"] = f"SPI: {spi_score}" + + ax.text(transform=ax.transAxes, **text_kwargs_copy) + + @staticmethod + def _add_score_under_title(context: PlotContext, ax: Axes, spi_score: int) -> None: + """ + Add the SPI score under the title. + + Parameters + ---------- + context : PlotContext + The context containing data and axes for rendering + ax : Axes + The axes to add the text to + spi_score : int + The SPI score to show + + """ + if context.title is not None: + new_title = f"{context.title}\nSPI: {spi_score}" + else: + new_title = f"SPI: {spi_score}" + + ax.set_title(new_title) + + @staticmethod + def _validate_spi_inputs( + spi_data: pd.DataFrame | np.ndarray | None, + spi_params: DirectParams | CentredParams | None, + ) -> tuple[pd.DataFrame | np.ndarray | None, DirectParams | CentredParams | None]: + """ + Validate the right combination of inputs for the SPI plot. + + Parameters + ---------- + spi_data : pd.DataFrame | np.ndarray | None + Data to use for SPI plotting + spi_params : DirectParams | CentredParams | None + Parameters to generate SPI data + + Returns + ------- + tuple[pd.DataFrame | np.ndarray | None, DirectParams | CentredParams | None] + Validated data and parameters + + """ + # Input validation + if spi_data is None and spi_params is None: + msg = ( + "No data provided for SPI plot. " + "Please provide either spi_data or msn_params." + ) + raise ValueError(msg) + + if spi_data is not None and spi_params is not None: + msg = ( + "Please provide either spi_data or msn_params, not both. " + "Got: \n" + f"\n`spi_data`: {type(spi_data)}\n`spi_params`: {type(spi_params)}" + ) + raise ValueError(msg) + + if spi_data is not None and not isinstance(spi_data, pd.DataFrame | np.ndarray): + msg = "Invalid data type for SPI plot. Expected DataFrame or ndarray." + raise TypeError(msg) + + if spi_params is not None: + # Check if the import is available + try: + from soundscapy.spi.msn import CentredParams, DirectParams + + if not isinstance(spi_params, (DirectParams, CentredParams)): + msg = ( + "Invalid parameters for SPI plot. " + "Expected DirectParams or CentredParams." + ) + raise TypeError(msg) + except ImportError: + msg = ( + "Could not import DirectParams or CentredParams from soundscapy.spi.msn. " + "Please ensure the module is available." + ) + raise ImportError(msg) + + return spi_data, spi_params + + @staticmethod + def _generate_spi_data( + spi_data: pd.DataFrame | np.ndarray | None, + spi_params: DirectParams | CentredParams | None, + n: int, + ) -> pd.DataFrame: + """ + Validate and prepare SPI data from either direct data or parameters. + + Parameters + ---------- + spi_data : pd.DataFrame | np.ndarray | None + Data to use for SPI plotting + spi_params : DirectParams | CentredParams | None + Parameters to generate SPI data + n : int + Number of samples to generate if using msn_params + + Returns + ------- + pd.DataFrame + Prepared data for SPI plotting + + """ + # Generate data from parameters if provided + if spi_params is not None: + try: + from soundscapy.spi.msn import MultiSkewNorm + + spi_msn = MultiSkewNorm.from_params(spi_params) + sample_data = spi_msn.sample(n=n, return_sample=True) + spi_data = pd.DataFrame( + sample_data, + columns=["x", "y"], + ) + except ImportError: + msg = ( + "Could not import MultiSkewNorm from soundscapy.spi.msn. " + "Please ensure the module is available." + ) + raise ImportError(msg) + + if spi_data is not None: + # Process provided data + if isinstance(spi_data, np.ndarray): + if len(spi_data.shape) != 2 or spi_data.shape[1] != 2: # noqa: PLR2004 + msg = ( + "Invalid shape for SPI data. " + "Expected a 2D array with 2 columns." + ) + raise ValueError(msg) + spi_data = pd.DataFrame(spi_data, columns=["x", "y"]) + return spi_data + + msg = "Please provide either spi_data or msn_params." + raise ValueError(msg) + + def _process_spi_data( + self, spi_data: pd.DataFrame | np.ndarray, context: PlotContext + ) -> pd.DataFrame: + """ + Process SPI data into standard format. + + Parameters + ---------- + spi_data : pd.DataFrame | np.ndarray + Data to process + context : PlotContext + The context containing state for rendering + + Returns + ------- + pd.DataFrame + Processed data in standard format + + """ + params = self._get_params_from_context(context) + xcol = getattr(params, "x", context.x) + ycol = getattr(params, "y", context.y) + + # DataFrame handling + if isinstance(spi_data, pd.DataFrame): + if xcol not in spi_data.columns or ycol not in spi_data.columns: + spi_data = spi_data.rename(columns={"x": xcol, "y": ycol}) + return spi_data + + # Numpy array handling + if isinstance(spi_data, np.ndarray): + if len(spi_data.shape) != 2 or spi_data.shape[1] != 2: # noqa: PLR2004 + msg = "Invalid shape for SPI data. Expected a 2D array with 2 columns." + raise ValueError(msg) + return pd.DataFrame(spi_data, columns=[xcol, ycol]) + + msg = "Invalid SPI data type. Expected DataFrame or numpy array." + raise TypeError(msg) + + @staticmethod + def _calc_context_spi_score( + target_data: pd.DataFrame, + test_data: pd.DataFrame, + ) -> int | None: + """ + Calculate the SPI score between target and test data. + + Parameters + ---------- + target_data : pd.DataFrame + The target data + test_data : pd.DataFrame + The test data + + Returns + ------- + int | None + The SPI score, or None if calculation failed + + """ + try: + from soundscapy.spi import spi_score + + return spi_score(target=target_data, test=test_data) + except ImportError: + warnings.warn( + "Could not import spi_score from soundscapy.spi. " + "SPI score calculation will be skipped.", + UserWarning, + stacklevel=2, + ) + return None + + +class SPISimpleLayer(SPILayer, SimpleDensityLayer): """Layer for rendering SPI simple density plots.""" param_type = "spi_simple_density" + # Note: SPIDensityLayer and SPIScatterLayer could be implemented similarly + # by inheriting from SPILayer and DensityLayer or ScatterLayer respectively. + # For example: + # + # class SPIDensityLayer(SPILayer, DensityLayer): + # """Layer for rendering SPI density plots.""" + # param_type = "spi_density" + # + # class SPIScatterLayer(SPILayer, ScatterLayer): + # """Layer for rendering SPI scatter plots.""" + # param_type = "spi_scatter" + def _render_implementation( self, data: pd.DataFrame, @@ -402,33 +842,19 @@ def _render_implementation( # Render the SPI simple density plot sns.kdeplot(ax=ax, **kwargs) - # Add SPI score text if needed - if hasattr(spi_params, "show_score") and spi_params.show_score: - self._add_spi_score_text(context, ax, spi_params) - - def _add_spi_score_text( - self, context: PlotContext, ax: Axes, params: SPISimpleDensityParams - ) -> None: - """ - Add SPI score text to the plot. - - Parameters - ---------- - context : PlotContext - The context containing state for rendering - ax : Axes - The matplotlib axes to render on - params : SPISimpleDensityParams - The parameters for this layer - - """ - # This is a simplified version - in a real implementation, - # we would calculate and display the actual SPI score - if params.show_score == "under title": - # Add text under the title - if context.title: - ax.set_title(f"{context.title}\nSPI Score: 0.75") - elif params.show_score == "on axis" and params.axis_text_kw: - # Add text on the axis - text_kwargs = params.axis_text_kw.copy() - ax.text(s="SPI Score: 0.75", transform=ax.transAxes, **text_kwargs) + # Calculate SPI score + target_data = data[[context.x, context.y]] + test_data = context.data + + if test_data is not None: + spi_score = self._calc_context_spi_score(target_data, test_data) + + # Show the score + if hasattr(spi_params, "show_score") and spi_params.show_score: + self.show_score( + spi_score, + show_score=spi_params.show_score, + context=context, + ax=ax, + axis_text_kwargs=spi_params.axis_text_kw or {}, + ) diff --git a/src/soundscapy/plotting/new/managers.py b/src/soundscapy/plotting/new/managers.py index 484f900..212b0fb 100644 --- a/src/soundscapy/plotting/new/managers.py +++ b/src/soundscapy/plotting/new/managers.py @@ -8,17 +8,13 @@ from __future__ import annotations -import warnings -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar, overload import matplotlib.pyplot as plt import numpy as np import pandas as pd from matplotlib.axes import Axes -from soundscapy.plotting.new.constants import ( - RECOMMENDED_MIN_SAMPLES, -) from soundscapy.plotting.new.layer import ( DensityLayer, Layer, @@ -26,22 +22,36 @@ SimpleDensityLayer, SPISimpleLayer, ) -from soundscapy.plotting.new.protocols import RenderableLayer +from soundscapy.plotting.new.plot_context import PlotContext from soundscapy.sspylogging import get_logger +try: + # Python 3.13 made a @deprecated decorator available + from warnings import deprecated + +except ImportError: + # Fall back to using specific module + from deprecated import deprecated + if TYPE_CHECKING: - from soundscapy.plotting.new.plot_context import PlotContext + from soundscapy.plotting.new import ISOPlot + + _ISOPlotT = TypeVar("_ISOPlotT", bound="ISOPlot") + logger = get_logger() +# Type definitions +AxisSpec = int | tuple[int, int] | list[int] +LayerType = Literal["scatter", "density", "simple_density", "spi_simple"] +LayerClass = type[Layer] +LayerSpec = LayerType | LayerClass | Layer + class LayerManager: """ Manages the creation and rendering of visualization layers. - This class encapsulates the layer-related functionality that was previously - implemented as a mixin in the ISOPlot class. - Attributes ---------- plot : Any @@ -49,6 +59,13 @@ class LayerManager: """ + _LAYER_CLASSES: ClassVar[dict] = { + "scatter": ScatterLayer, + "density": DensityLayer, + "simple_density": SimpleDensityLayer, + "spi_simple": SPISimpleLayer, + } + def __init__(self, plot: Any) -> None: """ Initialize a LayerManager. @@ -61,195 +78,187 @@ def __init__(self, plot: Any) -> None: """ self.plot = plot - def add_scatter( + # Pass a fully instantiated layer + @overload + def add_layer( + self, layer: Layer, *, on_axis: AxisSpec | None = None + ) -> _ISOPlotT: ... + + # Pass an uninstantiated layer class + @overload + def add_layer( self, + layer_class: LayerClass, data: pd.DataFrame | None = None, *, - on_axis: int | tuple[int, int] | list[int] | None = None, + on_axis: AxisSpec | None = None, **params: Any, - ) -> Any: - """ - Add a scatter layer to the plot. - - Parameters - ---------- - data : pd.DataFrame | None, optional - Custom data for this layer, by default None - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes, by default None - **params : Any - Additional parameters for the scatter layer + ) -> _ISOPlotT: ... - Returns - ------- - Any - The parent plot instance for chaining - - """ - return self.add_layer( - ScatterLayer, - data=data, - on_axis=on_axis, - **params, - ) + # Pass a string of the layer type name + @overload + def add_layer( + self, + layer_type: LayerType, + data: pd.DataFrame | None = None, + *, + on_axis: AxisSpec | None = None, + **params: Any, + ) -> _ISOPlotT: ... - def add_density( + def add_layer( self, + layer_spec: LayerSpec, data: pd.DataFrame | None = None, *, - on_axis: int | tuple[int, int] | list[int] | None = None, + on_axis: AxisSpec | None = None, **params: Any, - ) -> Any: + ) -> _ISOPlotT: """ - Add a density layer to the plot. + Add a visualization layer, optionally targeting specific subplot(s). Parameters ---------- + layer_spec : LayerSpec + Either: + - A string layer type ("scatter", "density", "simple_density", "spi_simple") + - A layer class (ScatterLayer, DensityLayer, etc.) + - An already instantiated Layer object data : pd.DataFrame | None, optional - Custom data for this layer, by default None - on_axis : int | tuple[int, int] | list[int] | None, optional + Custom data for this layer, by default None. + Ignored if layer_spec is a Layer instance. + on_axis : AxisSpec, optional Target specific axis/axes, by default None **params : Any - Additional parameters for the density layer + Additional parameters for the layer. Special parameters: + - msn_params: Used only for "spi_simple" layer type + Ignored if layer_spec is a Layer instance. + Returns ------- - Any + ISOPlot The parent plot instance for chaining - """ - # Check if we have enough data for a density plot - plot_data = data if data is not None else self.plot.main_context.data - if plot_data is not None and len(plot_data) < RECOMMENDED_MIN_SAMPLES: - warnings.warn( - "Density plots are not recommended for " - f"small datasets (<{RECOMMENDED_MIN_SAMPLES} samples).", - UserWarning, - stacklevel=2, - ) + Examples + -------- + >>> import soundscapy as sspy + >>> from soundscapy.plotting.new import ISOPlot + >>> data = sspy.add_iso_coords(sspy.isd.load()) + >>> plot = ISOPlot(data).create_subplots(2, 2) - return self.add_layer( - DensityLayer, - data=data, - on_axis=on_axis, - **params, - ) + # Using layer type string + >>> custom_df = data.query("LocationID == 'CamdenTown'") + >>> plot.layer_mgr.add_layer("scatter", x="col1", y="col2", alpha=0.5) + >>> plot.layer_mgr.add_layer("density", data=custom_df, on_axis=1) + + # Using layer type class (works essentially the same as the string version) + >>> plot.layer_mgr.add_layer(ScatterLayer, x="col1", y="col2", alpha=0.5) + >>> plot.layer_mgr.add_layer(DensityLayer, data=custom_df, on_axis=1) + + # Using instantiated Layer object + >>> my_layer = ScatterLayer(x="col1", y="col2", alpha=0.7) + >>> plot.layer_mgr.add_layer(my_layer, on_axis=0) - def add_simple_density( - self, - data: pd.DataFrame | None = None, - *, - on_axis: int | tuple[int, int] | list[int] | None = None, - **params: Any, - ) -> Any: """ - Add a simple density layer to the plot. + # Handle the case when an instantiated Layer is provided + if isinstance(layer_spec, Layer): + # Use the provided layer directly + return self._render_layer(layer_spec, on_axis=on_axis, **params) - Parameters - ---------- - data : pd.DataFrame | None, optional - Custom data for this layer, by default None - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes, by default None - **params : Any - Additional parameters for the simple density layer + # Get the layer class from either the class or + layer_class = self._resolve_layer_class(layer_spec) - Returns - ------- - Any - The parent plot instance for chaining + # Create the layer instance + is_spi = "spi" in layer_class.__name__.lower() + if is_spi: + layer = layer_class(spi_target_data=data, **params) + else: + layer = layer_class(custom_data=data, **params) - """ - return self.add_layer( - SimpleDensityLayer, - data=data, - on_axis=on_axis, - **params, - ) + # Render the layer and return the plot + return self._render_layer(layer, on_axis=on_axis) - def add_spi_simple( - self, - data: pd.DataFrame | None = None, - *, - on_axis: int | tuple[int, int] | list[int] | None = None, - **params: Any, - ) -> Any: + def _resolve_layer_class(self, layer_spec: str | type[Layer]) -> type[Layer]: """ - Add an SPI simple layer to the plot. + Resolve a layer specification to a Layer class. Parameters ---------- - data : pd.DataFrame | None, optional - Custom data for this layer, by default None - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes, by default None - **params : Any - Additional parameters for the SPI simple layer + layer_spec : str | type[Layer] + Either a string layer type ("scatter", "density", "simple_density", "spi_simple") + or a Layer class (ScatterLayer, DensityLayer, etc.) Returns ------- - Any - The parent plot instance for chaining + type[Layer] + The resolved Layer class - """ - return self.add_layer( - SPISimpleLayer, - data=data, - on_axis=on_axis, - **params, + Raises + ------ + ValueError + If an unknown layer type string is provided + TypeError + If layer_spec is not a string or a Layer subclass + + """ + # Case 1: layer_spec is a Layer class + if isinstance(layer_spec, type) and issubclass(layer_spec, Layer): + return layer_spec + + # Case 2: layer_spec is a string + if isinstance(layer_spec, str): + layer_class = self._LAYER_CLASSES.get(layer_spec) + if layer_class is None: + msg = ( + f"Unknown layer type: {layer_spec}. " + f"Available layer types: {list(self._LAYER_CLASSES.keys())}" + ) + raise ValueError(msg) + return layer_class + + # If we get here, layer_spec was an invalid type + msg = ( + "Expected `layer_spec` to be either: " + f" - str (layer type name): {list(self._LAYER_CLASSES.keys())} " + " - Uninstantiated Layer class, e.g. ScatterLayer " + " - Already instantiated Layer object " + f"Got: {type(layer_spec).__name__}" ) + raise TypeError(msg) - def add_layer( - self, - layer_class: type[RenderableLayer], - data: pd.DataFrame | None = None, - *, - on_axis: int | tuple[int, int] | list[int] | None = None, - **params: Any, - ) -> Any: + def _render_layer( + self, layer: Layer, *, on_axis: AxisSpec | None = None + ) -> _ISOPlotT: """ - Add a visualization layer, optionally targeting specific subplot(s). + Render a layer on the appropriate axes. Parameters ---------- - layer_class : type[RenderableLayer] - The type of layer to add - data : pd.DataFrame | None, optional - Custom data for this layer, by default None - on_axis : int | tuple[int, int] | list[int] | None, optional + layer : Layer + The layer to render + on_axis : AxisSpec | None, optional Target specific axis/axes, by default None - **params : Any - Additional parameters for the layer Returns ------- - Any + ISOPlot The parent plot instance for chaining """ - # Create the layer instance - layer = cast("Layer", layer_class(custom_data=data, **params)) - - # Check if we have axes to render on - self._check_for_axes() - + # TODO: This should maybe be moved to a different class's responsibility + # What about encapsulating this in .get_contexts_by_spec ? + # Then LayerManager doesn't need to know anything about the context, + # it just asks for a list of contexts to add the layer to. # If no subplots created yet, add to main context - if not self.plot.subplot_contexts: - if self.plot.main_context.ax is None: - # Get the single axis and assign it to main context - if isinstance(self.plot.axes, Axes): - self.plot.main_context.ax = self.plot.axes - elif isinstance(self.plot.axes, np.ndarray) and self.plot.axes.size > 0: - self.plot.main_context.ax = self.plot.axes.flatten()[0] - - # Add layer to main context - self.plot.main_context.layers.append(layer) - # Render the layer immediately - layer.render(self.plot.main_context) - return self.plot + if self.plot.figure is None or self.plot.axes is None: + msg = "Cannot add layer to main context before creating subplots." + raise RuntimeError(msg) # Handle various axis targeting options - target_contexts = self._resolve_target_contexts(on_axis) + # TODO: If feels like this should be done via + # self.plot.get_contexts_by_spec(on_axis), rather than a classmethod + target_contexts = PlotContext.get_contexts_by_spec(self.plot, on_axis) # Add the layer to each target context and render it for context in target_contexts: @@ -258,95 +267,41 @@ def add_layer( return self.plot - def _check_for_axes(self) -> None: - """ - Check if we have axes to render on, create if needed. - - This method ensures that the plot has axes to render on, - creating them if necessary. - """ - if self.plot.figure is None: - # Create a new figure and axes - self.plot.figure, self.plot.axes = plt.subplots(figsize=(5, 5)) - - def _resolve_target_contexts( - self, on_axis: int | tuple[int, int] | list[int] | None - ) -> list[PlotContext]: - """ - Resolve which subplot contexts to target based on axis specification. - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] | None - The axis specification: - - None: All subplot contexts - - int: Single subplot at flattened index - - tuple[int, int]: Subplot at (row, col) - - list[int]: Multiple subplots at specified indices - - Returns - ------- - list[PlotContext] - List of target subplot contexts - - """ - # If no specific axis, target all subplot contexts - if on_axis is None: - return self.plot.subplot_contexts - - # Convert axis specification to list of indices - indices = self._resolve_axis_indices(on_axis) - - # Get the contexts for each valid index - target_contexts = [] - for idx in indices: - if 0 <= idx < len(self.plot.subplot_contexts): - target_contexts.append(self.plot.subplot_contexts[idx]) - else: - msg = f"Subplot index {idx} out of range" - raise IndexError(msg) - - return target_contexts - - def _resolve_axis_indices( - self, on_axis: int | tuple[int, int] | list[int] - ) -> list[int]: - """ - Convert axis specification to list of indices. - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] - The axis specification to resolve + @deprecated() + def add_scatter(self, data=None, *, on_axis=None, **params) -> _ISOPlotT: # noqa: ANN001 + """Legacy method that forwards to add_layer(layer_type="scatter", ...).""" + return self.add_layer("scatter", data=data, on_axis=on_axis, **params) - Returns - ------- - list[int] - List of flattened indices + @deprecated() + def add_density(self, data=None, *, on_axis=None, **params) -> _ISOPlotT: # noqa: ANN001 + """Legacy method that forwards to add_layer(layer_type="density", ...).""" + return self.add_layer("density", data=data, on_axis=on_axis, **params) - Raises - ------ - ValueError - If an invalid axis specification is provided + @deprecated() + def add_simple_density(self, data=None, *, on_axis=None, **params) -> _ISOPlotT: # noqa: ANN001 + """Legacy method that forwards to add_layer(layer_type="simple_density",...).""" + return self.add_layer("simple_density", data=data, on_axis=on_axis, **params) - """ - if isinstance(on_axis, int): - return [on_axis] - if isinstance(on_axis, tuple) and len(on_axis) == 2: - # Convert (row, col) to flattened index - row, col = on_axis - return [row * self.plot.subplots_params.ncols + col] - if isinstance(on_axis, list): - return on_axis - msg = f"Invalid axis specification: {on_axis}" - raise ValueError(msg) + @deprecated() + def add_spi_simple( + self, + data=None, # noqa: ANN001 + *, + msn_params=None, # noqa: ANN001 + on_axis=None, # noqa: ANN001 + **params, + ) -> _ISOPlotT: + """Legacy method that forwards to add_layer(layer_type="spi_simple", ...).""" + return self.add_layer( + "spi_simple", data=data, on_axis=on_axis, msn_params=msn_params, **params + ) class StyleManager: """ - Manages the styling of plots. + Manages the style_mgr of plots. - This class encapsulates the styling-related functionality that was previously + This class encapsulates the style_mgr-related functionality that was previously implemented as a mixin in the ISOPlot class. Attributes @@ -375,14 +330,14 @@ def apply_styling( **style_params: Any, ) -> Any: """ - Apply styling to the plot. + Apply style_mgr to the plot. Parameters ---------- on_axis : int | tuple[int, int] | list[int] | None, optional Target specific axis/axes, by default None **style_params : Any - Additional styling parameters + Additional style_mgr parameters Returns ------- @@ -399,7 +354,7 @@ def apply_styling( return self.plot # Apply to specified subplots - target_contexts = self._resolve_target_contexts(on_axis) + target_contexts = PlotContext.get_contexts_by_spec(self.plot, on_axis) for context in target_contexts: self._apply_styling_to_context(context) @@ -407,12 +362,12 @@ def apply_styling( def _apply_styling_to_context(self, context: PlotContext) -> None: """ - Apply styling to a specific context. + Apply style_mgr to a specific context. Parameters ---------- context : PlotContext - The context to apply styling to + The context to apply style_mgr to """ if context.ax is None: @@ -421,7 +376,7 @@ def _apply_styling_to_context(self, context: PlotContext) -> None: # Get style parameters style_params = context.get_params("style") - # Apply styling to the axes + # Apply style_mgr to the axes ax = context.ax # Set limits @@ -447,9 +402,10 @@ def _apply_styling_to_context(self, context: PlotContext) -> None: if hasattr(style_params, "primary_lines") and style_params.primary_lines: self._add_primary_lines(ax, style_params) - # Add diagonal lines + # Add diagonal lines and labels if hasattr(style_params, "diagonal_lines") and style_params.diagonal_lines: self._add_diagonal_lines(ax, style_params) + self._add_diagonal_labels(ax, style_params) # Add legend if needed if hasattr(style_params, "legend_loc") and style_params.legend_loc: @@ -503,8 +459,9 @@ def _add_diagonal_lines(self, ax: Axes, style_params: Any) -> None: ax.plot( xlim, ylim, - color="black", - linestyle="--", + color="grey", + linestyle="dashed", + alpha=0.5, linewidth=style_params.linewidth, zorder=style_params.diag_lines_zorder, ) @@ -513,83 +470,75 @@ def _add_diagonal_lines(self, ax: Axes, style_params: Any) -> None: ax.plot( xlim, ylim[::-1], - color="black", - linestyle="--", + color="grey", + linestyle="dashed", + alpha=0.5, linewidth=style_params.linewidth, zorder=style_params.diag_lines_zorder, ) - def _resolve_target_contexts( - self, on_axis: int | tuple[int, int] | list[int] | None - ) -> list[PlotContext]: + def _add_diagonal_labels(self, ax: Axes, style_params: Any) -> None: """ - Resolve which subplot contexts to target based on axis specification. + Add diagonal labels to the plot. Parameters ---------- - on_axis : int | tuple[int, int] | list[int] | None - The axis specification: - - None: All subplot contexts - - int: Single subplot at flattened index - - tuple[int, int]: Subplot at (row, col) - - list[int]: Multiple subplots at specified indices - - Returns - ------- - list[PlotContext] - List of target subplot contexts + ax : Axes + The axes to add labels to + style_params : Any + The style parameters """ - # If no specific axis, target all subplot contexts - if on_axis is None: - return self.plot.subplot_contexts - - # Convert axis specification to list of indices - indices = self._resolve_axis_indices(on_axis) - - # Get the contexts for each valid index - target_contexts = [] - for idx in indices: - if 0 <= idx < len(self.plot.subplot_contexts): - target_contexts.append(self.plot.subplot_contexts[idx]) - else: - msg = f"Subplot index {idx} out of range" - raise IndexError(msg) - - return target_contexts - - def _resolve_axis_indices( - self, on_axis: int | tuple[int, int] | list[int] - ) -> list[int]: - """ - Convert axis specification to list of indices. - - Parameters - ---------- - on_axis : int | tuple[int, int] | list[int] - The axis specification to resolve - - Returns - ------- - list[int] - List of flattened indices - - Raises - ------ - ValueError - If an invalid axis specification is provided + # Add diagonal labels + xlim = ax.get_xlim() + ylim = ax.get_ylim() - """ - if isinstance(on_axis, int): - return [on_axis] - if isinstance(on_axis, tuple) and len(on_axis) == 2: - # Convert (row, col) to flattened index - row, col = on_axis - return [row * self.plot.subplots_params.ncols + col] - if isinstance(on_axis, list): - return on_axis - msg = f"Invalid axis specification: {on_axis}" - raise ValueError(msg) + # Define font dictionary for diagonal labels + diag_ax_font = { + "fontstyle": "italic", + "fontsize": "small", + "fontweight": "bold", + "color": "black", + "alpha": 0.5, + } + + # Add the four diagonal labels + ax.text( + xlim[1] / 2, + ylim[1] / 2, + "(vibrant)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=style_params.diag_labels_zorder, + ) + ax.text( + xlim[0] / 2, + ylim[1] / 2, + "(chaotic)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=style_params.diag_labels_zorder, + ) + ax.text( + xlim[0] / 2, + ylim[0] / 2, + "(monotonous)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=style_params.diag_labels_zorder, + ) + ax.text( + xlim[1] / 2, + ylim[0] / 2, + "(calm)", + ha="center", + va="center", + fontdict=diag_ax_font, + zorder=style_params.diag_labels_zorder, + ) class SubplotManager: @@ -680,7 +629,7 @@ def _create_subplot_contexts(self) -> None: Create subplot contexts based on the current configuration. This method creates a PlotContext for each subplot, either with - the same data or with data split by a grouping variable. + the same custom_data or with custom_data split by a grouping variable. """ # Clear existing subplot contexts self.plot.subplot_contexts = [] @@ -693,47 +642,35 @@ def _create_subplot_contexts(self) -> None: self._create_subplots_by_group() return - # Otherwise, create a grid of subplots with the same data + # Otherwise, create a grid of subplots with the same custom_data axes = self.plot.axes if not isinstance(axes, np.ndarray): axes = np.array([[axes]]) # Create a context for each axis - for i in range(params.nrows): - for j in range(params.ncols): - # Get the axis for this subplot - ax = ( - axes[i, j] - if params.nrows > 1 and params.ncols > 1 - else axes[i] - if params.nrows > 1 - else axes[j] - if params.ncols > 1 - else axes - ) - - # Create a title for this subplot - title = ( - f"Subplot {i * params.ncols + j + 1}" - if self.plot.main_context.title is None - else f"{self.plot.main_context.title} {i * params.ncols + j + 1}" - ) + for i, ax in enumerate(axes.flatten()): + # Create a title for this subplot + title = ( + f"Subplot {i + 1}" + if self.plot.main_context.title is None + else f"{self.plot.main_context.title} {i + 1}" + ) - # Create a child context for this subplot - context = self.plot.main_context.create_child( - ax=ax, - title=title, - ) + # Create a child context for this subplot + context = self.plot.main_context.create_child( + ax=ax, + title=title, + ) - # Add to subplot contexts - self.plot.subplot_contexts.append(context) + # Add to subplot contexts + self.plot.subplot_contexts.append(context) def _create_subplots_by_group(self) -> None: """ - Create subplots by grouping the data. + Create subplots by grouping the custom_data. This method creates a subplot for each unique value in the - subplot_by column of the data. + subplot_by column of the custom_data. """ # Get subplot parameters params = self.plot.subplots_params @@ -777,7 +714,7 @@ def _create_subplots_by_group(self) -> None: else axes ) - # Filter data for this group + # Filter custom_data for this group group_data = data[data[subplot_by] == group] # Create a title for this subplot diff --git a/src/soundscapy/plotting/new/parameter_models.py b/src/soundscapy/plotting/new/parameter_models.py index 118985c..03f9025 100644 --- a/src/soundscapy/plotting/new/parameter_models.py +++ b/src/soundscapy/plotting/new/parameter_models.py @@ -8,11 +8,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias +from typing import Any, Literal, Self, TypeAlias import numpy as np import pandas as pd from matplotlib.colors import Colormap +from matplotlib.typing import ColorType from pydantic import BaseModel, ConfigDict, Field from pydantic.alias_generators import to_snake @@ -29,9 +30,6 @@ ) from soundscapy.sspylogging import get_logger -if TYPE_CHECKING: - from matplotlib.typing import ColorType - logger = get_logger() # Type aliases @@ -244,7 +242,7 @@ class SimpleDensityParams(DensityParams): class SPISeabornParams(SeabornParams): - """Base parameters for seaborn plotting functions for SPI data.""" + """Base parameters for seaborn plotting functions for SPI custom_data.""" color: ColorType | None = "red" hue: str | np.ndarray | pd.Series | None = None @@ -255,6 +253,9 @@ class SPISeabornParams(SeabornParams): axis_text_kw: dict[str, Any] | None = Field( default_factory=lambda: DEFAULT_SPI_TEXT_KWARGS.copy() ) + # msn_params is not directly stored in the parameter model + # but is used by SPILayer to generate SPI custom_data + # It should be passed to the layer constructor, not to the parameter model def as_seaborn_kwargs(self) -> dict[str, Any]: """ @@ -272,7 +273,7 @@ def as_seaborn_kwargs(self) -> dict[str, Any]: class SPISimpleDensityParams(SimpleDensityParams): - """Parameters for simple density plotting of SPI data.""" + """Parameters for simple density plotting of SPI custom_data.""" color: ColorType | None = "red" label: str = "SPI" @@ -310,7 +311,7 @@ class JointPlotParams(BaseParams): class StyleParams(BaseParams): """ - Configuration options for styling circumplex plots. + Configuration options for style_mgr circumplex plots. """ xlim: tuple[float, float] = DEFAULT_XLIM @@ -347,19 +348,19 @@ class SubplotsParams(BaseParams): @property def n_subplots(self) -> int: """ - Calculate the total number of subplots. + Calculate the total number of subplot_mgr. Returns ------- int - Total number of subplots. + Total number of subplot_mgr. """ return self.nrows * self.ncols def as_plt_subplots_args(self) -> dict[str, Any]: """ - Pass matplotlib subplot arguments to a plt.subplots call. + Pass matplotlib subplot arguments to a plt.subplot_mgr call. Returns ------- diff --git a/src/soundscapy/plotting/new/plot_context.py b/src/soundscapy/plotting/new/plot_context.py index 556ff66..5c540c1 100644 --- a/src/soundscapy/plotting/new/plot_context.py +++ b/src/soundscapy/plotting/new/plot_context.py @@ -1,14 +1,16 @@ """ -Data and state management for plotting layers. +Data and state management for plotting layer_mgr. -This module provides the PlotContext class that manages data, state, and parameters +This module provides the PlotContext class that manages custom_data, state, and parameters for ISOPlot visualizations. The PlotContext is the central component in the plotting -architecture, owning both data and parameter models for different layer types. +architecture, owning both custom_data and parameter models for different layer types. """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast + +import matplotlib.pyplot as plt from soundscapy.plotting.new.constants import DEFAULT_XCOL, DEFAULT_YCOL from soundscapy.plotting.new.parameter_models import ( @@ -25,27 +27,30 @@ import pandas as pd from matplotlib.axes import Axes + from soundscapy.plotting.new.iso_plot import ISOPlot from soundscapy.plotting.new.protocols import ParamModel, RenderableLayer + _ISOPlotT = TypeVar("_ISOPlotT", bound="ISOPlot") + logger = get_logger() class PlotContext: """ - Manages data, state, and parameters for a plot or subplot. + Manages custom_data, state, and parameters for a plot or subplot. - This class centralizes the management of data, coordinates, parameters, and other - state needed for rendering plot layers. It owns parameter models for different - layer types and provides them to layers when needed. + This class centralizes the management of custom_data, coordinates, parameters, and other + state needed for rendering plot layer_mgr. It owns parameter models for different + layer types and provides them to layer_mgr when needed. Attributes ---------- data : pd.DataFrame | None - The data associated with this context + The custom_data associated with this context x : str - The column name for x-axis data + The column name for x-axis custom_data y : str - The column name for y-axis data + The column name for y-axis custom_data hue : str | None The column name for color encoding, if any ax : Axes | None @@ -53,7 +58,7 @@ class PlotContext: title : str | None The title for this context's plot layers : list[RenderableLayer] - The visualization layers to be rendered on this context + The visualization layer_mgr to be rendered on this context parent : PlotContext | None The parent context, if this is a child context @@ -76,9 +81,9 @@ def __init__( data : pd.DataFrame | None Data to be visualized x : str - Column name for x-axis data + Column name for x-axis custom_data y : str - Column name for y-axis data + Column name for y-axis custom_data hue : str | None Column name for color encoding ax : Axes | None @@ -106,21 +111,20 @@ def __init__( def _init_param_models(self) -> None: """Initialize parameter models with context values.""" # Common parameters for all models - common_params = { - "data": self.data, - "x": self.x, - "y": self.y, - "hue": self.hue, - } + data = self.data + x = self.x + y = self.y + hue = self.hue # Create parameter models for different layer types - self._param_models["scatter"] = ScatterParams(**common_params) - self._param_models["density"] = DensityParams(**common_params) - self._param_models["simple_density"] = SimpleDensityParams(**common_params) + self._param_models["scatter"] = ScatterParams(data=data, x=x, y=y, hue=hue) + self._param_models["density"] = DensityParams(data=data, x=x, y=y, hue=hue) + self._param_models["simple_density"] = SimpleDensityParams(data=data, x=x, y=y) self._param_models["spi_simple_density"] = SPISimpleDensityParams( - **common_params + data=data, x=x, y=y, hue=hue ) self._param_models["style"] = StyleParams( + # TODO: Should not be setting defaults here! xlim=(-1, 1), ylim=(-1, 1), xlabel=r"$P_{ISO}$", @@ -252,3 +256,128 @@ def create_child( child.parent = self return child + + def ensure_axes_exist(self, plot: _ISOPlotT) -> None: + """ + Check if we have axes to render on, create if needed. + + This method ensures that the plot has axes to render on, + creating them if necessary. + + Parameters + ---------- + plot : Any + The parent plot instance + + """ + if plot.figure is None: + # Create a new figure and axes + logger.info("Creating new figure and axes") + plot.figure, plot.axes = plt.subplots(figsize=(5, 5)) + + def get_axes_by_spec( + self, plot: _ISOPlotT, spec: int | tuple[int, int] | list[int] | None + ) -> list[Axes]: + """ + Get axes based on specification. + + Parameters + ---------- + plot : ISOPlot + The parent plot instance + spec : int | tuple[int, int] | list[int] | None + The axis specification: + - None: All subplot axes + - int: Single subplot at flattened index + - tuple[int, int]: Subplot at (row, col) + - list[int]: Multiple subplot_mgr at specified indices + + Returns + ------- + list[Axes] + List of matplotlib Axes objects + + """ + # Get the contexts based on specification + contexts = self.get_contexts_by_spec(plot, spec) + + # Extract the axes from each context + return [context.ax for context in contexts if context.ax is not None] + + @classmethod + def get_contexts_by_spec( + cls, plot: _ISOPlotT, spec: int | tuple[int, int] | list[int] | None + ) -> list[PlotContext]: + """ + Resolve which subplot contexts to target based on axis specification. + + Parameters + ---------- + plot : ISOPlot + The parent plot instance + spec : int | tuple[int, int] | list[int] | None + The axis specification: + - None: All subplot contexts + - int: Single subplot at flattened index + - tuple[int, int]: Subplot at (row, col) + - list[int]: Multiple subplot_mgr at specified indices + + Returns + ------- + list[PlotContext] + List of target subplot contexts + + """ + # If no specific axis, target all subplot contexts + if spec is None: + return plot.subplot_contexts + + # Convert axis specification to list of indices + indices = cls.resolve_axis_indices(plot, spec) + + # Get the contexts for each valid index + target_contexts = [] + for idx in indices: + if 0 <= idx < len(plot.subplot_contexts): + target_contexts.append(plot.subplot_contexts[idx]) + else: + msg = f"Subplot index {idx} out of range" + raise IndexError(msg) + + return target_contexts + + @staticmethod + def resolve_axis_indices( + plot: _ISOPlotT, spec: int | tuple[int, int] | list[int] + ) -> list[int]: + """ + Convert axis specification to list of indices. + + Parameters + ---------- + plot : Any + The parent plot instance + spec : int | tuple[int, int] | list[int] + The axis specification to resolve + + Returns + ------- + list[int] + List of flattened indices + + Raises + ------ + ValueError + If an invalid axis specification is provided + + """ + if isinstance(spec, int): + return [spec] + if isinstance(spec, tuple) and len(spec) == 2: + # Convert (row, col) to flattened index + row, col = spec + return [row * plot.subplots_params.ncols + col] + if isinstance(spec, list): + return spec + msg = f"Invalid axis specification: {spec}" + raise ValueError(msg) diff --git a/src/soundscapy/plotting/new/protocols.py b/src/soundscapy/plotting/new/protocols.py index 15765f3..313d5d8 100644 --- a/src/soundscapy/plotting/new/protocols.py +++ b/src/soundscapy/plotting/new/protocols.py @@ -29,7 +29,7 @@ def render(self, context: PlotContext) -> None: Parameters ---------- context : PlotContext - The context containing data and axes for rendering + The context containing custom_data and axes for rendering """ ... diff --git a/uv.lock b/uv.lock index 90f94bf..08bb50a 100644 --- a/uv.lock +++ b/uv.lock @@ -680,6 +680,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] +[[package]] +name = "deprecated" +version = "1.2.18" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/97/06afe62762c9a8a86af0cfb7bfdab22a43ad17138b07af5b1a58442690a2/deprecated-1.2.18.tar.gz", hash = "sha256:422b6f6d859da6f2ef57857761bfb392480502a64c3028ca9bbe86085d72115d", size = 2928744, upload-time = "2025-01-27T10:46:25.7Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/c6/ac0b6c1e2d138f1002bcf799d330bd6d85084fece321e662a14223794041/Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec", size = 9998, upload-time = "2025-01-27T10:46:09.186Z" }, +] + [[package]] name = "distlib" version = "0.3.9" @@ -3606,6 +3618,7 @@ name = "soundscapy" source = { editable = "." } dependencies = [ { name = "coverage" }, + { name = "deprecated" }, { name = "loguru" }, { name = "numpy" }, { name = "pandas", extra = ["excel"] }, @@ -3675,6 +3688,7 @@ test = [ requires-dist = [ { name = "acoustic-toolbox", marker = "extra == 'audio'", specifier = ">=0.1.2" }, { name = "coverage", specifier = "==7.8.0" }, + { name = "deprecated", specifier = ">=1.2.18" }, { name = "loguru", specifier = ">=0.7.2" }, { name = "mosqito", marker = "extra == 'audio'", specifier = ">=1.2.1" }, { name = "numba", marker = "extra == 'audio'", specifier = ">=0.59" }, @@ -4131,6 +4145,70 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/07/c6fe3ad3e685340704d314d765b7912993bcb8dc198f0e7a89382d37974b/win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390", size = 4083, upload-time = "2024-12-07T15:28:26.465Z" }, ] +[[package]] +name = "wrapt" +version = "1.17.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/fc/e91cc220803d7bc4db93fb02facd8461c37364151b8494762cc88b0fbcef/wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3", size = 55531, upload-time = "2025-01-14T10:35:45.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/d1/1daec934997e8b160040c78d7b31789f19b122110a75eca3d4e8da0049e1/wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984", size = 53307, upload-time = "2025-01-14T10:33:13.616Z" }, + { url = "https://files.pythonhosted.org/packages/1b/7b/13369d42651b809389c1a7153baa01d9700430576c81a2f5c5e460df0ed9/wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22", size = 38486, upload-time = "2025-01-14T10:33:15.947Z" }, + { url = "https://files.pythonhosted.org/packages/62/bf/e0105016f907c30b4bd9e377867c48c34dc9c6c0c104556c9c9126bd89ed/wrapt-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80dd7db6a7cb57ffbc279c4394246414ec99537ae81ffd702443335a61dbf3a7", size = 38777, upload-time = "2025-01-14T10:33:17.462Z" }, + { url = "https://files.pythonhosted.org/packages/27/70/0f6e0679845cbf8b165e027d43402a55494779295c4b08414097b258ac87/wrapt-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a6e821770cf99cc586d33833b2ff32faebdbe886bd6322395606cf55153246c", size = 83314, upload-time = "2025-01-14T10:33:21.282Z" }, + { url = "https://files.pythonhosted.org/packages/0f/77/0576d841bf84af8579124a93d216f55d6f74374e4445264cb378a6ed33eb/wrapt-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b60fb58b90c6d63779cb0c0c54eeb38941bae3ecf7a73c764c52c88c2dcb9d72", size = 74947, upload-time = "2025-01-14T10:33:24.414Z" }, + { url = "https://files.pythonhosted.org/packages/90/ec/00759565518f268ed707dcc40f7eeec38637d46b098a1f5143bff488fe97/wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b870b5df5b71d8c3359d21be8f0d6c485fa0ebdb6477dda51a1ea54a9b558061", size = 82778, upload-time = "2025-01-14T10:33:26.152Z" }, + { url = "https://files.pythonhosted.org/packages/f8/5a/7cffd26b1c607b0b0c8a9ca9d75757ad7620c9c0a9b4a25d3f8a1480fafc/wrapt-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4011d137b9955791f9084749cba9a367c68d50ab8d11d64c50ba1688c9b457f2", size = 81716, upload-time = "2025-01-14T10:33:27.372Z" }, + { url = "https://files.pythonhosted.org/packages/7e/09/dccf68fa98e862df7e6a60a61d43d644b7d095a5fc36dbb591bbd4a1c7b2/wrapt-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1473400e5b2733e58b396a04eb7f35f541e1fb976d0c0724d0223dd607e0f74c", size = 74548, upload-time = "2025-01-14T10:33:28.52Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8e/067021fa3c8814952c5e228d916963c1115b983e21393289de15128e867e/wrapt-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3cedbfa9c940fdad3e6e941db7138e26ce8aad38ab5fe9dcfadfed9db7a54e62", size = 81334, upload-time = "2025-01-14T10:33:29.643Z" }, + { url = "https://files.pythonhosted.org/packages/4b/0d/9d4b5219ae4393f718699ca1c05f5ebc0c40d076f7e65fd48f5f693294fb/wrapt-1.17.2-cp310-cp310-win32.whl", hash = "sha256:582530701bff1dec6779efa00c516496968edd851fba224fbd86e46cc6b73563", size = 36427, upload-time = "2025-01-14T10:33:30.832Z" }, + { url = "https://files.pythonhosted.org/packages/72/6a/c5a83e8f61aec1e1aeef939807602fb880e5872371e95df2137142f5c58e/wrapt-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:58705da316756681ad3c9c73fd15499aa4d8c69f9fd38dc8a35e06c12468582f", size = 38774, upload-time = "2025-01-14T10:33:32.897Z" }, + { url = "https://files.pythonhosted.org/packages/cd/f7/a2aab2cbc7a665efab072344a8949a71081eed1d2f451f7f7d2b966594a2/wrapt-1.17.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ff04ef6eec3eee8a5efef2401495967a916feaa353643defcc03fc74fe213b58", size = 53308, upload-time = "2025-01-14T10:33:33.992Z" }, + { url = "https://files.pythonhosted.org/packages/50/ff/149aba8365fdacef52b31a258c4dc1c57c79759c335eff0b3316a2664a64/wrapt-1.17.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4db983e7bca53819efdbd64590ee96c9213894272c776966ca6306b73e4affda", size = 38488, upload-time = "2025-01-14T10:33:35.264Z" }, + { url = "https://files.pythonhosted.org/packages/65/46/5a917ce85b5c3b490d35c02bf71aedaa9f2f63f2d15d9949cc4ba56e8ba9/wrapt-1.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9abc77a4ce4c6f2a3168ff34b1da9b0f311a8f1cfd694ec96b0603dff1c79438", size = 38776, upload-time = "2025-01-14T10:33:38.28Z" }, + { url = "https://files.pythonhosted.org/packages/ca/74/336c918d2915a4943501c77566db41d1bd6e9f4dbc317f356b9a244dfe83/wrapt-1.17.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b929ac182f5ace000d459c59c2c9c33047e20e935f8e39371fa6e3b85d56f4a", size = 83776, upload-time = "2025-01-14T10:33:40.678Z" }, + { url = "https://files.pythonhosted.org/packages/09/99/c0c844a5ccde0fe5761d4305485297f91d67cf2a1a824c5f282e661ec7ff/wrapt-1.17.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f09b286faeff3c750a879d336fb6d8713206fc97af3adc14def0cdd349df6000", size = 75420, upload-time = "2025-01-14T10:33:41.868Z" }, + { url = "https://files.pythonhosted.org/packages/b4/b0/9fc566b0fe08b282c850063591a756057c3247b2362b9286429ec5bf1721/wrapt-1.17.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7ed2d9d039bd41e889f6fb9364554052ca21ce823580f6a07c4ec245c1f5d6", size = 83199, upload-time = "2025-01-14T10:33:43.598Z" }, + { url = "https://files.pythonhosted.org/packages/9d/4b/71996e62d543b0a0bd95dda485219856def3347e3e9380cc0d6cf10cfb2f/wrapt-1.17.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:129a150f5c445165ff941fc02ee27df65940fcb8a22a61828b1853c98763a64b", size = 82307, upload-time = "2025-01-14T10:33:48.499Z" }, + { url = "https://files.pythonhosted.org/packages/39/35/0282c0d8789c0dc9bcc738911776c762a701f95cfe113fb8f0b40e45c2b9/wrapt-1.17.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1fb5699e4464afe5c7e65fa51d4f99e0b2eadcc176e4aa33600a3df7801d6662", size = 75025, upload-time = "2025-01-14T10:33:51.191Z" }, + { url = "https://files.pythonhosted.org/packages/4f/6d/90c9fd2c3c6fee181feecb620d95105370198b6b98a0770cba090441a828/wrapt-1.17.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9a2bce789a5ea90e51a02dfcc39e31b7f1e662bc3317979aa7e5538e3a034f72", size = 81879, upload-time = "2025-01-14T10:33:52.328Z" }, + { url = "https://files.pythonhosted.org/packages/8f/fa/9fb6e594f2ce03ef03eddbdb5f4f90acb1452221a5351116c7c4708ac865/wrapt-1.17.2-cp311-cp311-win32.whl", hash = "sha256:4afd5814270fdf6380616b321fd31435a462019d834f83c8611a0ce7484c7317", size = 36419, upload-time = "2025-01-14T10:33:53.551Z" }, + { url = "https://files.pythonhosted.org/packages/47/f8/fb1773491a253cbc123c5d5dc15c86041f746ed30416535f2a8df1f4a392/wrapt-1.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:acc130bc0375999da18e3d19e5a86403667ac0c4042a094fefb7eec8ebac7cf3", size = 38773, upload-time = "2025-01-14T10:33:56.323Z" }, + { url = "https://files.pythonhosted.org/packages/a1/bd/ab55f849fd1f9a58ed7ea47f5559ff09741b25f00c191231f9f059c83949/wrapt-1.17.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d5e2439eecc762cd85e7bd37161d4714aa03a33c5ba884e26c81559817ca0925", size = 53799, upload-time = "2025-01-14T10:33:57.4Z" }, + { url = "https://files.pythonhosted.org/packages/53/18/75ddc64c3f63988f5a1d7e10fb204ffe5762bc663f8023f18ecaf31a332e/wrapt-1.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fc7cb4c1c744f8c05cd5f9438a3caa6ab94ce8344e952d7c45a8ed59dd88392", size = 38821, upload-time = "2025-01-14T10:33:59.334Z" }, + { url = "https://files.pythonhosted.org/packages/48/2a/97928387d6ed1c1ebbfd4efc4133a0633546bec8481a2dd5ec961313a1c7/wrapt-1.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fdbdb757d5390f7c675e558fd3186d590973244fab0c5fe63d373ade3e99d40", size = 38919, upload-time = "2025-01-14T10:34:04.093Z" }, + { url = "https://files.pythonhosted.org/packages/73/54/3bfe5a1febbbccb7a2f77de47b989c0b85ed3a6a41614b104204a788c20e/wrapt-1.17.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bb1d0dbf99411f3d871deb6faa9aabb9d4e744d67dcaaa05399af89d847a91d", size = 88721, upload-time = "2025-01-14T10:34:07.163Z" }, + { url = "https://files.pythonhosted.org/packages/25/cb/7262bc1b0300b4b64af50c2720ef958c2c1917525238d661c3e9a2b71b7b/wrapt-1.17.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d18a4865f46b8579d44e4fe1e2bcbc6472ad83d98e22a26c963d46e4c125ef0b", size = 80899, upload-time = "2025-01-14T10:34:09.82Z" }, + { url = "https://files.pythonhosted.org/packages/2a/5a/04cde32b07a7431d4ed0553a76fdb7a61270e78c5fd5a603e190ac389f14/wrapt-1.17.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc570b5f14a79734437cb7b0500376b6b791153314986074486e0b0fa8d71d98", size = 89222, upload-time = "2025-01-14T10:34:11.258Z" }, + { url = "https://files.pythonhosted.org/packages/09/28/2e45a4f4771fcfb109e244d5dbe54259e970362a311b67a965555ba65026/wrapt-1.17.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6d9187b01bebc3875bac9b087948a2bccefe464a7d8f627cf6e48b1bbae30f82", size = 86707, upload-time = "2025-01-14T10:34:12.49Z" }, + { url = "https://files.pythonhosted.org/packages/c6/d2/dcb56bf5f32fcd4bd9aacc77b50a539abdd5b6536872413fd3f428b21bed/wrapt-1.17.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9e8659775f1adf02eb1e6f109751268e493c73716ca5761f8acb695e52a756ae", size = 79685, upload-time = "2025-01-14T10:34:15.043Z" }, + { url = "https://files.pythonhosted.org/packages/80/4e/eb8b353e36711347893f502ce91c770b0b0929f8f0bed2670a6856e667a9/wrapt-1.17.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e8b2816ebef96d83657b56306152a93909a83f23994f4b30ad4573b00bd11bb9", size = 87567, upload-time = "2025-01-14T10:34:16.563Z" }, + { url = "https://files.pythonhosted.org/packages/17/27/4fe749a54e7fae6e7146f1c7d914d28ef599dacd4416566c055564080fe2/wrapt-1.17.2-cp312-cp312-win32.whl", hash = "sha256:468090021f391fe0056ad3e807e3d9034e0fd01adcd3bdfba977b6fdf4213ea9", size = 36672, upload-time = "2025-01-14T10:34:17.727Z" }, + { url = "https://files.pythonhosted.org/packages/15/06/1dbf478ea45c03e78a6a8c4be4fdc3c3bddea5c8de8a93bc971415e47f0f/wrapt-1.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:ec89ed91f2fa8e3f52ae53cd3cf640d6feff92ba90d62236a81e4e563ac0e991", size = 38865, upload-time = "2025-01-14T10:34:19.577Z" }, + { url = "https://files.pythonhosted.org/packages/ce/b9/0ffd557a92f3b11d4c5d5e0c5e4ad057bd9eb8586615cdaf901409920b14/wrapt-1.17.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6ed6ffac43aecfe6d86ec5b74b06a5be33d5bb9243d055141e8cabb12aa08125", size = 53800, upload-time = "2025-01-14T10:34:21.571Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ef/8be90a0b7e73c32e550c73cfb2fa09db62234227ece47b0e80a05073b375/wrapt-1.17.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:35621ae4c00e056adb0009f8e86e28eb4a41a4bfa8f9bfa9fca7d343fe94f998", size = 38824, upload-time = "2025-01-14T10:34:22.999Z" }, + { url = "https://files.pythonhosted.org/packages/36/89/0aae34c10fe524cce30fe5fc433210376bce94cf74d05b0d68344c8ba46e/wrapt-1.17.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a604bf7a053f8362d27eb9fefd2097f82600b856d5abe996d623babd067b1ab5", size = 38920, upload-time = "2025-01-14T10:34:25.386Z" }, + { url = "https://files.pythonhosted.org/packages/3b/24/11c4510de906d77e0cfb5197f1b1445d4fec42c9a39ea853d482698ac681/wrapt-1.17.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cbabee4f083b6b4cd282f5b817a867cf0b1028c54d445b7ec7cfe6505057cf8", size = 88690, upload-time = "2025-01-14T10:34:28.058Z" }, + { url = "https://files.pythonhosted.org/packages/71/d7/cfcf842291267bf455b3e266c0c29dcb675b5540ee8b50ba1699abf3af45/wrapt-1.17.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49703ce2ddc220df165bd2962f8e03b84c89fee2d65e1c24a7defff6f988f4d6", size = 80861, upload-time = "2025-01-14T10:34:29.167Z" }, + { url = "https://files.pythonhosted.org/packages/d5/66/5d973e9f3e7370fd686fb47a9af3319418ed925c27d72ce16b791231576d/wrapt-1.17.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8112e52c5822fc4253f3901b676c55ddf288614dc7011634e2719718eaa187dc", size = 89174, upload-time = "2025-01-14T10:34:31.702Z" }, + { url = "https://files.pythonhosted.org/packages/a7/d3/8e17bb70f6ae25dabc1aaf990f86824e4fd98ee9cadf197054e068500d27/wrapt-1.17.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fee687dce376205d9a494e9c121e27183b2a3df18037f89d69bd7b35bcf59e2", size = 86721, upload-time = "2025-01-14T10:34:32.91Z" }, + { url = "https://files.pythonhosted.org/packages/6f/54/f170dfb278fe1c30d0ff864513cff526d624ab8de3254b20abb9cffedc24/wrapt-1.17.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:18983c537e04d11cf027fbb60a1e8dfd5190e2b60cc27bc0808e653e7b218d1b", size = 79763, upload-time = "2025-01-14T10:34:34.903Z" }, + { url = "https://files.pythonhosted.org/packages/4a/98/de07243751f1c4a9b15c76019250210dd3486ce098c3d80d5f729cba029c/wrapt-1.17.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:703919b1633412ab54bcf920ab388735832fdcb9f9a00ae49387f0fe67dad504", size = 87585, upload-time = "2025-01-14T10:34:36.13Z" }, + { url = "https://files.pythonhosted.org/packages/f9/f0/13925f4bd6548013038cdeb11ee2cbd4e37c30f8bfd5db9e5a2a370d6e20/wrapt-1.17.2-cp313-cp313-win32.whl", hash = "sha256:abbb9e76177c35d4e8568e58650aa6926040d6a9f6f03435b7a522bf1c487f9a", size = 36676, upload-time = "2025-01-14T10:34:37.962Z" }, + { url = "https://files.pythonhosted.org/packages/bf/ae/743f16ef8c2e3628df3ddfd652b7d4c555d12c84b53f3d8218498f4ade9b/wrapt-1.17.2-cp313-cp313-win_amd64.whl", hash = "sha256:69606d7bb691b50a4240ce6b22ebb319c1cfb164e5f6569835058196e0f3a845", size = 38871, upload-time = "2025-01-14T10:34:39.13Z" }, + { url = "https://files.pythonhosted.org/packages/3d/bc/30f903f891a82d402ffb5fda27ec1d621cc97cb74c16fea0b6141f1d4e87/wrapt-1.17.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:4a721d3c943dae44f8e243b380cb645a709ba5bd35d3ad27bc2ed947e9c68192", size = 56312, upload-time = "2025-01-14T10:34:40.604Z" }, + { url = "https://files.pythonhosted.org/packages/8a/04/c97273eb491b5f1c918857cd26f314b74fc9b29224521f5b83f872253725/wrapt-1.17.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:766d8bbefcb9e00c3ac3b000d9acc51f1b399513f44d77dfe0eb026ad7c9a19b", size = 40062, upload-time = "2025-01-14T10:34:45.011Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ca/3b7afa1eae3a9e7fefe499db9b96813f41828b9fdb016ee836c4c379dadb/wrapt-1.17.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e496a8ce2c256da1eb98bd15803a79bee00fc351f5dfb9ea82594a3f058309e0", size = 40155, upload-time = "2025-01-14T10:34:47.25Z" }, + { url = "https://files.pythonhosted.org/packages/89/be/7c1baed43290775cb9030c774bc53c860db140397047cc49aedaf0a15477/wrapt-1.17.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d615e4fe22f4ad3528448c193b218e077656ca9ccb22ce2cb20db730f8d306", size = 113471, upload-time = "2025-01-14T10:34:50.934Z" }, + { url = "https://files.pythonhosted.org/packages/32/98/4ed894cf012b6d6aae5f5cc974006bdeb92f0241775addad3f8cd6ab71c8/wrapt-1.17.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5aaeff38654462bc4b09023918b7f21790efb807f54c000a39d41d69cf552cb", size = 101208, upload-time = "2025-01-14T10:34:52.297Z" }, + { url = "https://files.pythonhosted.org/packages/ea/fd/0c30f2301ca94e655e5e057012e83284ce8c545df7661a78d8bfca2fac7a/wrapt-1.17.2-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a7d15bbd2bc99e92e39f49a04653062ee6085c0e18b3b7512a4f2fe91f2d681", size = 109339, upload-time = "2025-01-14T10:34:53.489Z" }, + { url = "https://files.pythonhosted.org/packages/75/56/05d000de894c4cfcb84bcd6b1df6214297b8089a7bd324c21a4765e49b14/wrapt-1.17.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e3890b508a23299083e065f435a492b5435eba6e304a7114d2f919d400888cc6", size = 110232, upload-time = "2025-01-14T10:34:55.327Z" }, + { url = "https://files.pythonhosted.org/packages/53/f8/c3f6b2cf9b9277fb0813418e1503e68414cd036b3b099c823379c9575e6d/wrapt-1.17.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:8c8b293cd65ad716d13d8dd3624e42e5a19cc2a2f1acc74b30c2c13f15cb61a6", size = 100476, upload-time = "2025-01-14T10:34:58.055Z" }, + { url = "https://files.pythonhosted.org/packages/a7/b1/0bb11e29aa5139d90b770ebbfa167267b1fc548d2302c30c8f7572851738/wrapt-1.17.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c82b8785d98cdd9fed4cac84d765d234ed3251bd6afe34cb7ac523cb93e8b4f", size = 106377, upload-time = "2025-01-14T10:34:59.3Z" }, + { url = "https://files.pythonhosted.org/packages/6a/e1/0122853035b40b3f333bbb25f1939fc1045e21dd518f7f0922b60c156f7c/wrapt-1.17.2-cp313-cp313t-win32.whl", hash = "sha256:13e6afb7fe71fe7485a4550a8844cc9ffbe263c0f1a1eea569bc7091d4898555", size = 37986, upload-time = "2025-01-14T10:35:00.498Z" }, + { url = "https://files.pythonhosted.org/packages/09/5e/1655cf481e079c1f22d0cabdd4e51733679932718dc23bf2db175f329b76/wrapt-1.17.2-cp313-cp313t-win_amd64.whl", hash = "sha256:eaf675418ed6b3b31c7a989fd007fa7c3be66ce14e5c3b27336383604c9da85c", size = 40750, upload-time = "2025-01-14T10:35:03.378Z" }, + { url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594, upload-time = "2025-01-14T10:35:44.018Z" }, +] + [[package]] name = "xdoctest" version = "1.2.0" From 4de675ac7280d91146f1a312078bf98695bcb317 Mon Sep 17 00:00:00 2001 From: Andrew Mitchell Date: Sat, 10 May 2025 21:55:20 +0100 Subject: [PATCH 6/8] Add support for subplot grouping and advanced subplot options Enhanced plotting functionality by introducing `subplot_by` for grouping data in subplots, along with options like `auto_allocate_axes`, `adjust_figsize`, and `subplot_titles`. Refactored layer rendering and documentation to incorporate these features, improving customization and usability. Updated dependencies to include `types-deprecated`. Signed-off-by: Andrew Mitchell --- pyproject.toml | 1 + src/soundscapy/plotting/iso_plot.py | 5 +- src/soundscapy/plotting/new/constants.py | 2 +- src/soundscapy/plotting/new/iso_plot.py | 275 ++++++++----- src/soundscapy/plotting/new/layer.py | 43 +- src/soundscapy/plotting/new/managers.py | 367 +++++++++++++++--- .../plotting/new/parameter_models.py | 52 ++- src/soundscapy/plotting/new/plot_context.py | 36 +- src/soundscapy/plotting/new/protocols.py | 5 +- uv.lock | 11 + 10 files changed, 608 insertions(+), 189 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 91f7f94..9dbab2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dev = [ "setuptools-scm>=8.3.1", "tox>=4.25.0", "twine>=6.1.0", + "types-deprecated>=1.2.15.20250304", "types-pyyaml>=6.0.12.20250402", "types-seaborn>=0.13.2.20250111", ] diff --git a/src/soundscapy/plotting/iso_plot.py b/src/soundscapy/plotting/iso_plot.py index 4fa6f60..f622d62 100644 --- a/src/soundscapy/plotting/iso_plot.py +++ b/src/soundscapy/plotting/iso_plot.py @@ -339,10 +339,7 @@ def _check_data_x_y( ) data = pd.DataFrame({xcol: x, ycol: y}) - x = xcol - y = ycol - - return data, x, y + return data, xcol, ycol # If data is None, and x and y are strings: if isinstance(x, str) and isinstance(y, str): diff --git a/src/soundscapy/plotting/new/constants.py b/src/soundscapy/plotting/new/constants.py index 321b95f..cf91922 100644 --- a/src/soundscapy/plotting/new/constants.py +++ b/src/soundscapy/plotting/new/constants.py @@ -43,4 +43,4 @@ }, "ha": "center", "va": "center", -} \ No newline at end of file +} diff --git a/src/soundscapy/plotting/new/iso_plot.py b/src/soundscapy/plotting/new/iso_plot.py index ab584ee..ca4269a 100644 --- a/src/soundscapy/plotting/new/iso_plot.py +++ b/src/soundscapy/plotting/new/iso_plot.py @@ -19,9 +19,12 @@ ... columns=['ISOPleasant', 'ISOEventful'] ... ) >>> # Create a plot and add a scatter layer ->>> plot = ISOPlot(data=data).create_subplots() ->>> plot.add_scatter() ->>> plot.apply_styling() +>>> plot = ( +... ISOPlot(data=data) +... .create_subplots() +... .add_scatter() +... .apply_styling() +... ) >>> plot.show() >>> isinstance(plot, ISOPlot) True @@ -31,11 +34,14 @@ >>> # Add a group column to the data >>> data['Group'] = rng.integers(1, 3, 100) >>> # Create a plot with subplots by group ->>> plot = (ISOPlot(data=data, hue='Group') +>>> plot = ( +... ISOPlot(data=data, hue='Group') ... .create_subplots() ... .add_scatter() ... .add_simple_density(fill=False) -... .apply_styling()) +... .apply_styling() +... ) +>>> plot.show() >>> isinstance(plot, ISOPlot) True @@ -43,7 +49,6 @@ from __future__ import annotations -import functools from typing import TYPE_CHECKING, Any, Literal import matplotlib.pyplot as plt @@ -310,9 +315,12 @@ def create_subplots( ncols: int = 1, figsize: tuple[float, float] | None = None, subplot_by: str | None = None, + subplot_titles: list[str] | None = None, *, sharex: bool | Literal["none", "all", "row", "col"] = True, sharey: bool | Literal["none", "all", "row", "col"] = True, + auto_allocate_axes: bool = False, + adjust_figsize: bool = True, **kwargs: Any, ) -> ISOPlot: """ @@ -328,11 +336,18 @@ def create_subplots( Figure size, by default None subplot_by : str | None, optional Column to create subplots by, by default None + subplot_titles : list[str] | None, optional + Custom titles for subplots, by default None * sharex : bool | Literal["none", "all", "row", "col"], optional Whether to share x-axis, by default True sharey : bool | Literal["none", "all", "row", "col"], optional Whether to share y-axis, by default True + auto_allocate_axes : bool, optional + Whether to automatically determine nrows/ncols based on data, + by default False + adjust_figsize : bool, optional + Whether to adjust the figure size based on nrows/ncols, by default True **kwargs : Any Additional parameters for subplots @@ -361,7 +376,7 @@ def create_subplots( >>> data['Group'] = rng.integers(1, 3, 100) >>> plot = ISOPlot(data=data) - >>> plot = plot.create_subplots(subplot_by='Group') + >>> plot = plot.create_subplots(subplot_by='Group', auto_allocate_axes=True) >>> len(plot.subplot_contexts) 2 >>> plot.subplot_contexts[0].title is not None @@ -375,6 +390,9 @@ def create_subplots( sharex=sharex, sharey=sharey, subplot_by=subplot_by, + subplot_titles=subplot_titles, + auto_allocate_axes=auto_allocate_axes, + adjust_figsize=adjust_figsize, **kwargs, ) @@ -383,6 +401,7 @@ def add_scatter( data: pd.DataFrame | None = None, *, on_axis: int | tuple[int, int] | list[int] | None = None, + subplot_by: str | None = None, **params: Any, ) -> ISOPlot: """ @@ -394,6 +413,10 @@ def add_scatter( Custom data for this layer, by default None on_axis : int | tuple[int, int] | list[int] | None, optional Target specific axis/axes, by default None + subplot_by : str | None, optional + Column to split data across existing subplots, by default None. + If provided, the data will be split based on unique values in this column + and rendered on the corresponding subplots. **params : Any Additional parameters for the scatter layer @@ -413,32 +436,48 @@ def add_scatter( ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data).create_subplots() - >>> plot = plot.add_scatter() - >>> len(plot.subplot_contexts[0].layers) + >>> plot1 = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .apply_styling() + ... ) + >>> plot1.show() + >>> len(plot1.subplot_contexts[0].layers) 1 Add a scatter layer with custom parameters: - >>> plot = ISOPlot(data=data).create_subplots() - >>> plot = plot.add_scatter(s=50, alpha=0.5, color='red') - >>> len(plot.subplot_contexts[0].layers) + >>> plot2 = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter(s=50, alpha=0.5, color='red') + ... .apply_styling() + ... ) + >>> plot2.show() + >>> len(plot2.subplot_contexts[0].layers) 1 Add a scatter layer to a specific subplot: - >>> plot = ISOPlot(data=data) - >>> plot = plot.create_subplots(nrows=2,ncols=2) - >>> plot = plot.add_scatter(on_axis=0) - >>> len(plot.subplot_contexts[0].layers) + >>> plot3 = ( + ... ISOPlot(data=data) + ... .create_subplots(nrows=2,ncols=2) + ... .add_scatter(on_axis=0) + ... .apply_styling() + ... ) + >>> plot3.show() + >>> len(plot3.subplot_contexts[0].layers) 1 - >>> len(plot.subplot_contexts[1].layers) + >>> len(plot3.subplot_contexts[1].layers) 0 """ - return self.layer_mgr.add_scatter( + return self.layer_mgr.add_layer( + "scatter", data=data, on_axis=on_axis, + subplot_by=subplot_by, **params, ) @@ -447,6 +486,7 @@ def add_density( data: pd.DataFrame | None = None, *, on_axis: int | tuple[int, int] | list[int] | None = None, + subplot_by: str | None = None, **params: Any, ) -> ISOPlot: """ @@ -458,6 +498,10 @@ def add_density( Custom data for this layer, by default None on_axis : int | tuple[int, int] | list[int] | None, optional Target specific axis/axes, by default None + subplot_by : str | None, optional + Column to split data across existing subplots, by default None. + If provided, the data will be split based on unique values in this column + and rendered on the corresponding subplots. **params : Any Additional parameters for the density layer @@ -477,32 +521,46 @@ def add_density( ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data).create_subplots() - >>> plot = plot.add_density() + >>> plot = ( + ... ISOPlot(data=data).create_subplots() + ... .add_density() + ... .apply_styling() + ... ) + >>> plot.show() >>> len(plot.subplot_contexts[0].layers) 1 Add a density layer with custom parameters: - >>> plot = ISOPlot(data=data).create_subplots() - >>> plot = plot.add_density(fill=False, levels=5, alpha=0.7) + >>> plot = ( + ... ISOPlot(data=data).create_subplots() + ... .add_density(fill=False, levels=5, alpha=0.7) + ... .apply_styling() + ... ) + >>> plot.show() >>> len(plot.subplot_contexts[0].layers) 1 Add a density layer to a specific subplot: - >>> plot = ISOPlot(data=data) - >>> plot = plot.create_subplots(nrows=2,ncols=2) - >>> plot = plot.add_density(on_axis=1) + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots(nrows=2,ncols=2) + ... .add_density(on_axis=1) + ... .apply_styling() + ... ) + >>> plot.show() >>> len(plot.subplot_contexts[0].layers) 0 >>> len(plot.subplot_contexts[1].layers) 1 """ - return self.layer_mgr.add_density( + return self.layer_mgr.add_layer( + "density", data=data, on_axis=on_axis, + subplot_by=subplot_by, **params, ) @@ -541,23 +599,35 @@ def add_simple_density( ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), ... columns=['ISOPleasant', 'ISOEventful'] ... ) - >>> plot = ISOPlot(data=data).create_subplots() - >>> plot = plot.add_simple_density() + >>> plot = ( + ... ISOPlot(data=data).create_subplots() + ... .add_simple_density() + ... .apply_styling() + ... ) + >>> plot.show() >>> len(plot.subplot_contexts[0].layers) 1 Add a simple density layer with custom parameters: - >>> plot = ISOPlot(data=data).create_subplots() - >>> plot = plot.add_simple_density(fill=False, thresh=0.3) + >>> plot = ( + ... ISOPlot(data=data).create_subplots() + ... .add_simple_density(fill=False, thresh=0.3) + ... .apply_styling() + ... ) + >>> plot.show() >>> len(plot.subplot_contexts[0].layers) 1 Add a simple density layer to multiple subplots: - >>> plot = ISOPlot(data=data) - >>> plot = plot.create_subplots(nrows=2,ncols=2) - >>> plot = plot.add_simple_density(on_axis=[0, 2]) + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots(nrows=2,ncols=2) + ... .add_simple_density(on_axis=[0, 2]) + ... .apply_styling() + ... ) + >>> plot.show() >>> len(plot.subplot_contexts[0].layers) 1 >>> len(plot.subplot_contexts[1].layers) @@ -596,8 +666,55 @@ def add_spi_simple( ISOPlot The current plot instance for chaining + Examples + -------- + Add a SPI layer to all subplots: + + >>> import pandas as pd + >>> import numpy as np + >>> from soundscapy.spi import DirectParams + >>> rng = np.random.default_rng(42) + >>> # Create a DataFrame with random data + >>> data = pd.DataFrame( + ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), + ... columns=['ISOPleasant', 'ISOEventful'] + ... ) + >>> # Define MSN parameters for the SPI target + >>> msn_params = DirectParams( + ... xi=np.array([0.5, 0.7]), + ... omega=np.array([[0.1, 0.05], [0.05, 0.1]]), + ... alpha=np.array([0, -5]), + ... ) + >>> # Create the plot with only an SPI layer + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .add_spi_simple(msn_params=msn_params) + ... .apply_styling() + ... ) + >>> plot.show() + >>> len(plot.subplot_contexts[0].layers) == 2 + True + >>> plot.close() # Clean up + + Add an SPI layer over top of 'real' data: + >>> plot = ( + ... ISOPlot(data=data) + ... .create_subplots() + ... .add_scatter() + ... .add_density() + ... .add_spi_simple(msn_params=msn_params, show_score="on axis") + ... .apply_styling() + ... ) + >>> plot.show() + >>> len(plot.subplot_contexts[0].layers) == 3 + True + >>> plot.close() # Clean up + """ - return self.layer_mgr.add_spi_simple( + return self.layer_mgr.add_layer( + "spi_simple", data=data, on_axis=on_axis, **params, @@ -609,71 +726,48 @@ def add_layer( data: pd.DataFrame | None = None, *, on_axis: AxisSpec | None = None, + subplot_by: str | None = None, **params: Any, ) -> ISOPlot: """ - Add a visualization layer, optionally targeting specific subplot(s). + Add a new layer to the plot based on the provided specifications and parameters. + + This function integrates a new visual layer into the current plot by specifying + what and how to display data. It uses the `layer_mgr.add_layer` method to handle + the layer addition. The layer can be associated with a specific axis, use a subset + of data, and be customized with additional parameters. Parameters ---------- - layer_class : Layer subclass - The type of layer to add - data : pd.DataFrame | None, optional - Custom data for this layer, by default None - on_axis : int | tuple[int, int] | list[int] | None, optional - Target specific axis/axes, by default None + layer_spec : LayerSpec + The specification of the layer to be added. Defines the functionality, + appearance, and behavior of the layer. + data : pd.DataFrame, optional + A Pandas DataFrame holding the data to be used by the layer. If not + provided, the layer might use existing data from the plot or handle + plotting differently based on its implementation. + on_axis : AxisSpec, optional + Specifies the axis where the layer should be drawn. If omitted, defaults + to the plot's primary or default axis. + subplot_by : str, optional + Key or column in the data to use for creating subplots. Used for splitting + data and visualizing it in separate subplots, if required. **params : Any - Additional parameters for the layer + Additional custom parameters that may be required by the layer. These + parameters are passed through to configure the layer further. Returns ------- ISOPlot - The current plot instance for chaining - - Examples - -------- - Add a scatter layer using the generic add_layer method: - - >>> import pandas as pd - >>> import numpy as np - >>> from soundscapy.plotting.new import ScatterLayer - >>> rng = np.random.default_rng(42) - >>> data = pd.DataFrame( - ... rng.multivariate_normal([0.2, 0.15], [[0.1, 0], [0, 0.2]], 100), - ... columns=['ISOPleasant', 'ISOEventful'] - ... ) - >>> plot = ISOPlot(data=data).create_subplots() - >>> plot = plot.add_layer(ScatterLayer) - >>> len(plot.subplot_contexts[0].layers) - 1 - >>> isinstance(plot.subplot_contexts[0].layers[0], ScatterLayer) - True - - Add a density layer to a specific subplot: - - >>> plot = ISOPlot(data=data) - >>> plot = plot.create_subplots(nrows=2,ncols=2) - >>> plot = plot.add_layer("density", on_axis=3, fill=False) - >>> len(plot.subplot_contexts[3].layers) - 1 - >>> from soundscapy.plotting.new import DensityLayer - >>> isinstance(plot.subplot_contexts[3].layers[0], DensityLayer) - True - - Add a layer with custom data: - - >>> custom_data = pd.DataFrame({ - ... 'ISOPleasant': rng.normal(0.5, 0.1, 50), - ... 'ISOEventful': rng.normal(0.5, 0.1, 50), - ... }) - >>> plot = ISOPlot(data=data).create_subplots() - >>> plot = plot.add_layer(ScatterLayer, data=custom_data, color='red') - >>> len(plot.subplot_contexts[0].layers) - 1 + The updated instance of the plot with the new layer added. """ return self.layer_mgr.add_layer( - layer_spec=layer_spec, data=data, on_axis=on_axis, **params + layer_spec=layer_spec, + data=data, + on_axis=on_axis, + subplot_by=subplot_by, + **params, ) def apply_styling( @@ -711,6 +805,7 @@ def apply_styling( >>> plot = ISOPlot(data=data).create_subplots() >>> plot = plot.add_scatter() >>> plot = plot.apply_styling() + >>> plot.show() >>> isinstance(plot, ISOPlot) True @@ -726,6 +821,7 @@ def apply_styling( ... primary_lines=True, ... diagonal_lines=True ... ) + >>> plot.show() >>> isinstance(plot, ISOPlot) True @@ -735,6 +831,7 @@ def apply_styling( >>> plot = plot.create_subplots(nrows=2,ncols=2) >>> plot = plot.add_scatter() >>> plot = plot.apply_styling(on_axis=0, title="Subplot 0") + >>> plot.show() >>> isinstance(plot, ISOPlot) True @@ -744,7 +841,6 @@ def apply_styling( **style_params, ) - @functools.wraps(plt.show) def show(self) -> None: """ Display the plot. @@ -763,12 +859,11 @@ def show(self) -> None: >>> plot = ISOPlot(data=data).create_subplots() >>> plot.add_scatter() >>> plot.apply_styling() - >>> # plot.show() # Uncomment to display the plot + >>> plot.show() """ plt.show() - @functools.wraps(plt.close) def close(self) -> None: """ Close the plot. @@ -787,7 +882,7 @@ def close(self) -> None: >>> plot = ISOPlot(data=data).create_subplots() >>> plot.add_scatter() >>> plot.apply_styling() - >>> # plot.show() # Uncomment to display the plot + >>> plot.show() >>> plot.close() # Close the plot """ diff --git a/src/soundscapy/plotting/new/layer.py b/src/soundscapy/plotting/new/layer.py index 939d9c5..f69fb72 100644 --- a/src/soundscapy/plotting/new/layer.py +++ b/src/soundscapy/plotting/new/layer.py @@ -16,6 +16,9 @@ import seaborn as sns from soundscapy.plotting.new.constants import RECOMMENDED_MIN_SAMPLES +from soundscapy.plotting.new.parameter_models import ( + _LayerParamsT, +) from soundscapy.sspylogging import get_logger if TYPE_CHECKING: @@ -28,7 +31,7 @@ SimpleDensityParams, SPISimpleDensityParams, ) - from soundscapy.plotting.new.protocols import PlotContext + from soundscapy.plotting.new.plot_context import PlotContext from soundscapy.spi.msn import ( CentredParams, DirectParams, @@ -121,7 +124,7 @@ def render(self, context: PlotContext) -> None: # Render the layer self._render_implementation(data, context, context.ax, params) - def _get_params_from_context(self, context: PlotContext) -> BaseParams: + def _get_params_from_context(self, context: PlotContext) -> _LayerParamsT: """ Get parameters from context and apply overrides. @@ -138,19 +141,25 @@ def _get_params_from_context(self, context: PlotContext) -> BaseParams: """ # Get parameters from context based on layer type params = context.get_params_for_layer(type(self)) + if not isinstance(params, _LayerParamsT): + msg = ( + f"Invalid parameters for layer type {type(self).__name__}: " + f"expected {type(_LayerParamsT)}, got {type(params)}" + ) + raise TypeError(msg) # Apply overrides if self.param_overrides: params.update(**self.param_overrides) - return cast("BaseParams", params) + return params def _render_implementation( self, data: pd.DataFrame, context: PlotContext, ax: Axes, - params: BaseParams, + params: _LayerParamsT, ) -> None: """ Implement actual rendering (to be overridden by subclasses). @@ -228,7 +237,7 @@ def _render_implementation( data: pd.DataFrame, context: PlotContext, ax: Axes, - params: BaseParams, + params: ScatterParams, # type: ignore[override] ) -> None: """ Render a scatter plot. @@ -245,11 +254,8 @@ def _render_implementation( The parameters for this layer """ - # Cast params to the correct type - scatter_params = cast("ScatterParams", params) - # Create a copy of the parameters with data - kwargs = scatter_params.as_seaborn_kwargs() + kwargs = params.as_seaborn_kwargs() kwargs["data"] = data # Ensure x and y are set correctly @@ -270,7 +276,7 @@ def _render_implementation( data: pd.DataFrame, context: PlotContext, ax: Axes, - params: BaseParams, + params: DensityParams, # type: ignore[override] ) -> None: """ Render a density plot. @@ -296,11 +302,8 @@ def _render_implementation( stacklevel=2, ) - # Cast params to the correct type - density_params = cast("DensityParams", params) - # Create a copy of the parameters with data - kwargs = density_params.as_seaborn_kwargs() + kwargs = params.as_seaborn_kwargs() kwargs["data"] = data # Ensure x and y are set correctly @@ -321,7 +324,7 @@ def _render_implementation( data: pd.DataFrame, context: PlotContext, ax: Axes, - params: BaseParams, + params: SimpleDensityParams, # type: ignore[override] ) -> None: """ Render a simple density plot. @@ -565,7 +568,7 @@ def _add_score_as_text(ax: Axes, spi_score: int, **text_kwargs: Any) -> None: """ from soundscapy.plotting.new.constants import DEFAULT_SPI_TEXT_KWARGS - text_kwargs_copy = DEFAULT_SPI_TEXT_KWARGS.copy() + text_kwargs_copy: dict[str, Any] = DEFAULT_SPI_TEXT_KWARGS.copy() text_kwargs_copy.update(**text_kwargs) text_kwargs_copy["s"] = f"SPI: {spi_score}" @@ -807,7 +810,7 @@ def _render_implementation( data: pd.DataFrame, context: PlotContext, ax: Axes, - params: BaseParams, + params: SPISimpleDensityParams, # type: ignore[override] ) -> None: """ Render an SPI simple density plot. @@ -824,8 +827,7 @@ def _render_implementation( The parameters for this layer """ - # Cast params to the correct type - spi_params = cast("SPISimpleDensityParams", params) + spi_params = params # Create a copy of the parameters with data kwargs = spi_params.as_seaborn_kwargs() @@ -858,3 +860,6 @@ def _render_implementation( ax=ax, axis_text_kwargs=spi_params.axis_text_kw or {}, ) + + +_LayerT = Layer | ScatterLayer | DensityLayer | SPILayer | SPISimpleLayer diff --git a/src/soundscapy/plotting/new/managers.py b/src/soundscapy/plotting/new/managers.py index 212b0fb..40e075d 100644 --- a/src/soundscapy/plotting/new/managers.py +++ b/src/soundscapy/plotting/new/managers.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast import matplotlib.pyplot as plt import numpy as np @@ -27,7 +27,7 @@ try: # Python 3.13 made a @deprecated decorator available - from warnings import deprecated + from warnings import deprecated # type: ignore[attr-defined] except ImportError: # Fall back to using specific module @@ -35,8 +35,7 @@ if TYPE_CHECKING: from soundscapy.plotting.new import ISOPlot - - _ISOPlotT = TypeVar("_ISOPlotT", bound="ISOPlot") + from soundscapy.plotting.new.parameter_models import StyleParams logger = get_logger() @@ -44,8 +43,7 @@ # Type definitions AxisSpec = int | tuple[int, int] | list[int] LayerType = Literal["scatter", "density", "simple_density", "spi_simple"] -LayerClass = type[Layer] -LayerSpec = LayerType | LayerClass | Layer +LayerSpec = LayerType | type[Layer] | Layer class LayerManager: @@ -78,33 +76,39 @@ def __init__(self, plot: Any) -> None: """ self.plot = plot - # Pass a fully instantiated layer - @overload - def add_layer( - self, layer: Layer, *, on_axis: AxisSpec | None = None - ) -> _ISOPlotT: ... - - # Pass an uninstantiated layer class - @overload - def add_layer( - self, - layer_class: LayerClass, - data: pd.DataFrame | None = None, - *, - on_axis: AxisSpec | None = None, - **params: Any, - ) -> _ISOPlotT: ... - - # Pass a string of the layer type name - @overload - def add_layer( - self, - layer_type: LayerType, - data: pd.DataFrame | None = None, - *, - on_axis: AxisSpec | None = None, - **params: Any, - ) -> _ISOPlotT: ... + # # Pass a fully instantiated layer + # @overload + # def add_layer( + # self, + # layer_spec: Layer, + # *, + # on_axis: AxisSpec | None = None, + # subplot_by: str | None = None, + # ) -> ISOPlot: ... + # + # # Pass an uninstantiated layer class + # @overload + # def add_layer( + # self, + # layer_spec: LayerClass, + # data: pd.DataFrame | None = None, + # *, + # on_axis: AxisSpec | None = None, + # subplot_by: str | None = None, + # **params: Any, + # ) -> ISOPlot: ... + # + # # Pass a string of the layer type name + # @overload + # def add_layer( + # self, + # layer_spec: LayerType, + # data: pd.DataFrame | None = None, + # *, + # on_axis: AxisSpec | None = None, + # subplot_by: str | None = None, + # **params: Any, + # ) -> ISOPlot: ... def add_layer( self, @@ -112,8 +116,9 @@ def add_layer( data: pd.DataFrame | None = None, *, on_axis: AxisSpec | None = None, + subplot_by: str | None = None, **params: Any, - ) -> _ISOPlotT: + ) -> ISOPlot: """ Add a visualization layer, optionally targeting specific subplot(s). @@ -129,6 +134,12 @@ def add_layer( Ignored if layer_spec is a Layer instance. on_axis : AxisSpec, optional Target specific axis/axes, by default None + subplot_by : str | None, optional + Column to split data across existing subplots, by default None. + If provided, the data will be split based on unique values in this column + and rendered on the corresponding subplots. + Note: This is different from the subplot_by parameter in create_subplots, + which creates new subplots based on unique values. **params : Any Additional parameters for the layer. Special parameters: - msn_params: Used only for "spi_simple" layer type @@ -164,11 +175,19 @@ def add_layer( # Handle the case when an instantiated Layer is provided if isinstance(layer_spec, Layer): # Use the provided layer directly - return self._render_layer(layer_spec, on_axis=on_axis, **params) + return self._render_layer( + layer_spec, on_axis=on_axis, subplot_by=subplot_by, **params + ) # Get the layer class from either the class or layer_class = self._resolve_layer_class(layer_spec) + # If subplot_by is provided, we need to handle it specially + if subplot_by is not None: + return self._render_layer_with_subplot_by( + layer_class, data, subplot_by, on_axis, **params + ) + # Create the layer instance is_spi = "spi" in layer_class.__name__.lower() if is_spi: @@ -227,9 +246,106 @@ def _resolve_layer_class(self, layer_spec: str | type[Layer]) -> type[Layer]: ) raise TypeError(msg) + def _render_layer_with_subplot_by( + self, + layer_class: type[Layer], + data: pd.DataFrame | None, + subplot_by: str, + on_axis: AxisSpec | None = None, + **params: Any, + ) -> ISOPlot: + """ + Render a layer with data split across subplots based on a grouping variable. + + Parameters + ---------- + layer_class : type[Layer] + The layer class to instantiate + data : pd.DataFrame | None + The data to split and render + subplot_by : str + Column to split data by + on_axis : AxisSpec | None, optional + Target specific axis/axes, by default None + **params : Any + Additional parameters for the layer + + Returns + ------- + ISOPlot + The parent plot instance for chaining + + Raises + ------ + ValueError + If the subplot_by column doesn't exist in the data + RuntimeError + If no subplots have been created yet + + """ + # If no subplots created yet, raise an error + if self.plot.figure is None or self.plot.axes is None: + msg = "Cannot add layer to main context before creating subplots." + raise RuntimeError(msg) + + # If no data provided, use the main context data + if data is None: + data = self.plot.main_context.data + + # If still no data, raise an error + if data is None: + msg = "No data provided for layer and no data in main context." + raise ValueError(msg) + + # Validate that the subplot_by column exists in the data + if subplot_by not in data.columns: + msg = ( + f"Invalid subplot_by column '{subplot_by}'. " + f"Available columns are: {data.columns.tolist()}" + ) + raise ValueError(msg) + + # Get the unique values in the subplot_by column + unique_values = data[subplot_by].unique() + + # Get the target contexts + target_contexts = PlotContext.get_contexts_by_spec(self.plot, on_axis) + + # Check if we have enough subplots + if len(unique_values) > len(target_contexts): + msg = ( + f"Not enough subplots for all unique values in '{subplot_by}'. " + f"Got {len(unique_values)} unique values and {len(target_contexts)} subplots." + ) + raise ValueError(msg) + + # For each unique value, create a layer with the filtered data and render it + for i, value in enumerate(unique_values): + # Filter the data for this value + filtered_data = data[data[subplot_by] == value] + + # Create the layer instance + is_spi = "spi" in layer_class.__name__.lower() + if is_spi: + layer = layer_class(spi_target_data=filtered_data, **params) + else: + layer = layer_class(custom_data=filtered_data, **params) + + # Add the layer to the context and render it + context = target_contexts[i] + context.layers.append(layer) + layer.render(context) + + return self.plot + def _render_layer( - self, layer: Layer, *, on_axis: AxisSpec | None = None - ) -> _ISOPlotT: + self, + layer: Layer, + *, + on_axis: AxisSpec | None = None, + subplot_by: str | None = None, + **params: Any, + ) -> ISOPlot: """ Render a layer on the appropriate axes. @@ -239,6 +355,10 @@ def _render_layer( The layer to render on_axis : AxisSpec | None, optional Target specific axis/axes, by default None + subplot_by : str | None, optional + Column to split data by, by default None + **params : Any + Additional parameters for the layer Returns ------- @@ -255,6 +375,15 @@ def _render_layer( msg = "Cannot add layer to main context before creating subplots." raise RuntimeError(msg) + # If subplot_by is provided, we need to handle it specially + if subplot_by is not None: + # We can't handle subplot_by for an already instantiated layer + msg = ( + "Cannot use subplot_by with an already instantiated layer. " + "Please provide the layer class and data instead." + ) + raise ValueError(msg) + # Handle various axis targeting options # TODO: If feels like this should be done via # self.plot.get_contexts_by_spec(on_axis), rather than a classmethod @@ -268,19 +397,50 @@ def _render_layer( return self.plot @deprecated() - def add_scatter(self, data=None, *, on_axis=None, **params) -> _ISOPlotT: # noqa: ANN001 + def add_scatter( + self, + data=None, # noqa: ANN001 + *, + on_axis=None, # noqa: ANN001 + subplot_by=None, # noqa: ANN001 + **params, + ) -> ISOPlot: """Legacy method that forwards to add_layer(layer_type="scatter", ...).""" - return self.add_layer("scatter", data=data, on_axis=on_axis, **params) + return self.add_layer( + "scatter", data=data, on_axis=on_axis, subplot_by=subplot_by, **params + ) @deprecated() - def add_density(self, data=None, *, on_axis=None, **params) -> _ISOPlotT: # noqa: ANN001 + def add_density( + self, + data=None, # noqa: ANN001 + *, + on_axis=None, # noqa: ANN001 + subplot_by=None, # noqa: ANN001 + **params, + ) -> ISOPlot: """Legacy method that forwards to add_layer(layer_type="density", ...).""" - return self.add_layer("density", data=data, on_axis=on_axis, **params) + return self.add_layer( + "density", data=data, on_axis=on_axis, subplot_by=subplot_by, **params + ) @deprecated() - def add_simple_density(self, data=None, *, on_axis=None, **params) -> _ISOPlotT: # noqa: ANN001 + def add_simple_density( + self, + data=None, # noqa: ANN001 + *, + on_axis=None, # noqa: ANN001 + subplot_by=None, # noqa: ANN001 + **params, + ) -> ISOPlot: """Legacy method that forwards to add_layer(layer_type="simple_density",...).""" - return self.add_layer("simple_density", data=data, on_axis=on_axis, **params) + return self.add_layer( + "simple_density", + data=data, + on_axis=on_axis, + subplot_by=subplot_by, + **params, + ) @deprecated() def add_spi_simple( @@ -289,11 +449,17 @@ def add_spi_simple( *, msn_params=None, # noqa: ANN001 on_axis=None, # noqa: ANN001 + subplot_by=None, # noqa: ANN001 **params, - ) -> _ISOPlotT: + ) -> ISOPlot: """Legacy method that forwards to add_layer(layer_type="spi_simple", ...).""" return self.add_layer( - "spi_simple", data=data, on_axis=on_axis, msn_params=msn_params, **params + "spi_simple", + data=data, + on_axis=on_axis, + subplot_by=subplot_by, + msn_params=msn_params, + **params, ) @@ -374,7 +540,7 @@ def _apply_styling_to_context(self, context: PlotContext) -> None: return # Get style parameters - style_params = context.get_params("style") + style_params = cast("StyleParams", context.get_params("style")) # Apply style_mgr to the axes ax = context.ax @@ -575,6 +741,10 @@ def create_subplots( sharex: bool | Literal["none", "all", "row", "col"] = True, sharey: bool | Literal["none", "all", "row", "col"] = True, subplot_by: str | None = None, + subplot_titles: list[str] | None = None, + *, + auto_allocate_axes: bool = False, + adjust_figsize: bool = True, **kwargs: Any, ) -> Any: """ @@ -594,6 +764,13 @@ def create_subplots( Whether to share y-axis, by default True subplot_by : str | None, optional Column to create subplots by, by default None + subplot_titles : list[str] | None, optional + Custom titles for subplots, by default None + auto_allocate_axes : bool, optional + Whether to automatically determine nrows/ncols based on data, + by default False + adjust_figsize : bool, optional + Whether to adjust the figure size based on nrows/ncols, by default True **kwargs : Any Additional parameters for subplots @@ -611,6 +788,9 @@ def create_subplots( sharex=sharex, sharey=sharey, subplot_by=subplot_by, + subplot_titles=subplot_titles, + auto_allocate_axes=auto_allocate_axes, + adjust_figsize=adjust_figsize, **kwargs, ) @@ -665,6 +845,33 @@ def _create_subplot_contexts(self) -> None: # Add to subplot contexts self.plot.subplot_contexts.append(context) + def _allocate_subplot_axes(self, n_groups: int) -> tuple[int, int]: + """ + Allocate the subplot axes based on the number of data subsets. + + Parameters + ---------- + n_groups : int + Number of groups to allocate subplots for + + Returns + ------- + tuple[int, int] + Tuple of (nrows, ncols) + + """ + import warnings + + msg = ( + "This is an experimental feature. " + "The number of rows and columns may not be optimal." + ) + warnings.warn(msg, UserWarning, stacklevel=2) + + ncols = int(np.ceil(np.sqrt(n_groups))) + nrows = int(np.ceil(n_groups / ncols)) + return nrows, ncols + def _create_subplots_by_group(self) -> None: """ Create subplots by grouping the custom_data. @@ -679,18 +886,74 @@ def _create_subplots_by_group(self) -> None: if subplot_by is None or self.plot.main_context.data is None: return - # Get unique values in the subplot_by column + # Validate that the subplot_by column exists in the data data = self.plot.main_context.data + if subplot_by not in data.columns: + msg = ( + f"Invalid subplot_by column '{subplot_by}'. " + f"Available columns are: {data.columns.tolist()}" + ) + raise ValueError(msg) + + # Get unique values in the subplot_by column groups = data[subplot_by].unique() + n_subplots_by = len(groups) + + # Warn if there are few unique values + if n_subplots_by < 2: # noqa: PLR2004 + import warnings + + warnings.warn( + f"Only {n_subplots_by} unique values found in '{subplot_by}'. " + "Subplots may not be meaningful.", + UserWarning, + stacklevel=2, + ) # Limit to the number of subplots if specified if params.n_subplots_by > 0: groups = groups[: params.n_subplots_by] + n_subplots_by = len(groups) + + # Auto-allocate axes if requested + if params.auto_allocate_axes: + nrows, ncols = self._allocate_subplot_axes(n_subplots_by) + # Update the subplot parameters + self.plot.subplots_params.update(nrows=nrows, ncols=ncols) + # Recreate the figure and axes with the new dimensions + self.plot.figure, self.plot.axes = plt.subplots( + **self.plot.subplots_params.as_plt_subplots_args() + ) # Check if we have enough subplots - if len(groups) > params.n_subplots: - msg = f"Not enough subplots for all groups: {len(groups)} groups, {params.n_subplots} subplots" + if n_subplots_by > params.n_subplots: + msg = f"Not enough subplots for all groups: {n_subplots_by} groups, {params.n_subplots} subplots" + raise ValueError(msg) + + # Handle subplot titles + subplot_titles = params.subplot_titles + if subplot_titles is None: + # Create subplot titles based on the unique values + subplot_titles = [str(value) for value in groups] + elif len(subplot_titles) != n_subplots_by: + # Validate that the number of titles matches the number of unique values + msg = ( + "Number of subplot titles must match the number of unique values " + f"for '{subplot_by}'. Got {len(subplot_titles)} titles and " + f"{n_subplots_by} unique values." + ) raise ValueError(msg) + else: + # Warn if custom titles are provided with subplot_by + import warnings + + warnings.warn( + "Not recommended to provide separate subplot titles when using " + "subplot_by. Consider using the default titles based on unique values. " + "Manual subplot_titles may not be in the same order as the data.", + UserWarning, + stacklevel=2, + ) # Get axes array axes = self.plot.axes @@ -717,11 +980,11 @@ def _create_subplots_by_group(self) -> None: # Filter custom_data for this group group_data = data[data[subplot_by] == group] - # Create a title for this subplot + # Create a title for this subplot using the custom title or group value title = ( - f"{group}" + subplot_titles[i] if self.plot.main_context.title is None - else f"{self.plot.main_context.title}: {group}" + else f"{self.plot.main_context.title}: {subplot_titles[i]}" ) # Create a child context for this subplot diff --git a/src/soundscapy/plotting/new/parameter_models.py b/src/soundscapy/plotting/new/parameter_models.py index 03f9025..44b8552 100644 --- a/src/soundscapy/plotting/new/parameter_models.py +++ b/src/soundscapy/plotting/new/parameter_models.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Literal, Self, TypeAlias +from typing import Any, Literal, Self, TypeAlias, TypeGuard import numpy as np import pandas as pd @@ -310,9 +310,7 @@ class JointPlotParams(BaseParams): class StyleParams(BaseParams): - """ - Configuration options for style_mgr circumplex plots. - """ + """Configuration options for style_mgr circumplex plots.""" xlim: tuple[float, float] = DEFAULT_XLIM ylim: tuple[float, float] = DEFAULT_YLIM @@ -342,6 +340,7 @@ class SubplotsParams(BaseParams): sharey: bool | Literal["none", "all", "row", "col"] = True subplot_by: str | None = None n_subplots_by: int = -1 + subplot_titles: list[str] | None = None auto_allocate_axes: bool = False adjust_figsize: bool = True @@ -372,7 +371,52 @@ def as_plt_subplots_args(self) -> dict[str, Any]: exclude={ "subplot_by", "n_subplots_by", + "subplot_titles", "auto_allocate_axes", "adjust_figsize", } ) + + +_ParamModels = ( + BaseParams + | SeabornParams + | ScatterParams + | DensityParams + | SimpleDensityParams + | SPISimpleDensityParams +) + +_LayerParamsT = ( + SeabornParams + | ScatterParams + | DensityParams + | SimpleDensityParams + | SPISimpleDensityParams +) +"""Type alias for layer parameter models.""" + + +def is_layer_params(param_mod: BaseParams) -> TypeGuard[_LayerParamsT]: + """ + Determine whether the given parameters are of a specific layer parameters type. + + This function checks if the provided `params` object is an instance of the + specified `_LayerParamsT` type. It is used to provide type narrowing and helps + ensure that the `params` object adheres to the desired type constraints. It + returns a type guard, which allows type checkers to infer the type of `params` + as `_LayerParamsT` within a guarded code block. + + Parameters + ---------- + params : BaseParams + The parameters to be evaluated for the specific layer type. + + Returns + ------- + TypeGuard[_LayerParamsT] + A boolean-like value indicating whether the `params` object is of the type + `_LayerParamsT`. + + """ + return isinstance(param_mod, _LayerParamsT) diff --git a/src/soundscapy/plotting/new/plot_context.py b/src/soundscapy/plotting/new/plot_context.py index 5c540c1..788382f 100644 --- a/src/soundscapy/plotting/new/plot_context.py +++ b/src/soundscapy/plotting/new/plot_context.py @@ -8,18 +8,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any import matplotlib.pyplot as plt from soundscapy.plotting.new.constants import DEFAULT_XCOL, DEFAULT_YCOL from soundscapy.plotting.new.parameter_models import ( - BaseParams, DensityParams, ScatterParams, SimpleDensityParams, SPISimpleDensityParams, StyleParams, + _ParamModels, ) from soundscapy.sspylogging import get_logger @@ -28,9 +28,9 @@ from matplotlib.axes import Axes from soundscapy.plotting.new.iso_plot import ISOPlot - from soundscapy.plotting.new.protocols import ParamModel, RenderableLayer + from soundscapy.plotting.new.layer import Layer + from soundscapy.plotting.new.protocols import RenderableLayer - _ISOPlotT = TypeVar("_ISOPlotT", bound="ISOPlot") logger = get_logger() @@ -103,7 +103,7 @@ def __init__( self.parent: PlotContext | None = None # Parameter models for different layer types - self._param_models: dict[str, BaseParams] = {} + self._param_models: dict[str, _ParamModels] = {} # Initialize default parameter models self._init_param_models() @@ -131,7 +131,7 @@ def _init_param_models(self) -> None: ylabel=r"$E_{ISO}$", ) - def get_params(self, param_type: str) -> BaseParams: + def get_params(self, param_type: str) -> _ParamModels: """ Get parameters for a specific type. @@ -157,7 +157,7 @@ def get_params(self, param_type: str) -> BaseParams: return self._param_models[param_type] - def get_params_for_layer(self, layer_type: type[RenderableLayer]) -> ParamModel: + def get_params_for_layer(self, layer_type: type[Layer]) -> _ParamModels: """ Get parameters appropriate for a specific layer type. @@ -179,22 +179,22 @@ def get_params_for_layer(self, layer_type: type[RenderableLayer]) -> ParamModel: layer_name = layer_type.__name__.lower() if "scatter" in layer_name: - return cast("ParamModel", self.get_params("scatter")) - if "simpledensity" in layer_name: + return self.get_params("scatter") + if "simple" in layer_name: if "spi" in layer_name: - return cast("ParamModel", self.get_params("spi_simple_density")) - return cast("ParamModel", self.get_params("simple_density")) + return self.get_params("spi_simple_density") + return self.get_params("simple_density") if "density" in layer_name: - return cast("ParamModel", self.get_params("density")) + return self.get_params("density") # Default to scatter parameters if no match logger.warning( f"No specific parameters for layer type {layer_type.__name__}, " # noqa: G004 "using scatter parameters" ) - return cast("ParamModel", self.get_params("scatter")) + return self.get_params("scatter") - def update_params(self, param_type: str, **kwargs: Any) -> BaseParams: + def update_params(self, param_type: str, **kwargs: Any) -> _ParamModels: """ Update parameters for a specific type. @@ -257,7 +257,7 @@ def create_child( return child - def ensure_axes_exist(self, plot: _ISOPlotT) -> None: + def ensure_axes_exist(self, plot: ISOPlot) -> None: """ Check if we have axes to render on, create if needed. @@ -276,7 +276,7 @@ def ensure_axes_exist(self, plot: _ISOPlotT) -> None: plot.figure, plot.axes = plt.subplots(figsize=(5, 5)) def get_axes_by_spec( - self, plot: _ISOPlotT, spec: int | tuple[int, int] | list[int] | None + self, plot: ISOPlot, spec: int | tuple[int, int] | list[int] | None ) -> list[Axes]: """ Get axes based on specification. @@ -306,7 +306,7 @@ def get_axes_by_spec( @classmethod def get_contexts_by_spec( - cls, plot: _ISOPlotT, spec: int | tuple[int, int] | list[int] | None + cls, plot: ISOPlot, spec: int | tuple[int, int] | list[int] | None ) -> list[PlotContext]: """ Resolve which subplot contexts to target based on axis specification. @@ -348,7 +348,7 @@ def get_contexts_by_spec( @staticmethod def resolve_axis_indices( - plot: _ISOPlotT, spec: int | tuple[int, int] | list[int] + plot: ISOPlot, spec: int | tuple[int, int] | list[int] ) -> list[int]: """ Convert axis specification to list of indices. diff --git a/src/soundscapy/plotting/new/protocols.py b/src/soundscapy/plotting/new/protocols.py index 313d5d8..e651729 100644 --- a/src/soundscapy/plotting/new/protocols.py +++ b/src/soundscapy/plotting/new/protocols.py @@ -15,6 +15,9 @@ import pandas as pd from matplotlib.axes import Axes + from soundscapy.plotting.new.parameter_models import _LayerParamsT + + # Type variable for generic parameter models P = TypeVar("P", bound="ParamModel") @@ -100,7 +103,7 @@ class PlotContext(Protocol): title: str | None layers: list[RenderableLayer] - def get_params_for_layer(self, layer_type: type[RenderableLayer]) -> ParamModel: + def get_params_for_layer(self, layer_type: type[RenderableLayer]) -> _LayerParamsT: """ Get parameters appropriate for a specific layer type. diff --git a/uv.lock b/uv.lock index 08bb50a..daa9f4a 100644 --- a/uv.lock +++ b/uv.lock @@ -3660,6 +3660,7 @@ dev = [ { name = "setuptools-scm" }, { name = "tox" }, { name = "twine" }, + { name = "types-deprecated" }, { name = "types-pyyaml" }, { name = "types-seaborn" }, ] @@ -3718,6 +3719,7 @@ dev = [ { name = "setuptools-scm", specifier = ">=8.3.1" }, { name = "tox", specifier = ">=4.25.0" }, { name = "twine", specifier = ">=6.1.0" }, + { name = "types-deprecated", specifier = ">=1.2.15.20250304" }, { name = "types-pyyaml", specifier = ">=6.0.12.20250402" }, { name = "types-seaborn", specifier = ">=0.13.2.20250111" }, ] @@ -3932,6 +3934,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/b6/74e927715a285743351233f33ea3c684528a0d374d2e43ff9ce9585b73fe/twine-6.1.0-py3-none-any.whl", hash = "sha256:a47f973caf122930bf0fbbf17f80b83bc1602c9ce393c7845f289a3001dc5384", size = 40791, upload-time = "2025-01-21T18:45:24.584Z" }, ] +[[package]] +name = "types-deprecated" +version = "1.2.15.20250304" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/67/eeefaaabb03b288aad85483d410452c8bbcbf8b2bd876b0e467ebd97415b/types_deprecated-1.2.15.20250304.tar.gz", hash = "sha256:c329030553029de5cc6cb30f269c11f4e00e598c4241290179f63cda7d33f719", size = 8015, upload-time = "2025-03-04T02:48:17.894Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/e3/c18aa72ab84e0bc127a3a94e93be1a6ac2cb281371d3a45376ab7cfdd31c/types_deprecated-1.2.15.20250304-py3-none-any.whl", hash = "sha256:86a65aa550ea8acf49f27e226b8953288cd851de887970fbbdf2239c116c3107", size = 8553, upload-time = "2025-03-04T02:48:16.666Z" }, +] + [[package]] name = "types-python-dateutil" version = "2.9.0.20241206" From a4fe71f009dc66c2df7bec79f2894ec51d6d5d60 Mon Sep 17 00:00:00 2001 From: Andrew Mitchell Date: Sun, 11 May 2025 00:50:37 +0100 Subject: [PATCH 7/8] Refactors plotting parameters and styling management - Replaces type casting with direct parameter access using new get() method - Introduces ParamModels dataclass for better parameter organization - Simplifies style management with dedicated methods for grid, labels and legends - Reduces SPI layer sample size default from 10000 to 1000 - Improves code maintainability and readability through better encapsulation --- src/soundscapy/plotting/new/layer.py | 34 +-- src/soundscapy/plotting/new/managers.py | 218 +++++++++++------- .../plotting/new/parameter_models.py | 19 ++ src/soundscapy/plotting/new/plot_context.py | 75 ++++-- src/soundscapy/plotting/plot_functions.py | 2 +- 5 files changed, 225 insertions(+), 123 deletions(-) diff --git a/src/soundscapy/plotting/new/layer.py b/src/soundscapy/plotting/new/layer.py index f69fb72..1f495c6 100644 --- a/src/soundscapy/plotting/new/layer.py +++ b/src/soundscapy/plotting/new/layer.py @@ -9,7 +9,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal import numpy as np import pandas as pd @@ -351,7 +351,7 @@ def _render_implementation( ) # Cast params to the correct type - simple_density_params = cast("SimpleDensityParams", params) + simple_density_params = params # Create a copy of the parameters with data kwargs = simple_density_params.as_seaborn_kwargs() @@ -363,7 +363,7 @@ def _render_implementation( # Set specific parameters for simple density kwargs["levels"] = simple_density_params.levels - kwargs["thresh"] = getattr(simple_density_params, "thresh", 0.05) + kwargs["thresh"] = simple_density_params.get("thresh", 0.05) # Render the simple density plot sns.kdeplot(ax=ax, **kwargs) @@ -379,7 +379,7 @@ def __init__( spi_target_data: pd.DataFrame | np.ndarray | None = None, *, msn_params: DirectParams | CentredParams | None = None, - n: int = 10000, + n: int = 1000, custom_data: pd.DataFrame | None = None, **params: Any, ) -> None: @@ -421,6 +421,8 @@ def __init__( spi_target_data, msn_params ) + # TODO: Should move this into the render stage. + # Need to get SPIParams from param_model. # Generate the SPI target data self.spi_data: pd.DataFrame = self._generate_spi_data( spi_target_data, self.spi_params, n @@ -456,6 +458,9 @@ def render(self, context: PlotContext) -> None: msg = "No data available for rendering SPI layer" raise ValueError(msg) + # assign the generated target data to context + # context.spi_data = target_data + # Get parameters from context params = self._get_params_from_context(context) @@ -484,10 +489,15 @@ def _render_implementation( The parameters for this layer """ - target_data = data[[context.x, context.y]] + target_data = ( + data[[context.x, context.y]] if data is not None else self.spi_data + ) # Get test data from context - test_data = context.data + test_data = ( + context.data[[context.x, context.y]] if context.data is not None else None + ) + if test_data is None: warnings.warn( "Cannot find data to test SPI against. Skipping this plot.", @@ -502,14 +512,10 @@ def _render_implementation( # Show the score self.show_score( spi_score, - show_score=params.show_score - if hasattr(params, "show_score") - else "under title", + show_score=params.get("show_score", "under title"), context=context, ax=ax, - axis_text_kwargs=params.axis_text_kw - if hasattr(params, "axis_text_kw") - else {}, + axis_text_kwargs=params.get("axis_text_kw", {}), ) def show_score( @@ -734,8 +740,8 @@ def _process_spi_data( """ params = self._get_params_from_context(context) - xcol = getattr(params, "x", context.x) - ycol = getattr(params, "y", context.y) + xcol = params.get("x", context.x) + ycol = params.get("y", context.y) # DataFrame handling if isinstance(spi_data, pd.DataFrame): diff --git a/src/soundscapy/plotting/new/managers.py b/src/soundscapy/plotting/new/managers.py index 40e075d..38805e7 100644 --- a/src/soundscapy/plotting/new/managers.py +++ b/src/soundscapy/plotting/new/managers.py @@ -12,8 +12,9 @@ import matplotlib.pyplot as plt import numpy as np -import pandas as pd -from matplotlib.axes import Axes +import seaborn as sns +from matplotlib import ticker +from matplotlib.artist import Artist from soundscapy.plotting.new.layer import ( DensityLayer, @@ -22,6 +23,7 @@ SimpleDensityLayer, SPISimpleLayer, ) +from soundscapy.plotting.new.parameter_models import MplLegendLocType from soundscapy.plotting.new.plot_context import PlotContext from soundscapy.sspylogging import get_logger @@ -34,9 +36,10 @@ from deprecated import deprecated if TYPE_CHECKING: - from soundscapy.plotting.new import ISOPlot - from soundscapy.plotting.new.parameter_models import StyleParams + import pandas as pd + from matplotlib.axes import Axes + from soundscapy.plotting.new import ISOPlot, StyleParams logger = get_logger() @@ -76,40 +79,6 @@ def __init__(self, plot: Any) -> None: """ self.plot = plot - # # Pass a fully instantiated layer - # @overload - # def add_layer( - # self, - # layer_spec: Layer, - # *, - # on_axis: AxisSpec | None = None, - # subplot_by: str | None = None, - # ) -> ISOPlot: ... - # - # # Pass an uninstantiated layer class - # @overload - # def add_layer( - # self, - # layer_spec: LayerClass, - # data: pd.DataFrame | None = None, - # *, - # on_axis: AxisSpec | None = None, - # subplot_by: str | None = None, - # **params: Any, - # ) -> ISOPlot: ... - # - # # Pass a string of the layer type name - # @overload - # def add_layer( - # self, - # layer_spec: LayerType, - # data: pd.DataFrame | None = None, - # *, - # on_axis: AxisSpec | None = None, - # subplot_by: str | None = None, - # **params: Any, - # ) -> ISOPlot: ... - def add_layer( self, layer_spec: LayerSpec, @@ -465,10 +434,7 @@ def add_spi_simple( class StyleManager: """ - Manages the style_mgr of plots. - - This class encapsulates the style_mgr-related functionality that was previously - implemented as a mixin in the ISOPlot class. + Manages the styling of plots. Attributes ---------- @@ -488,6 +454,7 @@ def __init__(self, plot: Any) -> None: """ self.plot = plot + self.param_overrides = {} def apply_styling( self, @@ -512,28 +479,32 @@ def apply_styling( """ # Update style parameters - self.plot.main_context.update_params("style", **style_params) + self.param_overrides = style_params + self._set_sns_style() + # TODO: Need to set suptitle at some point # If no subplots, apply to main context if not self.plot.subplot_contexts: + self.plot.main_context.update_params("style", **style_params) self._apply_styling_to_context(self.plot.main_context) return self.plot # Apply to specified subplots target_contexts = PlotContext.get_contexts_by_spec(self.plot, on_axis) for context in target_contexts: + context.update_params("style", **style_params) self._apply_styling_to_context(context) return self.plot def _apply_styling_to_context(self, context: PlotContext) -> None: """ - Apply style_mgr to a specific context. + Apply styling to a specific context. Parameters ---------- context : PlotContext - The context to apply style_mgr to + The context to apply styling to """ if context.ax is None: @@ -541,43 +512,88 @@ def _apply_styling_to_context(self, context: PlotContext) -> None: # Get style parameters style_params = cast("StyleParams", context.get_params("style")) - # Apply style_mgr to the axes ax = context.ax - # Set limits - if hasattr(style_params, "xlim"): - ax.set_xlim(style_params.xlim) - if hasattr(style_params, "ylim"): - ax.set_ylim(style_params.ylim) + self._circumplex_grid(ax, style_params) - # Set labels - if hasattr(style_params, "xlabel") and style_params.xlabel is not False: - xlabel = style_params.xlabel or context.x - ax.set_xlabel(xlabel, fontdict=style_params.prim_ax_fontdict) + self._set_axis_title(ax, context, style_params) + # Alternative? + # ax.set_title( + # context.title, fontsize=style_params.title_fontsize) - if hasattr(style_params, "ylabel") and style_params.ylabel is not False: - ylabel = style_params.ylabel or context.y - ax.set_ylabel(ylabel, fontdict=style_params.prim_ax_fontdict) + self._primary_labels(ax, context, style_params) + # Alternative? + # ax.set_xlabel(xlabel, fontdict=style_params.prim_ax_fontdict) # noqa: ERA001 + # ax.set_ylabel(ylabel, fontdict=style_params.prim_ax_fontdict) # noqa: ERA001 - # Set title - if context.title: - ax.set_title(context.title, fontsize=style_params.title_fontsize) - - # Add primary axes lines - if hasattr(style_params, "primary_lines") and style_params.primary_lines: + if style_params.get("primary_lines"): self._add_primary_lines(ax, style_params) - - # Add diagonal lines and labels - if hasattr(style_params, "diagonal_lines") and style_params.diagonal_lines: + if style_params.get("diagonal_lines"): self._add_diagonal_lines(ax, style_params) self._add_diagonal_labels(ax, style_params) + if style_params.get("legend_loc") is not False: + # ax.legend(loc=style_params.get("legend_loc")) # noqa: ERA001 + self._move_legend(ax, style_params.get("legend_loc")) + + @staticmethod + def _set_sns_style() -> None: + """Set the overall style for the plot.""" + sns.set_style({"xtick.direction": "in", "ytick.direction": "in"}) + + @staticmethod + def _circumplex_grid(axis: Axes, style_params: StyleParams) -> None: + """Add the circumplex grid to the plot.""" + axis.set_xlim(style_params.get("xlim")) + axis.set_ylim(style_params.get("ylim")) + axis.set_aspect("equal") + + axis.get_yaxis().set_minor_locator(ticker.AutoMinorLocator()) + axis.get_xaxis().set_minor_locator(ticker.AutoMinorLocator()) + + axis.grid(visible=True, which="major", color="grey", alpha=0.5) + axis.grid( + visible=True, + which="minor", + color="grey", + linestyle="dashed", + linewidth=0.5, + alpha=0.4, + zorder=style_params.get("prim_lines_zorder"), + ) - # Add legend if needed - if hasattr(style_params, "legend_loc") and style_params.legend_loc: - ax.legend(loc=style_params.legend_loc) + @staticmethod + def _set_axis_title( + ax: Axes, context: PlotContext, style_params: StyleParams + ) -> None: + if ax and context.title: + ax.set_title(context.title, fontsize=style_params.title_fontsize) - def _add_primary_lines(self, ax: Axes, style_params: Any) -> None: + @staticmethod + def _primary_labels( + axis: Axes, context: PlotContext, style_params: StyleParams + ) -> None: + """Handle the default labels for the x and y axes.""" + xlabel = style_params.get("xlabel") + ylabel = style_params.get("ylabel") + + xlabel = context.x if xlabel is None else xlabel + ylabel = context.y if ylabel is None else ylabel + fontdict = style_params.get("prim_ax_fontdict") + + # BUG: For some reason, this ruins the sharex and sharey + # functionality, but only when a layer is applied + # a specific subplot. + axis.set_xlabel( + xlabel, fontdict=fontdict + ) if xlabel is not False else axis.xaxis.label.set_visible(False) + + axis.set_ylabel( + ylabel, fontdict=fontdict + ) if ylabel is not False else axis.yaxis.label.set_visible(False) + + @staticmethod + def _add_primary_lines(axis: Axes, style_params: StyleParams) -> None: """ Add primary axes lines to the plot. @@ -590,19 +606,21 @@ def _add_primary_lines(self, ax: Axes, style_params: Any) -> None: """ # Add horizontal and vertical lines at 0 - ax.axhline( + axis.axhline( y=0, - color="black", - linestyle="-", - linewidth=style_params.linewidth, - zorder=style_params.prim_lines_zorder, + color="grey", + linestyle="dashed", + alpha=1, + lw=style_params.get("linewidth"), + zorder=style_params.get("prim_lines_zorder"), ) - ax.axvline( + axis.axvline( x=0, - color="black", - linestyle="-", - linewidth=style_params.linewidth, - zorder=style_params.prim_lines_zorder, + color="grey", + linestyle="dashed", + alpha=1, + lw=style_params.get("linewidth"), + zorder=style_params.get("prim_lines_zorder"), ) def _add_diagonal_lines(self, ax: Axes, style_params: Any) -> None: @@ -706,6 +724,35 @@ def _add_diagonal_labels(self, ax: Axes, style_params: Any) -> None: zorder=style_params.diag_labels_zorder, ) + @staticmethod + def _move_legend(axis: Axes, legend_loc: MplLegendLocType) -> None: + """Move the legend to the specified location.""" + old_legend = axis.get_legend() + if old_legend is None: + # logger.debug("_move_legend: No legend found for axis %s", i) + return + + # Get handles and filter out None values + handles = [ + h for h in old_legend.legend_handles if isinstance(h, Artist | tuple) + ] + # Skip if no valid handles remain + if not handles: + return + + labels = [t.get_text() for t in old_legend.get_texts()] + title = old_legend.get_title().get_text() + # Ensure labels and handles match in length + if len(handles) != len(labels): + labels = labels[: len(handles)] + + axis.legend( + handles, + labels, + loc=legend_loc, + title=title, + ) + class SubplotManager: """ @@ -830,11 +877,10 @@ def _create_subplot_contexts(self) -> None: # Create a context for each axis for i, ax in enumerate(axes.flatten()): # Create a title for this subplot - title = ( - f"Subplot {i + 1}" - if self.plot.main_context.title is None - else f"{self.plot.main_context.title} {i + 1}" - ) + if params.subplot_titles is not None: + title = params.subplot_titles[i] + else: + title = None # Create a child context for this subplot context = self.plot.main_context.create_child( diff --git a/src/soundscapy/plotting/new/parameter_models.py b/src/soundscapy/plotting/new/parameter_models.py index 44b8552..84b1d94 100644 --- a/src/soundscapy/plotting/new/parameter_models.py +++ b/src/soundscapy/plotting/new/parameter_models.py @@ -143,6 +143,25 @@ def get_changed_params(self) -> dict[str, Any]: """ return self.model_dump(exclude_unset=True) + def get(self, key: str, default: Any = None) -> Any: + """ + Get a parameter value with a default fallback. + + Parameters + ---------- + key : str + Name of the parameter + default : Any, optional + Default value if parameter doesn't exist + + Returns + ------- + Any + Parameter value or default + + """ + return getattr(self, key, default) + class AxisParams(BaseParams): """Parameters for axis configuration.""" diff --git a/src/soundscapy/plotting/new/plot_context.py b/src/soundscapy/plotting/new/plot_context.py index 788382f..b154306 100644 --- a/src/soundscapy/plotting/new/plot_context.py +++ b/src/soundscapy/plotting/new/plot_context.py @@ -8,6 +8,8 @@ from __future__ import annotations +import dataclasses +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import matplotlib.pyplot as plt @@ -35,6 +37,42 @@ logger = get_logger() +@dataclass +class ParamModels: + """ + Container for parameter models with methods for retrieval and transformation. + + Attributes + ---------- + scatter : ScatterParams + density : DensityParams + simple_density : SimpleDensityParams + spi_simple_density : SPISimpleDensityParams + style : StyleParams + + """ + + scatter: ScatterParams = field(default_factory=ScatterParams) + density: DensityParams = field(default_factory=DensityParams) + simple_density: SimpleDensityParams = field(default_factory=SimpleDensityParams) + spi_simple_density: SPISimpleDensityParams = field( + default_factory=SPISimpleDensityParams + ) + style: StyleParams = field(default_factory=StyleParams) + + def get( + self, model_name: str, default: _ParamModels | None = None + ) -> _ParamModels | None: + return getattr(self, model_name, default) + + @property + def names(self) -> list[str]: + return list(self.__dict__.keys()) + + def as_dict(self) -> dict[str, _ParamModels]: + return self.__dict__ + + class PlotContext: """ Manages custom_data, state, and parameters for a plot or subplot. @@ -103,7 +141,7 @@ def __init__( self.parent: PlotContext | None = None # Parameter models for different layer types - self._param_models: dict[str, _ParamModels] = {} + self._param_models: ParamModels = field(default_factory=ParamModels) # Initialize default parameter models self._init_param_models() @@ -117,18 +155,12 @@ def _init_param_models(self) -> None: hue = self.hue # Create parameter models for different layer types - self._param_models["scatter"] = ScatterParams(data=data, x=x, y=y, hue=hue) - self._param_models["density"] = DensityParams(data=data, x=x, y=y, hue=hue) - self._param_models["simple_density"] = SimpleDensityParams(data=data, x=x, y=y) - self._param_models["spi_simple_density"] = SPISimpleDensityParams( - data=data, x=x, y=y, hue=hue - ) - self._param_models["style"] = StyleParams( - # TODO: Should not be setting defaults here! - xlim=(-1, 1), - ylim=(-1, 1), - xlabel=r"$P_{ISO}$", - ylabel=r"$E_{ISO}$", + self._param_models = ParamModels( + scatter=ScatterParams(data=data, x=x, y=y, hue=hue), + density=DensityParams(data=data, x=x, y=y, hue=hue), + simple_density=SimpleDensityParams(data=data, x=x, y=y, hue=hue), + spi_simple_density=SPISimpleDensityParams(data=data, x=x, y=y, hue=hue), + style=StyleParams(), ) def get_params(self, param_type: str) -> _ParamModels: @@ -151,11 +183,11 @@ def get_params(self, param_type: str) -> _ParamModels: If the parameter type is unknown """ - if param_type not in self._param_models: - msg = f"Unknown parameter type: {param_type}" - raise ValueError(msg) - - return self._param_models[param_type] + model = self._param_models.get(param_type) + if model is not None: + return model + msg = f"Unknown parameter type: {param_type}" + raise ValueError(msg) def get_params_for_layer(self, layer_type: type[Layer]) -> _ParamModels: """ @@ -247,17 +279,16 @@ def create_child( ax=ax, title=title, ) - # Copy parameter models from parent to child - for param_type, model in self._param_models.items(): - child._param_models[param_type] = model.model_copy() + child._param_models = dataclasses.replace(self._param_models) # Set parent reference child.parent = self return child - def ensure_axes_exist(self, plot: ISOPlot) -> None: + @staticmethod + def ensure_axes_exist(plot: ISOPlot) -> None: """ Check if we have axes to render on, create if needed. diff --git a/src/soundscapy/plotting/plot_functions.py b/src/soundscapy/plotting/plot_functions.py index e931a9c..200c289 100644 --- a/src/soundscapy/plotting/plot_functions.py +++ b/src/soundscapy/plotting/plot_functions.py @@ -178,7 +178,7 @@ def scatter( ... ax=ax[1], title="RegentsParkJapan" ... ) >>> plt.tight_layout() - >>> plt.show() + >>> plt.show() # xdoctest: +SKIP """ style_args = StyleParams().update(**kwargs, extra="ignore", na_rm=False) From cc2a04839285b42630a556643c42598a3c3e932f Mon Sep 17 00:00:00 2001 From: Andrew Mitchell Date: Sun, 11 May 2025 01:10:40 +0100 Subject: [PATCH 8/8] Fix import path for ISOPlot in iso_plot.py --- src/soundscapy/plotting/iso_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/soundscapy/plotting/iso_plot.py b/src/soundscapy/plotting/iso_plot.py index f622d62..3d32b04 100644 --- a/src/soundscapy/plotting/iso_plot.py +++ b/src/soundscapy/plotting/iso_plot.py @@ -4,7 +4,7 @@ Example: ------- >>> from soundscapy import isd, surveys ->>> from soundscapy.plotting.iso_plot_new import ISOPlot +>>> from soundscapy.plotting.iso_plot import ISOPlot >>> df = isd.load() >>> df = surveys.add_iso_coords(df) >>> sub_df = isd.select_location_ids(df, ['CamdenTown', 'RegentsParkJapan'])