Skip to content

Commit f15d982

Browse files
committed
Enhance availability check for backends. Fix issues with matplotlib in interactive mode.
1 parent e167b49 commit f15d982

File tree

4 files changed

+62
-34
lines changed

4 files changed

+62
-34
lines changed

src/optimagic/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ class InvalidAlgoInfoError(OptimagicError):
7979
"""Exception for invalid user provided algorithm information."""
8080

8181

82+
class InvalidPlottingBackendError(OptimagicError):
83+
"""Exception for invalid user provided plotting backend."""
84+
85+
8286
class StopOptimizationError(OptimagicError):
8387
def __init__(self, message, current_status):
8488
super().__init__(message)

src/optimagic/visualization/backends.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,29 @@
55
import plotly.graph_objects as go
66

77
from optimagic.config import IS_MATPLOTLIB_INSTALLED
8-
from optimagic.exceptions import NotInstalledError
8+
from optimagic.exceptions import InvalidPlottingBackendError, NotInstalledError
99
from optimagic.visualization.plotting_utilities import LineData
1010

1111
if IS_MATPLOTLIB_INSTALLED:
1212
import matplotlib as mpl
1313
import matplotlib.pyplot as plt
1414

15+
# Handle the case where matplotlib is used in notebooks (inline backend)
16+
# to ensure that interactive mode is disabled to avoid double plotting.
17+
# (See: https://github.com/matplotlib/matplotlib/issues/26221)
18+
if mpl.get_backend() == "module://matplotlib_inline.backend_inline":
19+
plt.install_repl_displayhook()
20+
plt.ioff()
21+
1522

1623
class PlotBackend(abc.ABC):
24+
is_available: bool
1725
default_template: str
18-
default_palette: list
26+
27+
@classmethod
28+
@abc.abstractmethod
29+
def get_default_palette(cls) -> list:
30+
pass
1931

2032
@abc.abstractmethod
2133
def __init__(self, template: str | None):
@@ -39,8 +51,12 @@ def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
3951

4052

4153
class PlotlyBackend(PlotBackend):
54+
is_available: bool = True
4255
default_template: str = "simple_white"
43-
default_palette: list = px.colors.qualitative.Set2
56+
57+
@classmethod
58+
def get_default_palette(cls) -> list:
59+
return px.colors.qualitative.Set2
4460

4561
def __init__(self, template: str | None):
4662
super().__init__(template)
@@ -68,41 +84,39 @@ def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
6884
self._fig.update_layout(legend=legend_properties)
6985

7086

71-
if IS_MATPLOTLIB_INSTALLED:
87+
class MatplotlibBackend(PlotBackend):
88+
is_available: bool = IS_MATPLOTLIB_INSTALLED
89+
default_template: str = "default"
7290

73-
class MatplotlibBackend(PlotBackend):
74-
default_template: str = "default"
75-
default_palette: list = [
76-
mpl.colormaps["Set2"](i) for i in range(mpl.colormaps["Set2"].N)
77-
]
91+
@classmethod
92+
def get_default_palette(cls) -> list:
93+
return [mpl.colormaps["Set2"](i) for i in range(mpl.colormaps["Set2"].N)]
7894

79-
def __init__(self, template: str | None):
80-
super().__init__(template)
81-
plt.style.use(self.template)
82-
self._fig, self._ax = plt.subplots()
83-
self.figure = self._fig
95+
def __init__(self, template: str | None):
96+
super().__init__(template)
97+
plt.style.use(self.template)
98+
self._fig, self._ax = plt.subplots()
99+
self.figure = self._fig
84100

85-
def add_lines(self, lines: list[LineData]) -> None:
86-
for line in lines:
87-
self._ax.plot(
88-
line.x,
89-
line.y,
90-
color=line.color,
91-
label=line.name if line.show_in_legend else None,
92-
)
101+
def add_lines(self, lines: list[LineData]) -> None:
102+
for line in lines:
103+
self._ax.plot(
104+
line.x,
105+
line.y,
106+
color=line.color,
107+
label=line.name if line.show_in_legend else None,
108+
)
93109

94-
def set_labels(
95-
self, xlabel: str | None = None, ylabel: str | None = None
96-
) -> None:
97-
self._ax.set(xlabel=xlabel, ylabel=ylabel)
110+
def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
111+
self._ax.set(xlabel=xlabel, ylabel=ylabel)
98112

99-
def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
100-
self._ax.legend(**legend_properties)
113+
def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
114+
self._ax.legend(**legend_properties)
101115

102116

103117
PLOT_BACKEND_CLASSES = {
104118
"plotly": PlotlyBackend,
105-
"matplotlib": MatplotlibBackend if IS_MATPLOTLIB_INSTALLED else None,
119+
"matplotlib": MatplotlibBackend,
106120
}
107121

108122

@@ -112,15 +126,15 @@ def get_plot_backend_class(backend_name: str) -> type[PlotBackend]:
112126
f"Invalid backend name '{backend_name}'. "
113127
f"Supported backends are: {', '.join(PLOT_BACKEND_CLASSES.keys())}."
114128
)
115-
raise ValueError(msg)
129+
raise InvalidPlottingBackendError(msg)
116130

117131
return _get_backend_if_installed(backend_name)
118132

119133

120134
def _get_backend_if_installed(backend_name: str) -> type[PlotBackend]:
121135
plot_cls = PLOT_BACKEND_CLASSES[backend_name]
122136

123-
if plot_cls is None:
137+
if not plot_cls.is_available:
124138
msg = (
125139
f"The '{backend_name}' backend is not installed. "
126140
f"Install the package using either 'pip install {backend_name}' or "

src/optimagic/visualization/history_plots.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import itertools
33
from dataclasses import dataclass
44
from pathlib import Path
5-
from typing import Any
5+
from typing import Any, Literal
66

77
import numpy as np
88
import plotly.graph_objects as go
@@ -38,7 +38,7 @@ def criterion_plot(
3838
results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath],
3939
names: list[str] | str | None = None,
4040
max_evaluations: int | None = None,
41-
backend: str = "plotly",
41+
backend: Literal["plotly", "matplotlib"] = "plotly",
4242
template: str | None = None,
4343
palette: list[str] | str | None = None,
4444
stack_multistart: bool = False,
@@ -75,7 +75,7 @@ def criterion_plot(
7575
# Process inputs
7676

7777
if palette is None:
78-
palette = plot_cls.default_palette
78+
palette = plot_cls.get_default_palette()
7979
palette_cycle = get_palette_cycle(palette)
8080

8181
dict_of_optimize_results_or_paths = _harmonize_inputs_to_dict(results, names)

tests/optimagic/visualization/test_history_plots.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy.testing import assert_array_equal
77

88
import optimagic as om
9+
from optimagic.exceptions import InvalidPlottingBackendError
910
from optimagic.logging import SQLiteLogOptions
1011
from optimagic.optimization.optimize import minimize
1112
from optimagic.parameters.bounds import Bounds
@@ -144,6 +145,15 @@ def test_criterion_plot_wrong_inputs():
144145
with pytest.raises(ValueError):
145146
criterion_plot(["bla", "bla"], names="blub")
146147

148+
with pytest.raises(InvalidPlottingBackendError):
149+
criterion_plot("bla", backend="blub")
150+
151+
152+
@pytest.mark.parametrize("backend", ["plotly", "matplotlib"])
153+
def test_criterion_plot_different_backends(minimize_result, backend):
154+
res = minimize_result[False][0]
155+
criterion_plot(res, backend=backend)
156+
147157

148158
def test_harmonize_inputs_to_dict_single_result():
149159
res = minimize(fun=lambda x: x @ x, params=np.arange(5), algorithm="scipy_lbfgsb")

0 commit comments

Comments
 (0)