diff --git a/.tools/envs/testenv-linux.yml b/.tools/envs/testenv-linux.yml index 398c56cce..691ea8375 100644 --- a/.tools/envs/testenv-linux.yml +++ b/.tools/envs/testenv-linux.yml @@ -19,6 +19,7 @@ dependencies: - numpy >= 2 # run, tests - pandas # run, tests - plotly>=6.2 # run, tests + - matplotlib # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-nevergrad.yml b/.tools/envs/testenv-nevergrad.yml index 874b9fa5e..4fbd6fef1 100644 --- a/.tools/envs/testenv-nevergrad.yml +++ b/.tools/envs/testenv-nevergrad.yml @@ -17,6 +17,7 @@ dependencies: - numpy >= 2 # run, tests - pandas # run, tests - plotly>=6.2 # run, tests + - matplotlib # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-numpy.yml b/.tools/envs/testenv-numpy.yml index c54dc010f..5bdc48eeb 100644 --- a/.tools/envs/testenv-numpy.yml +++ b/.tools/envs/testenv-numpy.yml @@ -17,6 +17,7 @@ dependencies: - cloudpickle # run, tests - joblib # run, tests - plotly>=6.2 # run, tests + - matplotlib # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-others.yml b/.tools/envs/testenv-others.yml index 308d142aa..131016aa2 100644 --- a/.tools/envs/testenv-others.yml +++ b/.tools/envs/testenv-others.yml @@ -17,6 +17,7 @@ dependencies: - numpy >= 2 # run, tests - pandas # run, tests - plotly>=6.2 # run, tests + - matplotlib # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-pandas.yml b/.tools/envs/testenv-pandas.yml index bccee25c6..9dde57231 100644 --- a/.tools/envs/testenv-pandas.yml +++ b/.tools/envs/testenv-pandas.yml @@ -17,6 +17,7 @@ dependencies: - cloudpickle # run, tests - joblib # run, tests - plotly>=6.2 # run, tests + - matplotlib # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/.tools/envs/testenv-plotly.yml b/.tools/envs/testenv-plotly.yml index eccdf512d..75c68c8b5 100644 --- a/.tools/envs/testenv-plotly.yml +++ b/.tools/envs/testenv-plotly.yml @@ -17,6 +17,7 @@ dependencies: - joblib # run, tests - numpy >= 2 # run, tests - pandas # run, tests + - matplotlib # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/environment.yml b/environment.yml index 6bb4f01db..e8f538f95 100644 --- a/environment.yml +++ b/environment.yml @@ -21,6 +21,7 @@ dependencies: - numpy >= 2 # run, tests - pandas # run, tests - plotly>=6.2 # run, tests + - matplotlib # tests - pybaum>=0.1.2 # run, tests - scipy>=1.2.1 # run, tests - sqlalchemy # run, tests diff --git a/pyproject.toml b/pyproject.toml index c74752252..e7d2370d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -290,6 +290,7 @@ module = [ "optimagic.visualization", "optimagic.visualization.convergence_plot", + "optimagic.visualization.backends", "optimagic.visualization.deviation_plot", "optimagic.visualization.history_plots", "optimagic.visualization.plotting_utilities", @@ -346,6 +347,8 @@ module = [ "plotly.graph_objects", "plotly.express", "plotly.subplots", + "matplotlib", + "matplotlib.pyplot", "cyipopt", "nlopt", "bokeh", diff --git a/src/optimagic/config.py b/src/optimagic/config.py index ce6cd4d60..00902a32a 100644 --- a/src/optimagic/config.py +++ b/src/optimagic/config.py @@ -23,7 +23,7 @@ def _is_installed(module_name: str) -> bool: # ====================================================================================== -# Check Available Packages +# Check Available Optimization Packages # ====================================================================================== IS_PETSC4PY_INSTALLED = _is_installed("petsc4py") @@ -40,6 +40,12 @@ def _is_installed(module_name: str) -> bool: IS_NEVERGRAD_INSTALLED = _is_installed("nevergrad") IS_BAYESOPT_INSTALLED = _is_installed("bayes_opt") +# ====================================================================================== +# Check Available Visualization Packages +# ====================================================================================== + +IS_MATPLOTLIB_INSTALLED = _is_installed("matplotlib") + # ====================================================================================== # Check if pandas version is newer or equal to version 2.1.0 diff --git a/src/optimagic/exceptions.py b/src/optimagic/exceptions.py index 7a7dfb75d..560b47b84 100644 --- a/src/optimagic/exceptions.py +++ b/src/optimagic/exceptions.py @@ -79,6 +79,10 @@ class InvalidAlgoInfoError(OptimagicError): """Exception for invalid user provided algorithm information.""" +class InvalidPlottingBackendError(OptimagicError): + """Exception for invalid user provided plotting backend.""" + + class StopOptimizationError(OptimagicError): def __init__(self, message, current_status): super().__init__(message) diff --git a/src/optimagic/visualization/backends.py b/src/optimagic/visualization/backends.py new file mode 100644 index 000000000..a2c96422c --- /dev/null +++ b/src/optimagic/visualization/backends.py @@ -0,0 +1,145 @@ +import abc +from typing import Any + +import plotly.express as px +import plotly.graph_objects as go + +from optimagic.config import IS_MATPLOTLIB_INSTALLED +from optimagic.exceptions import InvalidPlottingBackendError, NotInstalledError +from optimagic.visualization.plotting_utilities import LineData + +if IS_MATPLOTLIB_INSTALLED: + import matplotlib as mpl + import matplotlib.pyplot as plt + + # Handle the case where matplotlib is used in notebooks (inline backend) + # to ensure that interactive mode is disabled to avoid double plotting. + # (See: https://github.com/matplotlib/matplotlib/issues/26221) + if mpl.get_backend() == "module://matplotlib_inline.backend_inline": + plt.install_repl_displayhook() + plt.ioff() + + +class PlotBackend(abc.ABC): + is_available: bool + default_template: str + + @classmethod + @abc.abstractmethod + def get_default_palette(cls) -> list: + pass + + @abc.abstractmethod + def __init__(self, template: str | None): + if template is None: + template = self.default_template + + self.template = template + self.figure: Any = None + + @abc.abstractmethod + def add_lines(self, lines: list[LineData]) -> None: + pass + + @abc.abstractmethod + def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None: + pass + + @abc.abstractmethod + def set_legend_properties(self, legend_properties: dict[str, Any]) -> None: + pass + + +class PlotlyBackend(PlotBackend): + is_available: bool = True + default_template: str = "simple_white" + + @classmethod + def get_default_palette(cls) -> list: + return px.colors.qualitative.Set2 + + def __init__(self, template: str | None): + super().__init__(template) + self._fig = go.Figure() + self._fig.update_layout(template=self.template) + self.figure = self._fig + + def add_lines(self, lines: list[LineData]) -> None: + for line in lines: + trace = go.Scatter( + x=line.x, + y=line.y, + name=line.name, + mode="lines", + line_color=line.color, + showlegend=line.show_in_legend, + connectgaps=True, + ) + self._fig.add_trace(trace) + + def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None: + self._fig.update_layout(xaxis_title_text=xlabel, yaxis_title_text=ylabel) + + def set_legend_properties(self, legend_properties: dict[str, Any]) -> None: + self._fig.update_layout(legend=legend_properties) + + +class MatplotlibBackend(PlotBackend): + is_available: bool = IS_MATPLOTLIB_INSTALLED + default_template: str = "default" + + @classmethod + def get_default_palette(cls) -> list: + return [mpl.colormaps["Set2"](i) for i in range(mpl.colormaps["Set2"].N)] + + def __init__(self, template: str | None): + super().__init__(template) + plt.style.use(self.template) + self._fig, self._ax = plt.subplots() + self.figure = self._fig + + def add_lines(self, lines: list[LineData]) -> None: + for line in lines: + self._ax.plot( + line.x, + line.y, + color=line.color, + label=line.name if line.show_in_legend else None, + ) + + def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None: + self._ax.set(xlabel=xlabel, ylabel=ylabel) + + def set_legend_properties(self, legend_properties: dict[str, Any]) -> None: + self._ax.legend(**legend_properties) + + +PLOT_BACKEND_CLASSES = { + "plotly": PlotlyBackend, + "matplotlib": MatplotlibBackend, +} + + +def get_plot_backend_class(backend_name: str) -> type[PlotBackend]: + if backend_name not in PLOT_BACKEND_CLASSES: + msg = ( + f"Invalid backend name '{backend_name}'. " + f"Supported backends are: {', '.join(PLOT_BACKEND_CLASSES.keys())}." + ) + raise InvalidPlottingBackendError(msg) + + return _get_backend_if_installed(backend_name) + + +def _get_backend_if_installed(backend_name: str) -> type[PlotBackend]: + plot_cls = PLOT_BACKEND_CLASSES[backend_name] + + if not plot_cls.is_available: + msg = ( + f"The '{backend_name}' backend is not installed. " + f"Install the package using either 'pip install {backend_name}' or " + f"'conda install -c conda-forge {backend_name}'" + ) + raise NotInstalledError(msg) + + return plot_cls diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index 696ac0c1d..fcd8bf586 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -2,42 +2,57 @@ import itertools from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Literal import numpy as np import plotly.graph_objects as go from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten -from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE +from optimagic.config import PLOTLY_TEMPLATE from optimagic.logging.logger import LogReader, SQLiteLogOptions from optimagic.optimization.algorithm import Algorithm from optimagic.optimization.history import History from optimagic.optimization.optimize_result import OptimizeResult from optimagic.parameters.tree_registry import get_registry from optimagic.typing import IterationHistory, PyTree +from optimagic.visualization.backends import get_plot_backend_class +from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle -OptimizeResultOrPath = OptimizeResult | str | Path +BACKEND_TO_CRITERION_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = { + "plotly": { + "yanchor": "top", + "xanchor": "right", + "y": 0.95, + "x": 0.95, + }, + "matplotlib": { + "loc": "upper right", + }, +} + + +ResultOrPath = OptimizeResult | str | Path def criterion_plot( - results: OptimizeResultOrPath - | list[OptimizeResultOrPath] - | dict[str, OptimizeResultOrPath], + results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath], names: list[str] | str | None = None, max_evaluations: int | None = None, - template: str = PLOTLY_TEMPLATE, - palette: list[str] | str = PLOTLY_PALETTE, + backend: Literal["plotly", "matplotlib"] = "plotly", + template: str | None = None, + palette: list[str] | str | None = None, stack_multistart: bool = False, monotone: bool = False, show_exploration: bool = False, -) -> go.Figure: +) -> Any: """Plot the criterion history of an optimization. Args: - results: A (list or dict of) optimization results with collected history. - If dict, then the key is used as the name in a legend. - names: Names corresponding to res or entries in res. + results: An optimization result (or list of, or dict of results) with collected + history, or path(s) to it. If dict, then the key is used as the name in the + legend. max_evaluations: Clip the criterion history after that many entries. + backend: The backend to use for plotting. Default is "plotly". template: The template for the figure. Default is "plotly_white". palette: The coloring palette for traces. Default is "qualitative.Set2". stack_multistart: Whether to combine multistart histories into a single history. @@ -51,12 +66,17 @@ def criterion_plot( The figure object containing the criterion plot. """ + # ================================================================================== + # Get Plot Backend class + + plot_cls = get_plot_backend_class(backend) + # ================================================================================== # Process inputs - if not isinstance(palette, list): - palette = [palette] - palette_cycle = itertools.cycle(palette) + if palette is None: + palette = plot_cls.get_default_palette() + palette_cycle = get_palette_cycle(palette) dict_of_optimize_results_or_paths = _harmonize_inputs_to_dict(results, names) @@ -78,21 +98,19 @@ def criterion_plot( ) # ================================================================================== - # Generate the plotly figure + # Generate the figure - plot_config = PlotConfig( - template=template, - legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95}, - ) + plot = plot_cls(template) - fig = _plotly_line_plot(lines + multistart_lines, plot_config) - return fig + plot.add_lines(lines + multistart_lines) + plot.set_labels(xlabel="No. of criterion evaluations", ylabel="Criterion value") + plot.set_legend_properties(BACKEND_TO_CRITERION_PLOT_LEGEND_PROPERTIES[backend]) + + return plot.figure def _harmonize_inputs_to_dict( - results: OptimizeResultOrPath - | list[OptimizeResultOrPath] - | dict[str, OptimizeResultOrPath], + results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath], names: list[str] | str | None, ) -> dict[str, OptimizeResult | str | Path]: """Convert all valid inputs for results and names to dict[str, OptimizeResult].""" @@ -462,26 +480,6 @@ def _get_stacked_local_histories( ) -@dataclass(frozen=True) -class LineData: - """Data of a single line. - - Attributes: - x: The x-coordinates of the points. - y: The y-coordinates of the points. - color: The color of the line. Default is None. - name: The name of the line. Default is None. - show_in_legend: Whether to show the line in the legend. Default is True. - - """ - - x: np.ndarray - y: np.ndarray - color: str | None = None - name: str | None = None - show_in_legend: bool = True - - def _extract_criterion_plot_lines( data: list[_PlottingMultistartHistory], max_evaluations: int | None, @@ -543,69 +541,13 @@ def _extract_criterion_plot_lines( if max_evaluations is not None and len(history) > max_evaluations: history = history[:max_evaluations] - _color = next(palette_cycle) - if not isinstance(_color, str): - msg = "highlight_palette needs to be a string or list of strings, but its " - f"entry is of type {type(_color)}." - raise TypeError(msg) - line_data = LineData( x=np.arange(len(history)), y=history, - color=_color, + color=next(palette_cycle), name="best result" if plot_multistart else _data.name, show_in_legend=not plot_multistart, ) lines.append(line_data) return lines, multistart_lines - - -@dataclass(frozen=True) -class PlotConfig: - """Configuration settings for figure. - - Attributes: - template: The template for the figure. - legend: Configuration for the legend. - - """ - - template: str - legend: dict[str, Any] - - -def _plotly_line_plot(lines: list[LineData], plot_config: PlotConfig) -> go.Figure: - """Create a plotly line plot from the given lines and plot configuration. - - Args: - lines: Data for lines to be plotted. - plot_config: Configuration for the plot. - - Returns: - The figure object containing the lines. - - """ - - fig = go.Figure() - - for line in lines: - trace = go.Scatter( - x=line.x, - y=line.y, - name=line.name, - mode="lines", - line_color=line.color, - showlegend=line.show_in_legend, - connectgaps=True, - ) - fig.add_trace(trace) - - fig.update_layout( - template=plot_config.template, - xaxis_title_text="No. of criterion evaluations", - yaxis_title_text="Criterion value", - legend=plot_config.legend, - ) - - return fig diff --git a/src/optimagic/visualization/plotting_utilities.py b/src/optimagic/visualization/plotting_utilities.py index 3d37c3b97..f6c770533 100644 --- a/src/optimagic/visualization/plotting_utilities.py +++ b/src/optimagic/visualization/plotting_utilities.py @@ -2,6 +2,7 @@ import collections.abc import itertools from copy import deepcopy +from dataclasses import dataclass from typing import Any import numpy as np @@ -11,6 +12,26 @@ from optimagic.config import PLOTLY_TEMPLATE +@dataclass(frozen=True) +class LineData: + """Data of a single line. + + Attributes: + x: The x-coordinates of the points. + y: The y-coordinates of the points. + color: The color of the line. Default is None. + name: The name of the line. Default is None. + show_in_legend: Whether to show the line in the legend. Default is True. + + """ + + x: np.ndarray + y: np.ndarray + color: str | None = None + name: str | None = None + show_in_legend: bool = True + + def combine_plots( plots, plots_per_row=2, @@ -364,3 +385,9 @@ def _ensure_array_from_plotly_data(data: Any) -> np.ndarray: def _decode_base64_data(b64data: str, dtype: str) -> np.ndarray: decoded = base64.b64decode(b64data) return np.frombuffer(decoded, dtype=np.dtype(dtype)) + + +def get_palette_cycle(palette: list[str] | str) -> "itertools.cycle[str]": + if not isinstance(palette, list): + palette = [palette] + return itertools.cycle(palette) diff --git a/tests/optimagic/visualization/test_history_plots.py b/tests/optimagic/visualization/test_history_plots.py index 28cdceb6e..c0bae33d3 100644 --- a/tests/optimagic/visualization/test_history_plots.py +++ b/tests/optimagic/visualization/test_history_plots.py @@ -6,6 +6,7 @@ from numpy.testing import assert_array_equal import optimagic as om +from optimagic.exceptions import InvalidPlottingBackendError from optimagic.logging import SQLiteLogOptions from optimagic.optimization.optimize import minimize from optimagic.parameters.bounds import Bounds @@ -144,6 +145,15 @@ def test_criterion_plot_wrong_inputs(): with pytest.raises(ValueError): criterion_plot(["bla", "bla"], names="blub") + with pytest.raises(InvalidPlottingBackendError): + criterion_plot("bla", backend="blub") + + +@pytest.mark.parametrize("backend", ["plotly", "matplotlib"]) +def test_criterion_plot_different_backends(minimize_result, backend): + res = minimize_result[False][0] + criterion_plot(res, backend=backend) + def test_harmonize_inputs_to_dict_single_result(): res = minimize(fun=lambda x: x @ x, params=np.arange(5), algorithm="scipy_lbfgsb")