diff --git a/src/ess/livedata/dashboard/plot_configuration_adapter.py b/src/ess/livedata/dashboard/plot_configuration_adapter.py new file mode 100644 index 000000000..7a8c4d495 --- /dev/null +++ b/src/ess/livedata/dashboard/plot_configuration_adapter.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from typing import Any + +import pydantic + +from ess.livedata.config.workflow_spec import JobNumber +from ess.livedata.dashboard.configuration_adapter import ConfigurationAdapter +from ess.livedata.dashboard.plotting import PlotterSpec +from ess.livedata.dashboard.plotting_controller import PlottingController + + +class PlotConfigurationAdapter(ConfigurationAdapter): + """Adapter for plot configuration modal.""" + + def __init__( + self, + job_number: JobNumber, + output_name: str | None, + plot_spec: PlotterSpec, + available_sources: list[str], + plotting_controller: PlottingController, + success_callback, + ): + self._job_number = job_number + self._output_name = output_name + self._plot_spec = plot_spec + self._available_sources = available_sources + self._plotting_controller = plotting_controller + self._success_callback = success_callback + + self._persisted_config = ( + self._plotting_controller.get_persistent_plotter_config( + job_number=self._job_number, + output_name=self._output_name, + plot_name=self._plot_spec.name, + ) + ) + + @property + def title(self) -> str: + return f"Configure {self._plot_spec.title}" + + @property + def description(self) -> str: + return self._plot_spec.description + + def model_class(self) -> type[pydantic.BaseModel] | None: + return self._plot_spec.params + + @property + def source_names(self) -> list[str]: + return self._available_sources + + @property + def initial_source_names(self) -> list[str]: + if self._persisted_config is not None: + # Filter persisted source names to only include those still available + persisted_sources = [ + name + for name in self._persisted_config.source_names + if name in self._available_sources + ] + return persisted_sources if persisted_sources else self._available_sources + return self._available_sources + + @property + def initial_parameter_values(self) -> dict[str, Any]: + if self._persisted_config is not None: + return self._persisted_config.config.params + return {} + + def start_action( + self, + selected_sources: list[str], + parameter_values: Any, + ) -> None: + """Create the plot and call the success callback with the result.""" + plot = self._plotting_controller.create_plot( + job_number=self._job_number, + source_names=selected_sources, + output_name=self._output_name, + plot_name=self._plot_spec.name, + params=parameter_values, + ) + self._success_callback(plot, selected_sources) diff --git a/src/ess/livedata/dashboard/reduction.py b/src/ess/livedata/dashboard/reduction.py index f6126e7aa..4501b9277 100644 --- a/src/ess/livedata/dashboard/reduction.py +++ b/src/ess/livedata/dashboard/reduction.py @@ -10,7 +10,7 @@ from .dashboard import DashboardBase from .widgets.log_producer_widget import LogProducerWidget -pn.extension('holoviews', 'modal', template='material') +pn.extension('holoviews', 'modal', notifications=True, template='material') hv.extension('bokeh') diff --git a/src/ess/livedata/dashboard/widgets/configuration_widget.py b/src/ess/livedata/dashboard/widgets/configuration_widget.py index 104c32a28..e99040e9d 100644 --- a/src/ess/livedata/dashboard/widgets/configuration_widget.py +++ b/src/ess/livedata/dashboard/widgets/configuration_widget.py @@ -113,11 +113,11 @@ def _on_aux_source_changed(self, event) -> None: break if widget_index is not None: - self._widget.objects = ( - self._widget.objects[:widget_index] - + [self._model_widget.widget] - + self._widget.objects[widget_index + 1 :] - ) + self._widget.objects = [ + *self._widget.objects[:widget_index], + self._model_widget.widget, + *self._widget.objects[widget_index + 1 :], + ] def _create_widget(self) -> pn.Column: """Create the main configuration widget.""" @@ -198,121 +198,91 @@ def clear_validation_errors(self) -> None: self._model_widget.clear_validation_errors() -class ConfigurationModal: - """Generic modal dialog for configuration.""" +class ConfigurationPanel: + """Reusable configuration panel with validation and action execution.""" def __init__( self, config: ConfigurationAdapter, - start_button_text: str = "Start", - success_callback: Callable[[], None] | None = None, - error_callback: Callable[[str], None] | None = None, ) -> None: """ - Initialize generic configuration modal. + Initialize configuration panel. Parameters ---------- config Configuration adapter providing data and callbacks - start_button_text - Text for the start button - success_callback - Called when action completes successfully - error_callback - Called when an error occurs """ self._config = config self._config_widget = ConfigurationWidget(config) - self._success_callback = success_callback - self._error_callback = error_callback self._error_pane = pn.pane.HTML("", sizing_mode='stretch_width') - self._modal = self._create_modal(start_button_text) self._logger = logging.getLogger(__name__) + self._panel = self._create_panel() - def _create_modal(self, start_button_text: str) -> pn.Modal: - """Create the modal dialog.""" - start_button = pn.widgets.Button(name=start_button_text, button_type="primary") - start_button.on_click(self._on_start_action) - - cancel_button = pn.widgets.Button(name="Cancel", button_type="light") - cancel_button.on_click(self._on_cancel) - - content = pn.Column( + def _create_panel(self) -> pn.Column: + """Create the configuration panel.""" + return pn.Column( self._config_widget.widget, self._error_pane, - pn.Row(pn.Spacer(), cancel_button, start_button, margin=(10, 0)), - ) - - modal = pn.Modal( - content, - name=f"Configure {self._config.title}", - margin=20, - width=800, - height=800, ) - # Watch for modal close events to clean up - modal.param.watch(self._on_modal_closed, 'open') - - return modal - - def _on_cancel(self, event) -> None: - """Handle cancel button click.""" - self._modal.open = False - - def _on_modal_closed(self, event) -> None: - """Handle modal being closed (cleanup).""" - if not event.new: # Modal was closed - # Remove modal from its parent container after a short delay - # to allow the close animation to complete - def cleanup(): - try: - if hasattr(self._modal, '_parent') and self._modal._parent: - self._modal._parent.remove(self._modal) - except Exception: # noqa: S110 - pass # Ignore cleanup errors - - pn.state.add_periodic_callback(cleanup, period=100, count=1) + def validate(self) -> tuple[bool, list[str]]: + """ + Validate configuration and show errors inline. - def _on_start_action(self, event) -> None: - """Handle start action button click.""" - # Clear previous errors + Returns + ------- + : + Tuple of (is_valid, list_of_error_messages) + """ self._config_widget.clear_validation_errors() self._error_pane.object = "" - # Validate configuration is_valid, errors = self._config_widget.validate_configuration() if not is_valid: self._show_validation_errors(errors) - return - # Execute the start action and handle any exceptions + return is_valid, errors + + def execute_action(self) -> bool: + """ + Execute the configuration action. + + Assumes validation has already passed. If validation is needed, + use validate() first or use validate_and_execute(). + + Returns + ------- + : + True if action succeeded, False if action raised error + """ try: self._config.start_action( self._config_widget.selected_sources, self._config_widget.parameter_values, ) except Exception as e: - # Log the full exception with stack trace self._logger.exception("Error starting '%s'", self._config.title) - - # Show user-friendly error message error_message = f"Error starting '{self._config.title}': {e!s}" self._show_action_error(error_message) + return False - # Notify error callback if provided - if self._error_callback: - self._error_callback(error_message) + return True - # Keep modal open so user can correct the issue or see the error - return + def validate_and_execute(self) -> bool: + """ + Convenience method: validate then execute if valid. - # Success - close modal and notify success callback - self._modal.open = False - if self._success_callback: - self._success_callback() + Returns + ------- + : + True if both validation and execution succeeded, False otherwise + """ + is_valid, _ = self.validate() + if not is_valid: + return False + return self.execute_action() def _show_validation_errors(self, errors: list[str]) -> None: """Show validation errors inline.""" @@ -339,6 +309,102 @@ def _show_action_error(self, message: str) -> None: ) self._error_pane.object = error_html + @property + def panel(self) -> pn.Column: + """Get the panel widget.""" + return self._panel + + +class ConfigurationModal: + """Modal wrapper around ConfigurationPanel with action buttons.""" + + def __init__( + self, + config: ConfigurationAdapter, + start_button_text: str = "Start", + success_callback: Callable[[], None] | None = None, + ) -> None: + """ + Initialize configuration modal. + + Parameters + ---------- + config + Configuration adapter providing data and callbacks + start_button_text + Text for the start button + success_callback + Called when action completes successfully + """ + self._config = config + self._success_callback = success_callback + + # Create panel + self._panel = ConfigurationPanel(config=config) + + # Create action buttons + self._start_button = pn.widgets.Button( + name=start_button_text, button_type="primary" + ) + self._start_button.on_click(self._on_start_clicked) + + self._cancel_button = pn.widgets.Button(name="Cancel", button_type="light") + self._cancel_button.on_click(self._on_cancel_clicked) + + # Create modal with panel + buttons + self._modal = self._create_modal() + + def _create_modal(self) -> pn.Modal: + """Create the modal dialog.""" + # Combine panel with buttons + content = pn.Column( + self._panel.panel, + pn.Row( + pn.Spacer(), + self._cancel_button, + self._start_button, + margin=(10, 0), + ), + ) + + modal = pn.Modal( + content, + name=f"Configure {self._config.title}", + margin=20, + width=800, + height=800, + ) + + # Watch for modal close events to clean up + modal.param.watch(self._on_modal_closed, 'open') + + return modal + + def _on_start_clicked(self, event) -> None: + """Handle start button click.""" + if self._panel.validate_and_execute(): + self._modal.open = False + if self._success_callback: + self._success_callback() + + def _on_cancel_clicked(self, event) -> None: + """Handle cancel button click.""" + self._modal.open = False + + def _on_modal_closed(self, event) -> None: + """Handle modal being closed (cleanup).""" + if not event.new: # Modal was closed + # Remove modal from its parent container after a short delay + # to allow the close animation to complete + def cleanup(): + try: + if hasattr(self._modal, '_parent') and self._modal._parent: + self._modal._parent.remove(self._modal) + except Exception: # noqa: S110 + pass # Ignore cleanup errors + + pn.state.add_periodic_callback(cleanup, period=100, count=1) + def show(self) -> None: """Show the modal dialog.""" self._modal.open = True diff --git a/src/ess/livedata/dashboard/widgets/job_plotter_selection_modal.py b/src/ess/livedata/dashboard/widgets/job_plotter_selection_modal.py new file mode 100644 index 000000000..3438c3d94 --- /dev/null +++ b/src/ess/livedata/dashboard/widgets/job_plotter_selection_modal.py @@ -0,0 +1,560 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +import logging +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import pandas as pd +import panel as pn + +from ess.livedata.config.workflow_spec import JobNumber +from ess.livedata.dashboard.job_service import JobService +from ess.livedata.dashboard.plotting import PlotterSpec +from ess.livedata.dashboard.plotting_controller import PlottingController + +from ..plot_configuration_adapter import PlotConfigurationAdapter +from .configuration_widget import ConfigurationPanel +from .wizard import Wizard, WizardStep + + +@dataclass +class JobOutputSelection: + """Output from job/output selection step.""" + + job: JobNumber + output: str | None + + +@dataclass +class PlotterSelection: + """Output from plotter selection step.""" + + job: JobNumber + output: str | None + plot_name: str + + +@dataclass +class PlotResult: + """Output from configuration step (final result).""" + + plot: Any + selected_sources: list[str] + + +class JobOutputSelectionStep(WizardStep[None, JobOutputSelection]): + """ + Step 1: Job and output selection. + + This is mostly copied from the legacy PlotCreationWidget but is considered legacy. + The contents of this widget will be fully replaced. + """ + + def __init__( + self, + job_service: JobService, + ) -> None: + """ + Initialize job/output selection step. + + Parameters + ---------- + job_service: + Service for accessing job data + """ + super().__init__() + self._job_service = job_service + self._table = self._create_job_output_table() + self._selected_job: JobNumber | None = None + self._selected_output: str | None = None + + # Set up selection watcher + self._table.param.watch(self._on_table_selection_change, 'selection') + + @property + def name(self) -> str: + """Display name for this step.""" + return "Select Job and Output" + + @property + def description(self) -> str | None: + """Description text for this step.""" + return "Choose the job and output you want to visualize." + + def _create_job_output_table(self) -> pn.widgets.Tabulator: + """Create job and output selection table with grouping.""" + return pn.widgets.Tabulator( + name="Available Jobs and Outputs", + pagination='remote', + page_size=15, + sizing_mode='stretch_width', + selectable=1, + disabled=True, + height=400, + groupby=['workflow_name', 'job_number'], + configuration={ + 'columns': [ + {'title': 'Job Number', 'field': 'job_number', 'width': 100}, + {'title': 'Workflow', 'field': 'workflow_name', 'width': 100}, + {'title': 'Output Name', 'field': 'output_name', 'width': 200}, + {'title': 'Source Names', 'field': 'source_names', 'width': 500}, + ], + }, + ) + + def _update_job_output_table(self) -> None: + """Update the job and output table with current job data.""" + job_output_data = [] + for job_number, workflow_id in self._job_service.job_info.items(): + job_data = self._job_service.job_data.get(job_number, {}) + sources = list(job_data.keys()) + + # Get output names from any source (they all have the same outputs per + # backend guarantee) + output_names = set() + for source_data in job_data.values(): + if isinstance(source_data, dict): + output_names.update(source_data.keys()) + break # Since all sources have same outputs, we only check one + + # If no outputs found, create a row with empty output name + if not output_names: + job_output_data.append( + { + 'output_name': '', + 'source_names': ', '.join(sources), + 'workflow_name': workflow_id.name, + 'job_number': job_number.hex, + } + ) + else: + # Create one row per output name + job_output_data.extend( + [ + { + 'output_name': output_name, + 'source_names': ', '.join(sources), + 'workflow_name': workflow_id.name, + 'job_number': job_number.hex, + } + for output_name in sorted(output_names) + ] + ) + + if job_output_data: + df = pd.DataFrame(job_output_data) + else: + df = pd.DataFrame( + columns=['job_number', 'workflow_name', 'output_name', 'source_names'] + ) + self._table.value = df + + def _on_table_selection_change(self, event) -> None: + """Handle job and output selection change.""" + selection = event.new + if len(selection) != 1: + self._selected_job = None + self._selected_output = None + self._notify_ready_changed(False) + return + + # Get selected job number and output name from index + selected_row = selection[0] + job_number_str = self._table.value['job_number'].iloc[selected_row] + output_name = self._table.value['output_name'].iloc[selected_row] + + self._selected_job = JobNumber(job_number_str) + self._selected_output = output_name if output_name else None + self._notify_ready_changed(True) + + def is_valid(self) -> bool: + """Whether a valid job/output selection has been made.""" + return self._selected_job is not None + + def commit(self) -> JobOutputSelection | None: + """Commit the selected job and output.""" + if self._selected_job is None: + return None + return JobOutputSelection(job=self._selected_job, output=self._selected_output) + + def render_content(self) -> pn.Column: + """Render job/output selection table.""" + return pn.Column( + self._table, + sizing_mode='stretch_width', + ) + + def on_enter(self, input_data: None) -> None: + """Update table data when step becomes active.""" + self._update_job_output_table() + + +class PlotterSelectionStep(WizardStep[JobOutputSelection, PlotterSelection]): + """Step 2: Plotter type selection.""" + + def __init__( + self, + plotting_controller: PlottingController, + logger: logging.Logger, + ) -> None: + """ + Initialize plotter selection step. + + Parameters + ---------- + plotting_controller: + Controller for determining available plotters + logger: + Logger instance for error reporting + """ + super().__init__() + self._plotting_controller = plotting_controller + self._logger = logger + self._radio_group: pn.widgets.RadioButtonGroup | None = None + self._content_container = pn.Column(sizing_mode='stretch_width') + self._job_output: JobOutputSelection | None = None + self._selected_plot_name: str | None = None + + @property + def name(self) -> str: + """Display name for this step.""" + return "Select Plotter Type" + + @property + def description(self) -> str | None: + """Description text for this step.""" + return "Choose the type of plot you want to create." + + def is_valid(self) -> bool: + """Step is valid when a plotter has been selected.""" + return self._selected_plot_name is not None + + def commit(self) -> PlotterSelection | None: + """Commit the job, output, and selected plotter.""" + if self._job_output is None or self._selected_plot_name is None: + return None + return PlotterSelection( + job=self._job_output.job, + output=self._job_output.output, + plot_name=self._selected_plot_name, + ) + + def render_content(self) -> pn.Column: + """Render plotter selection radio buttons.""" + return self._content_container + + def on_enter(self, input_data: JobOutputSelection) -> None: + """Update available plotters when step becomes active.""" + self._job_output = input_data + self._update_plotter_selection() + + def _update_plotter_selection(self) -> None: + """Update plotter selection based on job and output selection.""" + self._content_container.clear() + + if self._job_output is None: + self._content_container.append(pn.pane.Markdown("*No job selected*")) + self._radio_group = None + self._notify_ready_changed(False) + return + + available_plots = self._plotting_controller.get_available_plotters( + self._job_output.job, self._job_output.output + ) + if available_plots: + self._create_radio_buttons(available_plots) + else: + self._content_container.append( + pn.pane.Markdown("*No plotters available for this selection*") + ) + self._radio_group = None + self._notify_ready_changed(False) + + def _create_radio_buttons(self, available_plots: dict[str, PlotterSpec]) -> None: + """Create radio button group for plotter selection.""" + # Build mapping from display title to plot name. + # RadioButtonGroup displays keys (titles) and stores values (plot names). + # Handle potential duplicate titles by making them unique. + self._plot_name_map = self._make_unique_title_mapping(available_plots) + options = self._plot_name_map + + # Select first option by default + initial_value = ( + next(iter(self._plot_name_map.keys())) if self._plot_name_map else None + ) + + self._radio_group = pn.widgets.RadioButtonGroup( + name="Plotter Type", + options=options, + value=initial_value, + button_type="primary", + button_style="solid", + sizing_mode='stretch_width', + ) + self._radio_group.param.watch(self._on_plotter_selection_change, 'value') + self._content_container.append(self._radio_group) + + # Initialize with the selected value + if initial_value is not None: + self._selected_plot_name = self._plot_name_map[initial_value] + self._notify_ready_changed(True) + + def _make_unique_title_mapping( + self, available_plots: dict[str, PlotterSpec] + ) -> dict[str, str]: + """Create mapping from unique display titles to internal plot names.""" + title_counts: dict[str, int] = {} + result: dict[str, str] = {} + + # Sort alphabetically by title for better UX + sorted_plots = sorted(available_plots.items(), key=lambda x: x[1].title) + + for name, spec in sorted_plots: + title = spec.title + count = title_counts.get(title, 0) + title_counts[title] = count + 1 + + # Make title unique if we've seen it before + unique_title = f"{title} ({count + 1})" if count > 0 else title + result[unique_title] = name + + return result + + def _on_plotter_selection_change(self, event) -> None: + """Handle plotter selection change.""" + if event.new is not None: + self._selected_plot_name = self._plot_name_map[event.new] + self._notify_ready_changed(True) + else: + self._selected_plot_name = None + self._notify_ready_changed(False) + + +class ConfigurationStep(WizardStep[PlotterSelection, PlotResult]): + """Step 3: Plot configuration.""" + + def __init__( + self, + job_service: JobService, + plotting_controller: PlottingController, + logger: logging.Logger, + ) -> None: + """ + Initialize configuration step. + + Parameters + ---------- + job_service: + Service for accessing job data + plotting_controller: + Controller for plot creation + logger: + Logger instance for error reporting + """ + super().__init__() + self._job_service = job_service + self._plotting_controller = plotting_controller + self._logger = logger + self._config_panel: ConfigurationPanel | None = None + self._panel_container = pn.Column(sizing_mode='stretch_width') + self._plotter_selection: PlotterSelection | None = None + # Track last configuration to detect when panel needs recreation + self._last_job: JobNumber | None = None + self._last_output: str | None = None + self._last_plot_name: str | None = None + # Store result from callback + self._last_plot_result: PlotResult | None = None + + @property + def name(self) -> str: + """Display name for this step.""" + return "Configure Plot" + + def is_valid(self) -> bool: + """Step is valid when configuration is valid.""" + if self._config_panel is None: + return False + is_valid, _ = self._config_panel.validate() + return is_valid + + def commit(self) -> PlotResult | None: + """Commit the plot configuration and create the plot.""" + if self._config_panel is None or self._plotter_selection is None: + return None + + # Clear previous result + self._last_plot_result = None + + # Execute action (which calls adapter, which calls our callback) + success = self._config_panel.execute_action() + if not success: + return None + + # Result was captured by callback + return self._last_plot_result + + def render_content(self) -> pn.Column: + """Render configuration panel.""" + return self._panel_container + + def on_enter(self, input_data: PlotterSelection) -> None: + """Create or recreate configuration panel when selection changes.""" + self._plotter_selection = input_data + + # Check if the configuration has changed + if ( + input_data.job != self._last_job + or input_data.output != self._last_output + or input_data.plot_name != self._last_plot_name + ): + # Recreate panel with new configuration + self._create_config_panel() + # Track new values + self._last_job = input_data.job + self._last_output = input_data.output + self._last_plot_name = input_data.plot_name + + def _create_config_panel(self) -> None: + """Create the configuration panel for the selected plotter.""" + if self._plotter_selection is None: + return + + job_data = self._job_service.job_data.get(self._plotter_selection.job, {}) + available_sources = list(job_data.keys()) + + if not available_sources: + self._show_error('No sources available for selected job') + return + + try: + plot_spec = self._plotting_controller.get_spec( + self._plotter_selection.plot_name + ) + except Exception as e: + self._logger.exception("Error getting plot spec") + self._show_error(f'Error getting plot spec: {e}') + return + + config_adapter = PlotConfigurationAdapter( + job_number=self._plotter_selection.job, + output_name=self._plotter_selection.output, + plot_spec=plot_spec, + available_sources=available_sources, + plotting_controller=self._plotting_controller, + success_callback=self._on_plot_created, + ) + + self._config_panel = ConfigurationPanel(config=config_adapter) + + self._panel_container.clear() + self._panel_container.append(self._config_panel.panel) + + def _on_plot_created(self, plot, selected_sources: list[str]) -> None: + """Callback from adapter - store result for execute() to return.""" + self._last_plot_result = PlotResult( + plot=plot, selected_sources=selected_sources + ) + + def _show_error(self, message: str) -> None: + """Display an error notification.""" + if pn.state.notifications is not None: + pn.state.notifications.error(message, duration=3000) + + +class JobPlotterSelectionModal: + """ + Three-step wizard modal for selecting job/output, plotter type, and configuration. + + The modal guides the user through: + 1. Job and output selection from available data + 2. Plotter type selection based on compatibility with selected job/output + 3. Plotter configuration (source selection and parameters) + + Parameters + ---------- + job_service: + Service for accessing job data and information + plotting_controller: + Controller for determining available plotters + success_callback: + Called with (plot, selected_sources) when user completes configuration + cancel_callback: + Called when modal is closed or cancelled + """ + + def __init__( + self, + job_service: JobService, + plotting_controller: PlottingController, + success_callback: Callable, + cancel_callback: Callable[[], None], + ) -> None: + self._success_callback = success_callback + self._cancel_callback = cancel_callback + self._logger = logging.getLogger(__name__) + + # Create steps + step1 = JobOutputSelectionStep(job_service=job_service) + + step2 = PlotterSelectionStep( + plotting_controller=plotting_controller, logger=self._logger + ) + + step3 = ConfigurationStep( + job_service=job_service, + plotting_controller=plotting_controller, + logger=self._logger, + ) + + # Create wizard + self._wizard = Wizard( + steps=[step1, step2, step3], + on_complete=self._on_wizard_complete, + on_cancel=self._on_wizard_cancel, + action_button_label="Create Plot", + ) + + # Create modal wrapping the wizard + self._modal = pn.Modal( + self._wizard.render(), + name="Select Job and Plotter", + margin=20, + width=900, + height=700, + ) + + # Watch for modal close events (X button or ESC key). + # Panel's Modal widget uses 'open' as a boolean state property: + # when it transitions to False, the modal is closed. + self._modal.param.watch(self._on_modal_closed, 'open') + + def _on_wizard_complete(self, result: PlotResult) -> None: + """Handle wizard completion - close modal and call success callback.""" + self._modal.open = False + self._success_callback(result.plot, result.selected_sources) + + def _on_wizard_cancel(self) -> None: + """Handle wizard cancellation - close modal and call cancel callback.""" + self._modal.open = False + self._cancel_callback() + + def _on_modal_closed(self, event) -> None: + """Handle modal being closed via X button or ESC key.""" + if not event.new: # Modal was closed + # Only call cancel callback if wizard wasn't already completed/cancelled + if not self._wizard.is_finished(): + self._cancel_callback() + + def show(self) -> None: + """Show the modal dialog.""" + # Reset wizard and show modal + self._wizard.reset() + self._modal.open = True + + @property + def modal(self) -> pn.Modal: + """Get the modal widget.""" + return self._modal diff --git a/src/ess/livedata/dashboard/widgets/plot_creation_widget.py b/src/ess/livedata/dashboard/widgets/plot_creation_widget.py index d1b65a3ec..3ebefeca1 100644 --- a/src/ess/livedata/dashboard/widgets/plot_creation_widget.py +++ b/src/ess/livedata/dashboard/widgets/plot_creation_widget.py @@ -2,98 +2,20 @@ # Copyright (c) 2025 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from typing import Any - import holoviews as hv import pandas as pd import panel as pn -import pydantic from ess.livedata.config.workflow_spec import JobNumber -from ess.livedata.dashboard.configuration_adapter import ConfigurationAdapter from ess.livedata.dashboard.job_controller import JobController from ess.livedata.dashboard.job_service import JobService -from ess.livedata.dashboard.plotting import PlotterSpec from ess.livedata.dashboard.plotting_controller import PlottingController from ess.livedata.dashboard.workflow_controller import WorkflowController +from ..plot_configuration_adapter import PlotConfigurationAdapter from .configuration_widget import ConfigurationModal from .job_status_widget import JobStatusListWidget - - -class PlotConfigurationAdapter(ConfigurationAdapter): - """Adapter for plot configuration modal.""" - - def __init__( - self, - job_number: JobNumber, - output_name: str | None, - plot_spec: PlotterSpec, - available_sources: list[str], - plotting_controller: PlottingController, - success_callback, - ): - self._job_number = job_number - self._output_name = output_name - self._plot_spec = plot_spec - self._available_sources = available_sources - self._plotting_controller = plotting_controller - self._success_callback = success_callback - - self._persisted_config = ( - self._plotting_controller.get_persistent_plotter_config( - job_number=self._job_number, - output_name=self._output_name, - plot_name=self._plot_spec.name, - ) - ) - - @property - def title(self) -> str: - return f"Configure {self._plot_spec.title}" - - @property - def description(self) -> str: - return self._plot_spec.description - - def model_class(self) -> type[pydantic.BaseModel] | None: - return self._plot_spec.params - - @property - def source_names(self) -> list[str]: - return self._available_sources - - @property - def initial_source_names(self) -> list[str]: - if self._persisted_config is not None: - # Filter persisted source names to only include those still available - persisted_sources = [ - name - for name in self._persisted_config.source_names - if name in self._available_sources - ] - return persisted_sources if persisted_sources else self._available_sources - return self._available_sources - - @property - def initial_parameter_values(self) -> dict[str, Any]: - if self._persisted_config is not None: - return self._persisted_config.config.params - return {} - - def start_action( - self, - selected_sources: list[str], - parameter_values: Any, - ) -> None: - plot = self._plotting_controller.create_plot( - job_number=self._job_number, - source_names=selected_sources, - output_name=self._output_name, - plot_name=self._plot_spec.name, - params=parameter_values, - ) - self._success_callback(plot, selected_sources) +from .plot_grid_tab import PlotGridTab class PlotCreationWidget: @@ -133,6 +55,12 @@ def __init__( self._job_status_widget = JobStatusListWidget( job_service=job_service, job_controller=job_controller ) + # PlotCreationWidget is legacy; PlotGridTab placement here is temporary + self._plot_grid_tab = PlotGridTab( + job_service=job_service, + job_controller=job_controller, + plotting_controller=plotting_controller, + ) self._job_output_table = self._create_job_output_table() self._plot_selector = self._create_plot_selector() self._create_button = self._create_plot_button() @@ -154,6 +82,7 @@ def __init__( self._main_tabs = pn.Tabs( ("Jobs", self._job_status_widget.panel()), ("Create Plot", self._creation_tab), + ("Plot Grid", self._plot_grid_tab.widget), ("Plots", self._plot_tabs), sizing_mode='stretch_width', closable=False, diff --git a/src/ess/livedata/dashboard/widgets/plot_grid.py b/src/ess/livedata/dashboard/widgets/plot_grid.py new file mode 100644 index 000000000..1415dc154 --- /dev/null +++ b/src/ess/livedata/dashboard/widgets/plot_grid.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import holoviews as hv +import panel as pn + + +@dataclass(frozen=True) +class _CellStyles: + """Styling constants for PlotGrid cells.""" + + # Colors + PRIMARY_BLUE = '#007bff' + LIGHT_GRAY = '#dee2e6' + LIGHT_RED = '#ffe6e6' + LIGHT_BLUE = '#e7f3ff' + VERY_LIGHT_GRAY = '#f8f9fa' + MEDIUM_GRAY = '#6c757d' + MUTED_GRAY = '#adb5bd' + DANGER_RED = '#dc3545' + + # Dimensions + CELL_MIN_HEIGHT_PX = 100 + CELL_BORDER_WIDTH_NORMAL = 1 + CELL_BORDER_WIDTH_HIGHLIGHTED = 3 + CELL_MARGIN = 2 + CLOSE_BUTTON_SIZE = 40 + CLOSE_BUTTON_TOP_OFFSET = '5px' + CLOSE_BUTTON_RIGHT_OFFSET = '5px' + CLOSE_BUTTON_Z_INDEX = '1000' + + # Typography + FONT_SIZE_LARGE = '24px' + FONT_SIZE_CLOSE_BUTTON = '20px' + + +def _normalize_region(r1: int, c1: int, r2: int, c2: int) -> tuple[int, int, int, int]: + """ + Normalize region coordinates to (row_start, col_start, row_end, col_end). + + Parameters + ---------- + r1: + First row coordinate. + c1: + First column coordinate. + r2: + Second row coordinate. + c2: + Second column coordinate. + + Returns + ------- + : + Tuple of (row_start, col_start, row_end, col_end) where + row_start <= row_end and col_start <= col_end. + """ + return min(r1, r2), min(c1, c2), max(r1, r2), max(c1, c2) + + +def _calculate_region_span( + row_start: int, row_end: int, col_start: int, col_end: int +) -> tuple[int, int]: + """ + Calculate the span dimensions of a region. + + Parameters + ---------- + row_start: + Starting row (inclusive). + row_end: + Ending row (inclusive). + col_start: + Starting column (inclusive). + col_end: + Ending column (inclusive). + + Returns + ------- + : + Tuple of (row_span, col_span). + """ + return row_end - row_start + 1, col_end - col_start + 1 + + +def _format_region_label(row_span: int, col_span: int) -> str: + """ + Format a label describing region dimensions. + + Parameters + ---------- + row_span: + Number of rows in the region. + col_span: + Number of columns in the region. + + Returns + ------- + : + Formatted label string like "Click for 2x3 plot". + """ + return f'Click for {row_span}x{col_span} plot' + + +class PlotGrid: + """ + A grid widget for displaying multiple plots in a customizable layout. + + The PlotGrid allows users to select rectangular regions by clicking cells + and insert HoloViews plots into those regions. Each plot can be removed + via a close button. + + Parameters + ---------- + nrows: + Number of rows in the grid. + ncols: + Number of columns in the grid. + plot_request_callback: + Callback invoked when a region is selected. This callback will be + called asynchronously and should not return a value. The plot should + be inserted later via `insert_plot_deferred()`. + """ + + def __init__( + self, + nrows: int, + ncols: int, + plot_request_callback: Callable[[], None], + ) -> None: + self._nrows = nrows + self._ncols = ncols + self._plot_request_callback = plot_request_callback + + # State tracking + self._occupied_cells: dict[tuple[int, int, int, int], pn.Column] = {} + self._first_click: tuple[int, int] | None = None + self._highlighted_cell: pn.pane.HTML | None = None + self._pending_selection: tuple[int, int, int, int] | None = None + + # Create the grid + self._grid = pn.GridSpec( + sizing_mode='stretch_both', name='PlotGrid', min_height=600 + ) + + # Initialize empty cells + self._initialize_empty_cells() + + def _initialize_empty_cells(self) -> None: + """Populate the grid with empty clickable cells.""" + with pn.io.hold(): + for row in range(self._nrows): + for col in range(self._ncols): + self._grid[row, col] = self._create_empty_cell(row, col) + + def _create_empty_cell( + self, + row: int, + col: int, + highlighted: bool = False, + disabled: bool = False, + label: str | None = None, + large_font: bool = False, + ) -> pn.Column: + """Create an empty cell with placeholder text and click handler.""" + border_color = ( + _CellStyles.PRIMARY_BLUE if highlighted else _CellStyles.LIGHT_GRAY + ) + border_width = ( + _CellStyles.CELL_BORDER_WIDTH_HIGHLIGHTED + if highlighted + else _CellStyles.CELL_BORDER_WIDTH_NORMAL + ) + border_style = 'dashed' if highlighted else 'solid' + + if disabled: + background_color = _CellStyles.LIGHT_RED + text_color = _CellStyles.MUTED_GRAY + elif highlighted: + background_color = _CellStyles.LIGHT_BLUE + text_color = _CellStyles.MEDIUM_GRAY + else: + background_color = _CellStyles.VERY_LIGHT_GRAY + text_color = _CellStyles.MEDIUM_GRAY + + # Determine button label + if label is None: + label = '' if disabled else 'Click to add plot' + + # Font size - larger during selection process + # Use stylesheets to target the button element directly + if large_font: + stylesheets = [ + f""" + button {{ + font-size: {_CellStyles.FONT_SIZE_LARGE}; + font-weight: bold; + }} + """ + ] + else: + stylesheets = [] + + # Create a button that fills the cell + button = pn.widgets.Button( + name=label, + sizing_mode='stretch_both', + button_type='light', + disabled=disabled, + styles={ + 'background-color': background_color, + 'border': f'{border_width}px {border_style} {border_color}', + 'color': text_color, + 'min-height': f'{_CellStyles.CELL_MIN_HEIGHT_PX}px', + }, + stylesheets=stylesheets, + margin=_CellStyles.CELL_MARGIN, + ) + + # Attach click handler (even if disabled, for consistency) + def on_click(event: Any) -> None: + if not disabled: + self._on_cell_click(row, col) + + button.on_click(on_click) + + # Wrap in Column to allow for future expansion + return pn.Column(button, sizing_mode='stretch_both', margin=0) + + def _on_cell_click(self, row: int, col: int) -> None: + """Handle cell click for region selection.""" + if self._first_click is None: + # First click - start selection + self._first_click = (row, col) + self._refresh_all_cells() + else: + # Second click - complete selection + r1, c1 = self._first_click + r2, c2 = row, col + + # Normalize to get top-left and bottom-right corners + row_start, col_start, row_end, col_end = _normalize_region(r1, c1, r2, c2) + + # Calculate span + row_span, col_span = _calculate_region_span( + row_start, row_end, col_start, col_end + ) + + # Store selection for plot insertion + self._pending_selection = (row_start, col_start, row_span, col_span) + + # Clear selection highlight + self._clear_selection() + + # Request plot from callback (async, no return value) + self._plot_request_callback() + + def _is_cell_occupied(self, row: int, col: int) -> bool: + """Check if a specific cell is occupied by a plot.""" + for r, c, r_span, c_span in self._occupied_cells: + if r <= row < r + r_span and c <= col < c + c_span: + return True + return False + + def _is_region_available( + self, row_start: int, col_start: int, row_end: int, col_end: int + ) -> bool: + """Check if an entire region is available for plot insertion.""" + for row in range(row_start, row_end + 1): + for col in range(col_start, col_end + 1): + if self._is_cell_occupied(row, col): + return False + return True + + def _refresh_all_cells(self) -> None: + """Refresh all empty cells based on current selection state.""" + with pn.io.hold(): + for row in range(self._nrows): + for col in range(self._ncols): + if not self._is_cell_occupied(row, col): + # Delete the old cell first to avoid overlap warnings + try: + del self._grid[row, col] + except (KeyError, IndexError): + # Cell might not exist yet (during initialization) + pass + self._grid[row, col] = self._get_cell_for_state(row, col) + + def _get_cell_for_state(self, row: int, col: int) -> pn.Column: + """Get the appropriate cell widget based on current selection state.""" + if self._first_click is None: + # No selection in progress + return self._create_empty_cell(row, col) + + r1, c1 = self._first_click + + if row == r1 and col == c1: + # This is the first clicked cell - highlight it + return self._create_empty_cell( + row, + col, + highlighted=True, + label='Click again for 1x1 plot', + large_font=True, + ) + + # Check if this cell would create a valid region + row_start, col_start, row_end, col_end = _normalize_region(r1, c1, row, col) + + # Check if region is valid + is_valid = self._is_region_available(row_start, col_start, row_end, col_end) + + if not is_valid: + # Disable this cell + return self._create_empty_cell(row, col, disabled=True, large_font=True) + + # Calculate dimensions and format label + row_span, col_span = _calculate_region_span( + row_start, row_end, col_start, col_end + ) + label = _format_region_label(row_span, col_span) + + return self._create_empty_cell(row, col, label=label, large_font=True) + + def _clear_selection(self) -> None: + """Clear the current selection state.""" + self._first_click = None + self._highlighted_cell = None + self._refresh_all_cells() + + def _insert_plot(self, plot: hv.DynamicMap) -> None: + """Insert a plot into the grid at the pending selection.""" + if self._pending_selection is None: + return + + row, col, row_span, col_span = self._pending_selection + + # Create plot pane using the .layout pattern for DynamicMaps + plot_pane_wrapper = pn.pane.HoloViews(plot, sizing_mode='stretch_both') + plot_pane = plot_pane_wrapper.layout + + # Create close button with stylesheets for proper styling override + close_button = pn.widgets.Button( + name='\u00d7', # "X" multiplication sign + width=_CellStyles.CLOSE_BUTTON_SIZE, + height=_CellStyles.CLOSE_BUTTON_SIZE, + button_type='light', + sizing_mode='fixed', + margin=(_CellStyles.CELL_MARGIN, _CellStyles.CELL_MARGIN), + styles={ + 'position': 'absolute', + 'top': _CellStyles.CLOSE_BUTTON_TOP_OFFSET, + 'right': _CellStyles.CLOSE_BUTTON_RIGHT_OFFSET, + 'z-index': _CellStyles.CLOSE_BUTTON_Z_INDEX, + }, + stylesheets=[ + f""" + button {{ + background-color: transparent !important; + border: none !important; + color: {_CellStyles.DANGER_RED} !important; + font-weight: bold !important; + font-size: {_CellStyles.FONT_SIZE_CLOSE_BUTTON} !important; + padding: 0 !important; + }} + button:hover {{ + background-color: rgba(220, 53, 69, 0.1) !important; + }} + """ + ], + ) + + def on_close(event: Any) -> None: + self._remove_plot(row, col, row_span, col_span) + + close_button.on_click(on_close) + + container = pn.Column( + close_button, + plot_pane, + sizing_mode='stretch_both', + margin=2, + styles={'position': 'relative'}, + ) + + with pn.io.hold(): + # Delete existing cells in the region to avoid overlap warnings + for r in range(row, row + row_span): + for c in range(col, col + col_span): + try: + del self._grid[r, c] + except (KeyError, IndexError): + pass + + # Insert into grid + self._grid[row : row + row_span, col : col + col_span] = container + + # Track occupation + self._occupied_cells[(row, col, row_span, col_span)] = container + + # Clear pending selection + self._pending_selection = None + + def _remove_plot(self, row: int, col: int, row_span: int, col_span: int) -> None: + """Remove a plot from the grid and restore empty cells.""" + # Remove from tracking + key = (row, col, row_span, col_span) + if key in self._occupied_cells: + del self._occupied_cells[key] + + # Restore empty cells + with pn.io.hold(): + for r in range(row, row + row_span): + for c in range(col, col + col_span): + self._grid[r, c] = self._create_empty_cell(r, c) + + def _show_error(self, message: str) -> None: + """Display a temporary error notification.""" + if pn.state.notifications is not None: + pn.state.notifications.error(message, duration=3000) + + def insert_plot_deferred(self, plot: hv.DynamicMap) -> None: + """ + Complete plot insertion after async workflow. + + This method should be called after the plot request callback completes + successfully. It inserts the plot at the pending selection location + and clears the in-flight state. + + Parameters + ---------- + plot: + The HoloViews DynamicMap to insert into the grid. + """ + if self._pending_selection is None: + self._show_error('No pending selection to insert plot into') + return + + self._insert_plot(plot) + + def cancel_pending_selection(self) -> None: + """ + Abort the current plot creation workflow and reset state. + + This method should be called when the plot request callback is cancelled + or fails. It clears the pending selection. + """ + self._pending_selection = None + + @property + def panel(self) -> pn.viewable.Viewable: + """Get the Panel viewable object for this widget.""" + return self._grid diff --git a/src/ess/livedata/dashboard/widgets/plot_grid_tab.py b/src/ess/livedata/dashboard/widgets/plot_grid_tab.py new file mode 100644 index 000000000..3806fa1e2 --- /dev/null +++ b/src/ess/livedata/dashboard/widgets/plot_grid_tab.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +import holoviews as hv +import panel as pn + +from ess.livedata.dashboard.job_controller import JobController +from ess.livedata.dashboard.job_service import JobService +from ess.livedata.dashboard.plotting_controller import PlottingController + +from .job_plotter_selection_modal import JobPlotterSelectionModal +from .plot_grid import PlotGrid + + +class PlotGridTab: + """Tab widget that orchestrates PlotGrid with modal workflow for plot creation.""" + + def __init__( + self, + *, + job_service: JobService, + job_controller: JobController, + plotting_controller: PlottingController, + ) -> None: + """ + Initialize PlotGridTab. + + Parameters + ---------- + job_service: + Service for accessing job data + job_controller: + Controller for job operations + plotting_controller: + Controller for creating plotters + """ + self._job_service = job_service + self._job_controller = job_controller + self._plotting_controller = plotting_controller + + # Create PlotGrid (3x3 fixed for now; configurable tabs may be added later) + self._plot_grid = PlotGrid( + nrows=3, ncols=3, plot_request_callback=self._on_plot_requested + ) + + # Modal container for lifecycle management. + # Using pn.Row with height=0 ensures the modal is part of the component tree + # (required for rendering) but doesn't compete with the grid for vertical space. + # The modal itself renders as an overlay when opened. + self._modal_container = pn.Row(height=0, sizing_mode='stretch_width') + + # State for tracking current workflow + self._current_modal: JobPlotterSelectionModal | None = None + + # Create main widget - grid with zero-height modal container + self._widget = pn.Column( + self._plot_grid.panel, + self._modal_container, + sizing_mode='stretch_both', + ) + + def _on_plot_requested(self) -> None: + """Handle plot request from PlotGrid (user completed region selection).""" + # Create and show JobPlotterSelectionModal (now includes all 3 steps) + self._current_modal = JobPlotterSelectionModal( + job_service=self._job_service, + plotting_controller=self._plotting_controller, + success_callback=self._on_plot_created, + cancel_callback=self._on_modal_cancelled, + ) + + # Add modal to zero-height container so it renders but doesn't affect layout + self._modal_container.clear() + self._modal_container.append(self._current_modal.modal) + self._current_modal.show() + + def _on_plot_created( + self, plot: hv.DynamicMap, selected_sources: list[str] + ) -> None: + """Handle successful plot creation from configuration modal.""" + # Clear references BEFORE inserting plot to prevent cancellation on modal close + self._current_modal = None + + # Insert plot into grid using deferred insertion + self._plot_grid.insert_plot_deferred(plot) + + def _on_modal_cancelled(self) -> None: + """Handle modal cancellation.""" + # Cancel pending selection in PlotGrid + self._plot_grid.cancel_pending_selection() + + # Clear references + self._current_modal = None + + @property + def widget(self) -> pn.Column: + """Get the Panel widget.""" + return self._widget diff --git a/src/ess/livedata/dashboard/widgets/wizard.py b/src/ess/livedata/dashboard/widgets/wizard.py new file mode 100644 index 000000000..8758f8eae --- /dev/null +++ b/src/ess/livedata/dashboard/widgets/wizard.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +"""Generic multi-step wizard component.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any, Generic, TypeVar + +import panel as pn + +TInput = TypeVar('TInput') +TOutput = TypeVar('TOutput') + + +class WizardStep(ABC, Generic[TInput, TOutput]): + """ + Base class for wizard step components. + + Each step transforms input from the previous step into output for the next step. + The first step receives None as input. + + Type Parameters + ---------------- + TInput: + Type of input data from previous step (None for first step) + TOutput: + Type of output data to pass to next step + """ + + def __init__(self) -> None: + self._on_ready_changed: Callable[[bool], None] | None = None + self._step_number: int | None = None + + def on_ready_changed(self, callback: Callable[[bool], None]) -> None: + """Register callback to be notified when ready state changes.""" + self._on_ready_changed = callback + + def _notify_ready_changed(self, is_ready: bool) -> None: + """Notify wizard of ready state change.""" + if self._on_ready_changed: + self._on_ready_changed(is_ready) + + @property + @abstractmethod + def name(self) -> str: + """Display name for this step (e.g., 'Select Job and Output').""" + + @property + def description(self) -> str | None: + """Optional description text shown below the step header.""" + return None + + def render(self, step_number: int) -> pn.Column: + """ + Render the step's UI with automatic header generation. + + Parameters + ---------- + step_number: + The 1-based step number to display in the header + + Returns + ------- + : + Column containing header and step content + """ + self._step_number = step_number + + # Build header + header_parts = [f"
{self.description}
") + + return pn.Column( + pn.pane.HTML("".join(header_parts)), + self.render_content(), + sizing_mode='stretch_width', + ) + + @abstractmethod + def render_content(self) -> pn.Column | pn.viewable.Viewable: + """Render the step's content (without header).""" + + @abstractmethod + def is_valid(self) -> bool: + """Whether step data allows advancement.""" + + @abstractmethod + def commit(self) -> TOutput | None: + """ + Commit this step's data for the pipeline. + + Called when the user advances from this step. This method should package + the step's current state into output data for the next step. For the final + step, this may also trigger side effects (e.g., creating a plot). + + Returns + ------- + : + Output data to pass to next step, or None if commit failed + """ + + @abstractmethod + def on_enter(self, input_data: TInput) -> None: + """ + Called when step becomes active. + + Parameters + ---------- + input_data: + Output from the previous step (None for first step) + """ + + +class Wizard: + """ + Generic multi-step wizard component. + + The wizard manages navigation between steps, threading data from each step's + execution to the next step's input. Each step transforms input data to output + data, creating a pipeline of transformations. + + Parameters + ---------- + steps: + List of wizard steps to display in sequence + on_complete: + Called with final step's output when wizard completes successfully + on_cancel: + Called when wizard is cancelled + action_button_label: + Optional label for the action button on the last step (e.g., "Create Plot"). + If None, no action button is shown on the last step. + """ + + def __init__( + self, + steps: list[WizardStep[Any, Any]], + on_complete: Callable[[Any], None], + on_cancel: Callable[[], None], + action_button_label: str | None = None, + ) -> None: + self._steps = steps + self._on_complete = on_complete + self._on_cancel = on_cancel + self._action_button_label = action_button_label + + # State tracking + self._current_step_index = 0 + self._finished = False + self._step_results: list[Any] = [] # Results from executed steps + + # Navigation buttons + self._back_button = pn.widgets.Button( + name="Back", + button_type="light", + sizing_mode='fixed', + width=100, + ) + self._back_button.on_click(self._on_back_clicked) + + self._next_button = pn.widgets.Button( + name="Next", + button_type="primary", + sizing_mode='fixed', + width=120, + ) + self._next_button.on_click(self._on_next_clicked) + + self._cancel_button = pn.widgets.Button( + name="Cancel", + button_type="light", + sizing_mode='fixed', + width=100, + ) + self._cancel_button.on_click(self._on_cancel_clicked) + + # Content container + self._content = pn.Column(sizing_mode='stretch_width') + + def advance(self) -> None: + """Move to next step if current step is valid.""" + if not self._current_step.is_valid(): + return + + # Commit current step and get result + result = self._current_step.commit() + if result is None: + return # Commit failed, don't advance + + # Store result for this step + if self._current_step_index < len(self._step_results): + self._step_results[self._current_step_index] = result + else: + self._step_results.append(result) + + if self._current_step_index < len(self._steps) - 1: + # Move to next step + self._current_step_index += 1 + self._update_content() + else: + # Last step completed - pass result to completion callback + self.complete(result) + + def back(self) -> None: + """Go to previous step.""" + if self._current_step_index > 0: + self._current_step_index -= 1 + self._update_content() + + def complete(self, result: Any) -> None: + """ + Complete wizard successfully. + + Parameters + ---------- + result: + Output from the final step + """ + self._finished = True + self._on_complete(result) + + def cancel(self) -> None: + """Cancel wizard.""" + self._finished = True + self._on_cancel() + + def is_finished(self) -> bool: + """Whether wizard has completed or been cancelled.""" + return self._finished + + def reset(self) -> None: + """Reset wizard to first step.""" + self._current_step_index = 0 + self._finished = False + self._step_results = [] + self._update_content() + + def render(self) -> pn.Column: + """Render the wizard content.""" + return self._content + + @property + def _current_step(self) -> WizardStep[Any, Any]: + """Get the current step.""" + return self._steps[self._current_step_index] + + @property + def _is_first_step(self) -> bool: + """Check if on first step.""" + return self._current_step_index == 0 + + @property + def _is_last_step(self) -> bool: + """Check if on last step.""" + return self._current_step_index == len(self._steps) - 1 + + def _on_step_ready_changed(self, is_ready: bool) -> None: + """Handle step ready state change.""" + self._next_button.disabled = not is_ready + + def _update_content(self) -> None: + """Update modal content for current step.""" + self._current_step.on_ready_changed(self._on_step_ready_changed) + + # Get input for this step: None for first step, otherwise previous step's result + if self._current_step_index == 0: + input_data = None + else: + input_data = self._step_results[self._current_step_index - 1] + + self._current_step.on_enter(input_data) + self._render_step() + + def _render_step(self) -> None: + """Render current step with navigation buttons.""" + self._content.clear() + + # Add step content with 1-based step number + self._content.append(self._current_step.render(self._current_step_index + 1)) + + # Add vertical spacer to push buttons to bottom + self._content.append(pn.layout.VSpacer()) + + # Update next button state based on step validity + self._next_button.disabled = not self._current_step.is_valid() + + # Build navigation row with standard order: Cancel | Spacer | Back | Next + nav_buttons = [self._cancel_button, pn.layout.HSpacer()] + + if not self._is_first_step: + nav_buttons.append(self._back_button) + + # Show Next/Action button based on step + if self._is_last_step and self._action_button_label: + self._next_button.name = self._action_button_label + nav_buttons.append(self._next_button) + elif not self._is_last_step: + self._next_button.name = "Next" + nav_buttons.append(self._next_button) + + self._content.append( + pn.Row(*nav_buttons, sizing_mode='stretch_width', margin=(10, 0)) + ) + + def _on_next_clicked(self, event) -> None: + """Handle next button click.""" + self.advance() + + def _on_back_clicked(self, event) -> None: + """Handle back button click.""" + self.back() + + def _on_cancel_clicked(self, event) -> None: + """Handle cancel button click.""" + self.cancel() diff --git a/tests/dashboard/widgets/plot_grid_test.py b/tests/dashboard/widgets/plot_grid_test.py new file mode 100644 index 000000000..907af51ec --- /dev/null +++ b/tests/dashboard/widgets/plot_grid_test.py @@ -0,0 +1,405 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import holoviews as hv +import numpy as np +import panel as pn +import pytest + +from ess.livedata.dashboard.widgets.plot_grid import PlotGrid + + +class FakeCallback: + """Fake callback for testing plot requests.""" + + def __init__(self, side_effect: Exception | None = None) -> None: + self.call_count = 0 + self.calls: list = [] + self._side_effect = side_effect + + def __call__(self, *args, **kwargs) -> None: + self.call_count += 1 + self.calls.append((args, kwargs)) + if self._side_effect is not None: + raise self._side_effect + + def reset(self) -> None: + self.call_count = 0 + self.calls.clear() + + def assert_called_once(self) -> None: + assert self.call_count == 1, f"Expected 1 call, got {self.call_count}" + + def assert_not_called(self) -> None: + assert self.call_count == 0, f"Expected 0 calls, got {self.call_count}" + + +@pytest.fixture +def mock_plot() -> hv.DynamicMap: + """Create a mock HoloViews DynamicMap for testing.""" + + def create_curve(x_range): + x = np.linspace(x_range[0], x_range[1], 100) + y = np.sin(x) + return hv.Curve((x, y)) + + return hv.DynamicMap(create_curve, kdims=['x_range']).redim.range(x_range=(0, 10)) + + +@pytest.fixture +def mock_callback() -> FakeCallback: + """Create a fake callback for plot requests.""" + return FakeCallback() + + +def get_cell_button(grid: PlotGrid, row: int, col: int) -> pn.widgets.Button | None: + """ + Get the empty cell button widget from a grid cell. + + Returns None if cell is not a simple empty cell (e.g., contains a plot). + """ + try: + cell = grid.panel[row, col] # type: ignore[index] + if isinstance(cell, pn.Column) and len(cell) > 0: + first_item = cell[0] + if isinstance(first_item, pn.widgets.Button): + # Check if this is the close button (multiplication sign character) + if first_item.name == '\u00d7': + # This is a plot cell with a close button + return None + # This is an empty cell button + return first_item + except (KeyError, IndexError): + pass + return None + + +def simulate_click(grid: PlotGrid, row: int, col: int) -> None: + """Simulate a user clicking on a grid cell by triggering button's click event. + + This simulates a standard left-click interaction with the button. + """ + button = get_cell_button(grid, row, col) + if button is None: + msg = f"Cannot click cell ({row}, {col}): no clickable button found" + raise ValueError(msg) + if button.disabled: # type: ignore[truthy-bool] + msg = f"Cannot click cell ({row}, {col}): button is disabled" + raise ValueError(msg) + # Trigger the click event by incrementing clicks parameter + button.param.trigger('clicks') + + +def is_cell_occupied(grid: PlotGrid, row: int, col: int) -> bool: + """ + Check if a cell contains a plot (observable behavior). + + A cell is considered occupied if it doesn't have a simple button widget. + """ + return get_cell_button(grid, row, col) is None + + +def find_close_button(grid: PlotGrid, row: int, col: int) -> pn.widgets.Button | None: + """Find the close button within a plot cell.""" + try: + cell = grid.panel[row, col] # type: ignore[index] + if isinstance(cell, pn.Column): + for item in cell: + if isinstance(item, pn.widgets.Button) and item.name == '\u00d7': + return item + except (KeyError, IndexError): + pass + return None + + +def count_occupied_cells(grid: PlotGrid) -> int: + """Count how many cell positions contain plots.""" + count = 0 + # We need to know grid dimensions - we can infer from the panel + # Panel GridSpec doesn't expose nrows/ncols directly, but we can check + # We'll need to iterate over what we expect based on initialization + # This is a bit tricky without accessing private attributes + # For now, let's just try reasonable ranges + for row in range(10): # Assume max 10 rows + for col in range(10): # Assume max 10 cols + if is_cell_occupied(grid, row, col): + count += 1 + return count + + +class TestPlotGridInitialization: + def test_grid_has_panel_property(self, mock_callback: FakeCallback) -> None: + grid = PlotGrid(nrows=2, ncols=2, plot_request_callback=mock_callback) + assert grid.panel is not None + assert isinstance(grid.panel, pn.GridSpec) + + def test_grid_starts_with_empty_clickable_cells( + self, mock_callback: FakeCallback + ) -> None: + grid = PlotGrid(nrows=2, ncols=2, plot_request_callback=mock_callback) + + # All cells should have clickable buttons + for row in range(2): + for col in range(2): + button = get_cell_button(grid, row, col) + assert button is not None, f"Cell ({row}, {col}) should have a button" + assert ( + not button.disabled # type: ignore[truthy-bool] + ), f"Cell ({row}, {col}) should be enabled" + + +class TestCellSelection: + def test_single_cell_selection_triggers_callback( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=3, ncols=3, plot_request_callback=mock_callback) + + # First click should not trigger callback + simulate_click(grid, 1, 1) + mock_callback.assert_not_called() + + # Second click on same cell should trigger callback + simulate_click(grid, 1, 1) + mock_callback.assert_called_once() + + # Complete the deferred insertion + grid.insert_plot_deferred(mock_plot) + + # Cell should now contain a plot + assert is_cell_occupied(grid, 1, 1) + + def test_rectangular_region_selection( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=4, ncols=4, plot_request_callback=mock_callback) + + # Click two corners of a region + simulate_click(grid, 0, 0) + simulate_click(grid, 1, 2) + + mock_callback.assert_called_once() + + # Complete the deferred insertion + grid.insert_plot_deferred(mock_plot) + + # All cells in the 2x3 region should be occupied + assert is_cell_occupied(grid, 0, 0) + assert is_cell_occupied(grid, 0, 1) + assert is_cell_occupied(grid, 0, 2) + assert is_cell_occupied(grid, 1, 0) + assert is_cell_occupied(grid, 1, 1) + assert is_cell_occupied(grid, 1, 2) + + # Cells outside region should be empty + assert not is_cell_occupied(grid, 2, 0) + assert not is_cell_occupied(grid, 0, 3) + + def test_selection_works_regardless_of_click_order( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=4, ncols=4, plot_request_callback=mock_callback) + + # Click bottom-right first, then top-left + simulate_click(grid, 2, 2) + simulate_click(grid, 1, 1) + + grid.insert_plot_deferred(mock_plot) + + # Should still create a 2x2 region + assert is_cell_occupied(grid, 1, 1) + assert is_cell_occupied(grid, 1, 2) + assert is_cell_occupied(grid, 2, 1) + assert is_cell_occupied(grid, 2, 2) + + def test_first_click_changes_cell_appearance( + self, mock_callback: FakeCallback + ) -> None: + grid = PlotGrid(nrows=3, ncols=3, plot_request_callback=mock_callback) + + # Get initial button state + button_before = get_cell_button(grid, 0, 0) + initial_label = button_before.name if button_before else None + + # Click the cell + simulate_click(grid, 0, 0) + + # Button should now show different label + button_after = get_cell_button(grid, 0, 0) + new_label = button_after.name if button_after else None + + assert initial_label != new_label + assert new_label is not None + assert '1x1' in new_label # type: ignore[operator] + + +class TestPlotInsertion: + def test_multiple_plots_can_be_inserted( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=3, ncols=3, plot_request_callback=mock_callback) + + # Insert first plot + simulate_click(grid, 0, 0) + simulate_click(grid, 0, 0) + grid.insert_plot_deferred(mock_plot) + + # Insert second plot + simulate_click(grid, 2, 2) + simulate_click(grid, 2, 2) + grid.insert_plot_deferred(mock_plot) + + # Both cells should be occupied + assert is_cell_occupied(grid, 0, 0) + assert is_cell_occupied(grid, 2, 2) + + def test_inserted_plot_has_close_button( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=3, ncols=3, plot_request_callback=mock_callback) + + simulate_click(grid, 1, 1) + simulate_click(grid, 1, 1) + grid.insert_plot_deferred(mock_plot) + + # Should be able to find a close button + close_button = find_close_button(grid, 1, 1) + assert close_button is not None + + +class TestPlotRemoval: + def test_clicking_close_button_removes_plot( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=3, ncols=3, plot_request_callback=mock_callback) + + # Insert plot + simulate_click(grid, 0, 0) + simulate_click(grid, 1, 1) + grid.insert_plot_deferred(mock_plot) + + # Verify plot is there + assert is_cell_occupied(grid, 0, 0) + + # Click close button + close_button = find_close_button(grid, 0, 0) + assert close_button is not None + close_button.param.trigger('clicks') + + # Cells should now be empty and clickable again + assert not is_cell_occupied(grid, 0, 0) + assert not is_cell_occupied(grid, 1, 1) + assert get_cell_button(grid, 0, 0) is not None + assert get_cell_button(grid, 1, 1) is not None + + def test_removed_cells_become_selectable_again( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=3, ncols=3, plot_request_callback=mock_callback) + + # Insert and remove plot + simulate_click(grid, 1, 1) + simulate_click(grid, 1, 1) + grid.insert_plot_deferred(mock_plot) + + close_button = find_close_button(grid, 1, 1) + assert close_button is not None + close_button.param.trigger('clicks') + + mock_callback.reset() + + # Should be able to select the cell again + simulate_click(grid, 1, 1) + simulate_click(grid, 1, 1) + + assert mock_callback.call_count == 1 + + grid.insert_plot_deferred(mock_plot) + assert is_cell_occupied(grid, 1, 1) + + +class TestOverlapPrevention: + def test_cannot_select_overlapping_region( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=4, ncols=4, plot_request_callback=mock_callback) + + # Insert plot at (1, 1) to (2, 2) + simulate_click(grid, 1, 1) + simulate_click(grid, 2, 2) + grid.insert_plot_deferred(mock_plot) + + # Start new selection at (0, 0) + simulate_click(grid, 0, 0) + + # Cell (1, 1) should now be disabled (since it would overlap) + button = get_cell_button(grid, 1, 1) + # Button should be None because that cell is occupied + assert button is None + + def test_non_overlapping_regions_can_be_selected( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=4, ncols=4, plot_request_callback=mock_callback) + + # Insert plot at (1, 1) to (2, 2) + simulate_click(grid, 1, 1) + simulate_click(grid, 2, 2) + grid.insert_plot_deferred(mock_plot) + + mock_callback.reset() + + # Should be able to select non-overlapping regions + simulate_click(grid, 0, 0) + simulate_click(grid, 0, 0) + mock_callback.assert_called_once() + + grid.insert_plot_deferred(mock_plot) + assert is_cell_occupied(grid, 0, 0) + + +class TestSelectionCancellation: + def test_cancel_pending_selection_clears_state( + self, mock_callback: FakeCallback + ) -> None: + grid = PlotGrid(nrows=3, ncols=3, plot_request_callback=mock_callback) + + # Start a selection + simulate_click(grid, 0, 0) + simulate_click(grid, 1, 1) + + # Cancel it + grid.cancel_pending_selection() + + # Should be able to start a new selection + mock_callback.reset() + simulate_click(grid, 2, 2) + simulate_click(grid, 2, 2) + mock_callback.assert_called_once() + + +class TestErrorHandling: + def test_insert_without_pending_selection_shows_error( + self, mock_callback: FakeCallback, mock_plot: hv.DynamicMap + ) -> None: + grid = PlotGrid(nrows=3, ncols=3, plot_request_callback=mock_callback) + + # Try to insert without making a selection + # This should handle gracefully (no crash) + grid.insert_plot_deferred(mock_plot) + + # No cells should be occupied + assert not is_cell_occupied(grid, 0, 0) + assert not is_cell_occupied(grid, 1, 1) + + def test_callback_error_prevents_plot_insertion(self) -> None: + error_callback = FakeCallback(side_effect=ValueError('Test error')) + grid = PlotGrid(nrows=3, ncols=3, plot_request_callback=error_callback) + + simulate_click(grid, 0, 0) + + # Second click raises error, but grid should handle it + with pytest.raises(ValueError, match='Test error'): + simulate_click(grid, 0, 0) + + # Grid should still be in a usable state + # We never called insert_plot_deferred, so cell should still be empty + assert not is_cell_occupied(grid, 0, 0) diff --git a/tests/dashboard/widgets/wizard_test.py b/tests/dashboard/widgets/wizard_test.py new file mode 100644 index 000000000..5b48a8d4e --- /dev/null +++ b/tests/dashboard/widgets/wizard_test.py @@ -0,0 +1,778 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +from typing import Any + +import panel as pn + +from ess.livedata.dashboard.widgets.wizard import Wizard, WizardStep + + +class FakeWizardStep(WizardStep[Any, Any]): + """Test implementation of WizardStep.""" + + def __init__( + self, + step_name: str = "test_step", + valid: bool = True, + can_execute: bool = True, + step_description: str | None = None, + return_value: Any = None, + ) -> None: + super().__init__() + self._name = step_name + self._description = step_description + self._valid = valid + self._can_execute = can_execute + self._return_value = return_value + self.enter_called = False + self.execute_called = False + self.received_input: Any = None + + @property + def name(self) -> str: + """Display name for this step.""" + return self._name + + @property + def description(self) -> str | None: + """Optional description text.""" + return self._description + + def render_content(self) -> pn.Column: + """Render step content.""" + return pn.Column(pn.pane.Markdown(f"Content for {self.name}")) + + def is_valid(self) -> bool: + """Whether step is valid.""" + return self._valid + + def on_enter(self, input_data: Any) -> None: + """Called when step becomes active.""" + self.enter_called = True + self.received_input = input_data + + def commit(self) -> Any: + """Commit step data and return result.""" + self.execute_called = True + if not self._can_execute: + return None + return ( + self._return_value if self._return_value is not None else {"result": "ok"} + ) + + def set_valid(self, valid: bool) -> None: + """Change validity and notify wizard.""" + self._valid = valid + self._notify_ready_changed(valid) + + +class TestWizardStep: + """Tests for WizardStep base class.""" + + def test_can_register_ready_callback(self): + step = FakeWizardStep() + callback_called = False + + def callback(is_ready: bool) -> None: + nonlocal callback_called + callback_called = True + + step.on_ready_changed(callback) + step._notify_ready_changed(True) + + assert callback_called + + def test_ready_callback_receives_correct_value(self): + step = FakeWizardStep() + received_value = None + + def callback(is_ready: bool) -> None: + nonlocal received_value + received_value = is_ready + + step.on_ready_changed(callback) + step._notify_ready_changed(True) + + assert received_value is True + + def test_notify_without_callback_does_not_raise(self): + step = FakeWizardStep() + # Should not raise even without callback registered + step._notify_ready_changed(True) + + +class TestWizardInitialization: + """Tests for Wizard initialization.""" + + def test_initial_state_is_not_finished(self): + steps = [FakeWizardStep()] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + assert not wizard.is_finished() + + def test_starts_at_first_step(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + assert wizard._current_step_index == 0 + assert wizard._current_step is steps[0] + + def test_stores_action_button_label(self): + steps = [FakeWizardStep()] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + action_button_label="Create", + ) + + assert wizard._action_button_label == "Create" + + def test_action_button_label_defaults_to_none(self): + steps = [FakeWizardStep()] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + assert wizard._action_button_label is None + + +class TestWizardNavigation: + """Tests for wizard navigation.""" + + def test_advance_moves_to_next_step(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.advance() + + assert wizard._current_step_index == 1 + assert wizard._current_step is steps[1] + + def test_advance_does_not_move_if_step_invalid(self): + steps = [FakeWizardStep("step1", valid=False), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.advance() + + assert wizard._current_step_index == 0 + + def test_advance_on_last_step_completes_wizard(self): + steps = [FakeWizardStep("step1")] + completed = False + + def on_complete(ctx: Any) -> None: + nonlocal completed + completed = True + + wizard = Wizard( + steps=steps, + on_complete=on_complete, + on_cancel=lambda: None, + ) + + wizard.advance() + + assert completed + assert wizard.is_finished() + + def test_advance_on_last_step_calls_execute_if_present(self): + step = FakeWizardStep("step1", can_execute=True) + wizard = Wizard( + steps=[step], + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.advance() + + assert step.execute_called + + def test_advance_does_not_complete_if_execute_fails(self): + step = FakeWizardStep("step1", can_execute=False) + completed = False + + def on_complete(ctx: Any) -> None: + nonlocal completed + completed = True + + wizard = Wizard( + steps=[step], + on_complete=on_complete, + on_cancel=lambda: None, + ) + + wizard.advance() + + assert not completed + assert not wizard.is_finished() + + def test_back_moves_to_previous_step(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.advance() + wizard.back() + + assert wizard._current_step_index == 0 + + def test_back_on_first_step_does_nothing(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.back() + + assert wizard._current_step_index == 0 + + def test_on_enter_called_when_advancing(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._update_content() + wizard.advance() + + assert steps[1].enter_called + + def test_on_enter_called_when_going_back(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._update_content() + wizard.advance() + steps[0].enter_called = False # Reset flag + wizard.back() + + assert steps[0].enter_called + + +class TestWizardCompletion: + """Tests for wizard completion and cancellation.""" + + def test_complete_marks_wizard_as_finished(self): + steps = [FakeWizardStep()] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.complete(None) + + assert wizard.is_finished() + + def test_complete_calls_on_complete_callback(self): + steps = [FakeWizardStep(return_value={"foo": "bar"})] + received_result = None + + def on_complete(result: Any) -> None: + nonlocal received_result + received_result = result + + wizard = Wizard( + steps=steps, + on_complete=on_complete, + on_cancel=lambda: None, + ) + + wizard.complete({"foo": "bar"}) + + assert received_result == {"foo": "bar"} + + def test_cancel_marks_wizard_as_finished(self): + steps = [FakeWizardStep()] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.cancel() + + assert wizard.is_finished() + + def test_cancel_calls_on_cancel_callback(self): + steps = [FakeWizardStep()] + cancelled = False + + def on_cancel() -> None: + nonlocal cancelled + cancelled = True + + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=on_cancel, + ) + + wizard.cancel() + + assert cancelled + + +class TestWizardReset: + """Tests for wizard reset functionality.""" + + def test_reset_returns_to_first_step(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.advance() + wizard.reset() + + assert wizard._current_step_index == 0 + + def test_reset_clears_finished_flag(self): + steps = [FakeWizardStep()] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.complete(None) + wizard.reset() + + assert not wizard.is_finished() + + +class TestWizardRendering: + """Tests for wizard rendering.""" + + def test_render_returns_same_content_container(self): + steps = [FakeWizardStep()] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + content1 = wizard.render() + content2 = wizard.render() + + assert content1 is content2 + + def test_render_step_includes_navigation_buttons(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._render_step() + + # Last item should be a Row containing buttons + assert isinstance(wizard._content[-1], pn.Row) + + def test_first_step_does_not_show_back_button(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._render_step() + button_row = wizard._content[-1] + + # Back button should not be in the row + assert wizard._back_button not in button_row + + def test_middle_step_shows_back_button(self): + steps = [ + FakeWizardStep("step1"), + FakeWizardStep("step2"), + FakeWizardStep("step3"), + ] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.advance() + wizard._render_step() + button_row = wizard._content[-1] + + # Back button should be in the row + assert wizard._back_button in button_row + + def test_non_last_step_shows_next_button(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._render_step() + button_row = wizard._content[-1] + + # Next button should be in the row + assert wizard._next_button in button_row + assert wizard._next_button.name == "Next" + + def test_last_step_shows_action_button_when_label_provided(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + action_button_label="Create Plot", + ) + + wizard.advance() + wizard._render_step() + button_row = wizard._content[-1] + + # Next button should be shown with custom label + assert wizard._next_button in button_row + assert wizard._next_button.name == "Create Plot" + + def test_last_step_hides_button_when_no_action_label(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + action_button_label=None, + ) + + wizard.advance() + wizard._render_step() + button_row = wizard._content[-1] + + # Next button should not be shown on last step without action label + assert wizard._next_button not in button_row + + def test_cancel_button_always_shown(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + # Check first step + wizard._render_step() + assert wizard._cancel_button in wizard._content[-1] + + # Check last step + wizard.advance() + wizard._render_step() + assert wizard._cancel_button in wizard._content[-1] + + +class TestWizardButtonState: + """Tests for wizard button state management.""" + + def test_next_button_enabled_when_step_valid(self): + steps = [FakeWizardStep("step1", valid=True)] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._render_step() + + assert not wizard._next_button.disabled + + def test_next_button_disabled_when_step_invalid(self): + steps = [FakeWizardStep("step1", valid=False)] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._render_step() + + assert wizard._next_button.disabled + + def test_step_ready_changed_updates_next_button(self): + step = FakeWizardStep("step1", valid=False) + wizard = Wizard( + steps=[step], + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._update_content() + assert wizard._next_button.disabled + + # Simulate step becoming valid + step.set_valid(True) + + assert not wizard._next_button.disabled + + +class TestWizardButtonCallbacks: + """Tests for wizard button click callbacks.""" + + def test_next_button_calls_advance(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._on_next_clicked(None) + + assert wizard._current_step_index == 1 + + def test_back_button_calls_back(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.advance() + wizard._on_back_clicked(None) + + assert wizard._current_step_index == 0 + + def test_cancel_button_calls_cancel(self): + steps = [FakeWizardStep()] + cancelled = False + + def on_cancel() -> None: + nonlocal cancelled + cancelled = True + + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=on_cancel, + ) + + wizard._on_cancel_clicked(None) + + assert cancelled + + +class TestWizardProperties: + """Tests for wizard properties.""" + + def test_is_first_step_true_on_first_step(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + assert wizard._is_first_step + + def test_is_first_step_false_on_second_step(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.advance() + + assert not wizard._is_first_step + + def test_is_last_step_false_on_first_step(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + assert not wizard._is_last_step + + def test_is_last_step_true_on_last_step(self): + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard.advance() + + assert wizard._is_last_step + + +class TestWizardIntegration: + """Integration tests for complete wizard workflows.""" + + def test_complete_wizard_flow(self): + """Test a complete wizard flow from start to finish.""" + step1 = FakeWizardStep("step1") + step2 = FakeWizardStep("step2") + step3 = FakeWizardStep("step3", return_value={"final": "result"}) + completed = False + received_result = None + + def on_complete(result: Any) -> None: + nonlocal completed, received_result + completed = True + received_result = result + + wizard = Wizard( + steps=[step1, step2, step3], + on_complete=on_complete, + on_cancel=lambda: None, + ) + + # Start wizard + wizard._update_content() + assert wizard._current_step_index == 0 + assert step1.enter_called + + # Advance to step 2 + wizard.advance() + assert wizard._current_step_index == 1 + assert step2.enter_called + + # Go back to step 1 + wizard.back() + assert wizard._current_step_index == 0 + + # Advance through all steps + wizard.advance() + wizard.advance() + assert wizard._current_step_index == 2 + assert step3.enter_called + + # Complete wizard + wizard.advance() + assert completed + assert received_result == {"final": "result"} + assert wizard.is_finished() + + def test_wizard_cancellation_flow(self): + """Test wizard cancellation at different steps.""" + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + cancelled = False + + def on_cancel() -> None: + nonlocal cancelled + cancelled = True + + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=on_cancel, + ) + + wizard._update_content() + wizard.advance() + wizard.cancel() + + assert cancelled + assert wizard.is_finished() + + def test_wizard_with_invalid_step(self): + """Test that wizard cannot advance past invalid step.""" + step1 = FakeWizardStep("step1", valid=True) + step2 = FakeWizardStep("step2", valid=False) + step3 = FakeWizardStep("step3", valid=True) + + wizard = Wizard( + steps=[step1, step2, step3], + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + wizard._update_content() + wizard.advance() + assert wizard._current_step_index == 1 + + # Try to advance past invalid step + wizard.advance() + assert wizard._current_step_index == 1 # Should not advance + + # Make step valid and try again + step2.set_valid(True) + wizard.advance() + assert wizard._current_step_index == 2 # Should advance now + + def test_wizard_reset_after_completion(self): + """Test resetting wizard after completion.""" + steps = [FakeWizardStep("step1"), FakeWizardStep("step2")] + wizard = Wizard( + steps=steps, + on_complete=lambda result: None, + on_cancel=lambda: None, + ) + + # Complete wizard + wizard._update_content() + wizard.advance() + wizard.advance() + assert wizard.is_finished() + + # Reset wizard + wizard.reset() + assert not wizard.is_finished() + assert wizard._current_step_index == 0 + + def test_wizard_with_action_button_execution(self): + """Test wizard with action button that executes on last step.""" + step = FakeWizardStep("step1", can_execute=True) + completed = False + + def on_complete(ctx: Any) -> None: + nonlocal completed + completed = True + + wizard = Wizard( + steps=[step], + on_complete=on_complete, + on_cancel=lambda: None, + action_button_label="Execute", + ) + + wizard._update_content() + wizard.advance() + + assert step.execute_called + assert completed