5
5
import plotly .graph_objects as go
6
6
7
7
from optimagic .config import IS_MATPLOTLIB_INSTALLED
8
- from optimagic .exceptions import NotInstalledError
8
+ from optimagic .exceptions import InvalidPlottingBackendError , NotInstalledError
9
9
from optimagic .visualization .plotting_utilities import LineData
10
10
11
11
if IS_MATPLOTLIB_INSTALLED :
12
12
import matplotlib as mpl
13
13
import matplotlib .pyplot as plt
14
14
15
+ # Handle the case where matplotlib is used in notebooks (inline backend)
16
+ # to ensure that interactive mode is disabled to avoid double plotting.
17
+ # (See: https://github.com/matplotlib/matplotlib/issues/26221)
18
+ if mpl .get_backend () == "module://matplotlib_inline.backend_inline" :
19
+ plt .install_repl_displayhook ()
20
+ plt .ioff ()
21
+
15
22
16
23
class PlotBackend (abc .ABC ):
24
+ is_available : bool
17
25
default_template : str
18
- default_palette : list
26
+
27
+ @classmethod
28
+ @abc .abstractmethod
29
+ def get_default_palette (cls ) -> list :
30
+ pass
19
31
20
32
@abc .abstractmethod
21
33
def __init__ (self , template : str | None ):
@@ -39,8 +51,12 @@ def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
39
51
40
52
41
53
class PlotlyBackend (PlotBackend ):
54
+ is_available : bool = True
42
55
default_template : str = "simple_white"
43
- default_palette : list = px .colors .qualitative .Set2
56
+
57
+ @classmethod
58
+ def get_default_palette (cls ) -> list :
59
+ return px .colors .qualitative .Set2
44
60
45
61
def __init__ (self , template : str | None ):
46
62
super ().__init__ (template )
@@ -68,41 +84,39 @@ def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
68
84
self ._fig .update_layout (legend = legend_properties )
69
85
70
86
71
- if IS_MATPLOTLIB_INSTALLED :
87
+ class MatplotlibBackend (PlotBackend ):
88
+ is_available : bool = IS_MATPLOTLIB_INSTALLED
89
+ default_template : str = "default"
72
90
73
- class MatplotlibBackend (PlotBackend ):
74
- default_template : str = "default"
75
- default_palette : list = [
76
- mpl .colormaps ["Set2" ](i ) for i in range (mpl .colormaps ["Set2" ].N )
77
- ]
91
+ @classmethod
92
+ def get_default_palette (cls ) -> list :
93
+ return [mpl .colormaps ["Set2" ](i ) for i in range (mpl .colormaps ["Set2" ].N )]
78
94
79
- def __init__ (self , template : str | None ):
80
- super ().__init__ (template )
81
- plt .style .use (self .template )
82
- self ._fig , self ._ax = plt .subplots ()
83
- self .figure = self ._fig
95
+ def __init__ (self , template : str | None ):
96
+ super ().__init__ (template )
97
+ plt .style .use (self .template )
98
+ self ._fig , self ._ax = plt .subplots ()
99
+ self .figure = self ._fig
84
100
85
- def add_lines (self , lines : list [LineData ]) -> None :
86
- for line in lines :
87
- self ._ax .plot (
88
- line .x ,
89
- line .y ,
90
- color = line .color ,
91
- label = line .name if line .show_in_legend else None ,
92
- )
101
+ def add_lines (self , lines : list [LineData ]) -> None :
102
+ for line in lines :
103
+ self ._ax .plot (
104
+ line .x ,
105
+ line .y ,
106
+ color = line .color ,
107
+ label = line .name if line .show_in_legend else None ,
108
+ )
93
109
94
- def set_labels (
95
- self , xlabel : str | None = None , ylabel : str | None = None
96
- ) -> None :
97
- self ._ax .set (xlabel = xlabel , ylabel = ylabel )
110
+ def set_labels (self , xlabel : str | None = None , ylabel : str | None = None ) -> None :
111
+ self ._ax .set (xlabel = xlabel , ylabel = ylabel )
98
112
99
- def set_legend_properties (self , legend_properties : dict [str , Any ]) -> None :
100
- self ._ax .legend (** legend_properties )
113
+ def set_legend_properties (self , legend_properties : dict [str , Any ]) -> None :
114
+ self ._ax .legend (** legend_properties )
101
115
102
116
103
117
PLOT_BACKEND_CLASSES = {
104
118
"plotly" : PlotlyBackend ,
105
- "matplotlib" : MatplotlibBackend if IS_MATPLOTLIB_INSTALLED else None ,
119
+ "matplotlib" : MatplotlibBackend ,
106
120
}
107
121
108
122
@@ -112,15 +126,15 @@ def get_plot_backend_class(backend_name: str) -> type[PlotBackend]:
112
126
f"Invalid backend name '{ backend_name } '. "
113
127
f"Supported backends are: { ', ' .join (PLOT_BACKEND_CLASSES .keys ())} ."
114
128
)
115
- raise ValueError (msg )
129
+ raise InvalidPlottingBackendError (msg )
116
130
117
131
return _get_backend_if_installed (backend_name )
118
132
119
133
120
134
def _get_backend_if_installed (backend_name : str ) -> type [PlotBackend ]:
121
135
plot_cls = PLOT_BACKEND_CLASSES [backend_name ]
122
136
123
- if plot_cls is None :
137
+ if not plot_cls . is_available :
124
138
msg = (
125
139
f"The '{ backend_name } ' backend is not installed. "
126
140
f"Install the package using either 'pip install { backend_name } ' or "
0 commit comments