Skip to content

Implementation of matplotlib backend for criterion_plot() #599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .tools/envs/testenv-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly>=6.2 # run, tests
- matplotlib # tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-nevergrad.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly>=6.2 # run, tests
- matplotlib # tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-numpy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- cloudpickle # run, tests
- joblib # run, tests
- plotly>=6.2 # run, tests
- matplotlib # tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-others.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly>=6.2 # run, tests
- matplotlib # tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-pandas.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- cloudpickle # run, tests
- joblib # run, tests
- plotly>=6.2 # run, tests
- matplotlib # tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions .tools/envs/testenv-plotly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- joblib # run, tests
- numpy >= 2 # run, tests
- pandas # run, tests
- matplotlib # tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- numpy >= 2 # run, tests
- pandas # run, tests
- plotly>=6.2 # run, tests
- matplotlib # tests
- pybaum>=0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy # run, tests
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ module = [

"optimagic.visualization",
"optimagic.visualization.convergence_plot",
"optimagic.visualization.backends",
"optimagic.visualization.deviation_plot",
"optimagic.visualization.history_plots",
"optimagic.visualization.plotting_utilities",
Expand Down Expand Up @@ -346,6 +347,8 @@ module = [
"plotly.graph_objects",
"plotly.express",
"plotly.subplots",
"matplotlib",
"matplotlib.pyplot",
"cyipopt",
"nlopt",
"bokeh",
Expand Down
8 changes: 7 additions & 1 deletion src/optimagic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _is_installed(module_name: str) -> bool:


# ======================================================================================
# Check Available Packages
# Check Available Optimization Packages
# ======================================================================================

IS_PETSC4PY_INSTALLED = _is_installed("petsc4py")
Expand All @@ -40,6 +40,12 @@ def _is_installed(module_name: str) -> bool:
IS_NEVERGRAD_INSTALLED = _is_installed("nevergrad")
IS_BAYESOPT_INSTALLED = _is_installed("bayes_opt")

# ======================================================================================
# Check Available Visualization Packages
# ======================================================================================

IS_MATPLOTLIB_INSTALLED = _is_installed("matplotlib")


# ======================================================================================
# Check if pandas version is newer or equal to version 2.1.0
Expand Down
4 changes: 4 additions & 0 deletions src/optimagic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class InvalidAlgoInfoError(OptimagicError):
"""Exception for invalid user provided algorithm information."""


class InvalidPlottingBackendError(OptimagicError):
"""Exception for invalid user provided plotting backend."""


class StopOptimizationError(OptimagicError):
def __init__(self, message, current_status):
super().__init__(message)
Expand Down
145 changes: 145 additions & 0 deletions src/optimagic/visualization/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import abc
from typing import Any

import plotly.express as px
import plotly.graph_objects as go

from optimagic.config import IS_MATPLOTLIB_INSTALLED
from optimagic.exceptions import InvalidPlottingBackendError, NotInstalledError
from optimagic.visualization.plotting_utilities import LineData

if IS_MATPLOTLIB_INSTALLED:
import matplotlib as mpl
import matplotlib.pyplot as plt

# Handle the case where matplotlib is used in notebooks (inline backend)
# to ensure that interactive mode is disabled to avoid double plotting.
# (See: https://github.com/matplotlib/matplotlib/issues/26221)
if mpl.get_backend() == "module://matplotlib_inline.backend_inline":
plt.install_repl_displayhook()
plt.ioff()


class PlotBackend(abc.ABC):
is_available: bool
default_template: str

@classmethod
@abc.abstractmethod
def get_default_palette(cls) -> list:
pass

@abc.abstractmethod
def __init__(self, template: str | None):
if template is None:
template = self.default_template

self.template = template
self.figure: Any = None

@abc.abstractmethod
def add_lines(self, lines: list[LineData]) -> None:
pass

@abc.abstractmethod
def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
pass

@abc.abstractmethod
def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
pass


class PlotlyBackend(PlotBackend):
is_available: bool = True
default_template: str = "simple_white"

@classmethod
def get_default_palette(cls) -> list:
return px.colors.qualitative.Set2

def __init__(self, template: str | None):
super().__init__(template)
self._fig = go.Figure()
self._fig.update_layout(template=self.template)
self.figure = self._fig

def add_lines(self, lines: list[LineData]) -> None:
for line in lines:
trace = go.Scatter(
x=line.x,
y=line.y,
name=line.name,
mode="lines",
line_color=line.color,
showlegend=line.show_in_legend,
connectgaps=True,
)
self._fig.add_trace(trace)

def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
self._fig.update_layout(xaxis_title_text=xlabel, yaxis_title_text=ylabel)

def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
self._fig.update_layout(legend=legend_properties)


class MatplotlibBackend(PlotBackend):
is_available: bool = IS_MATPLOTLIB_INSTALLED
default_template: str = "default"

@classmethod
def get_default_palette(cls) -> list:
return [mpl.colormaps["Set2"](i) for i in range(mpl.colormaps["Set2"].N)]

def __init__(self, template: str | None):
super().__init__(template)
plt.style.use(self.template)
self._fig, self._ax = plt.subplots()
self.figure = self._fig

def add_lines(self, lines: list[LineData]) -> None:
for line in lines:
self._ax.plot(
line.x,
line.y,
color=line.color,
label=line.name if line.show_in_legend else None,
)

def set_labels(self, xlabel: str | None = None, ylabel: str | None = None) -> None:
self._ax.set(xlabel=xlabel, ylabel=ylabel)

def set_legend_properties(self, legend_properties: dict[str, Any]) -> None:
self._ax.legend(**legend_properties)


PLOT_BACKEND_CLASSES = {
"plotly": PlotlyBackend,
"matplotlib": MatplotlibBackend,
}


def get_plot_backend_class(backend_name: str) -> type[PlotBackend]:
if backend_name not in PLOT_BACKEND_CLASSES:
msg = (
f"Invalid backend name '{backend_name}'. "
f"Supported backends are: {', '.join(PLOT_BACKEND_CLASSES.keys())}."
)
raise InvalidPlottingBackendError(msg)

return _get_backend_if_installed(backend_name)


def _get_backend_if_installed(backend_name: str) -> type[PlotBackend]:
plot_cls = PLOT_BACKEND_CLASSES[backend_name]

if not plot_cls.is_available:
msg = (
f"The '{backend_name}' backend is not installed. "
f"Install the package using either 'pip install {backend_name}' or "
f"'conda install -c conda-forge {backend_name}'"
)
raise NotInstalledError(msg)

return plot_cls
Loading
Loading