diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py index 729958430..3f7639fa5 100644 --- a/pymc/progress_bar.py +++ b/pymc/progress_bar.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Protocol, cast from rich.box import SIMPLE_HEAD from rich.console import Console @@ -192,6 +192,143 @@ def callbacks(self, task: "Task"): self.finished_style = self.default_finished_style +class ProgressTask(Protocol): + """A protocol for a task in a progress bar. + + This protocol defines the expected interface for tasks that can be added to a progress bar. + """ + + @property + def elapsed(self): + """Get the elapsed time for this task.""" + + +class ProgressBar(Protocol): + @property + def tasks(self) -> list[ProgressTask]: + """Get the tasks in the progress bar.""" + + def add_task(self, *args, **kwargs) -> ProgressTask | None: + """Add a task to the progress bar.""" + + def update(self, task_id, **kwargs): + """Update the task with the given ID with the provided keyword arguments.""" + + def __enter__(self): + """Enter the context manager.""" + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager.""" + + +def compute_draw_speed(elapsed, draws): + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + + +def create_rich_progress_bar(full_stats, step_columns, progressbar, progressbar_theme): + columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + + if full_stats: + columns += step_columns + + columns += [ + TextColumn( + "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", + table_column=Column("Sampling Speed", ratio=1), + ), + TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), + TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), + ] + + return CustomProgress( + RecolorOnFailureBarColumn( + table_column=Column("Progress", ratio=2), + failing_color="tab:red", + complete_style=Style.parse("rgb(31,119,180)"), # tab:blue + finished_style=Style.parse("rgb(31,119,180)"), # tab:blue + ), + *columns, + console=Console(theme=progressbar_theme), + disable=not progressbar, + include_headers=True, + ) + + +class MarimoProgressTask: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + @property + def chain_idx(self) -> int: + return self.kwargs.get("chain_idx", 0) + + @property + def total(self): + return self.kwargs.get("total", 0) + + @property + def elapsed(self): + return self.kwargs.get("elapsed", 0) + + +class MarimoProgressBar: + def __init__(self) -> None: + self.tasks: list[ProgressTask] = [] + self.divergences: dict[int, int] = {} + + def __enter__(self): + """Enter the context manager.""" + import marimo as mo + + total_draws = (self.tasks[0].total + 1) * len(self.tasks) + + self.bar = mo.status.progress_bar(total=total_draws, title="Sampling PyMC model") + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager.""" + self.bar._finish() + + def add_task(self, *args, **kwargs): + """Add a task to the progress bar.""" + task = MarimoProgressTask(*args, **kwargs) + self.tasks.append(task) + return task + + def update(self, task_id, **kwargs): + """Update the task with the given ID with the provided keyword arguments.""" + if self.bar.progress.current >= cast(int, self.bar.progress.total): + return + + self.divergences[task_id.chain_idx] = kwargs.get("divergences", 0) + + total_divergences = sum(self.divergences.values()) + + update_kwargs = {} + if total_divergences > 0: + word = "draws" if total_divergences > 1 else "draw" + update_kwargs["subtitle"] = f"{total_divergences} diverging {word}" + + self.bar.progress.update(**update_kwargs) + + +def in_marimo_notebook() -> bool: + try: + import marimo as mo + + return mo.running_in_notebook() + except ImportError: + return False + + class ProgressBarManager: """Manage progress bars displayed during sampling.""" @@ -275,11 +412,16 @@ def __init__( progress_columns, progress_stats = step_method._progressbar_config(chains) - self._progress = self.create_progress_bar( - progress_columns, - progressbar=progressbar, - progressbar_theme=progressbar_theme, - ) + if in_marimo_notebook(): + self.combined_progress = False + self._progress = MarimoProgressBar() + else: + self._progress = create_rich_progress_bar( + self.full_stats, + progress_columns, + progressbar=progressbar, + progressbar_theme=progressbar_theme, + ) self.progress_stats = progress_stats self.update_stats_functions = step_method._make_progressbar_update_functions() @@ -331,18 +473,6 @@ def _initialize_tasks(self): for chain_idx in range(self.chains) ] - @staticmethod - def compute_draw_speed(elapsed, draws): - speed = draws / max(elapsed, 1e-6) - - if speed > 1 or speed == 0: - unit = "draws/s" - else: - unit = "s/draws" - speed = 1 / speed - - return speed, unit - def update(self, chain_idx, is_last, draw, tuning, stats): if not self._show_progress: return @@ -353,7 +483,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): chain_idx = 0 elapsed = self._progress.tasks[chain_idx].elapsed - speed, unit = self.compute_draw_speed(elapsed, draw) + speed, unit = compute_draw_speed(elapsed, draw) failing = False all_step_stats = {} @@ -395,31 +525,3 @@ def update(self, chain_idx, is_last, draw, tuning, stats): **all_step_stats, refresh=True, ) - - def create_progress_bar(self, step_columns, progressbar, progressbar_theme): - columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] - - if self.full_stats: - columns += step_columns - - columns += [ - TextColumn( - "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", - table_column=Column("Sampling Speed", ratio=1), - ), - TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), - TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), - ] - - return CustomProgress( - RecolorOnFailureBarColumn( - table_column=Column("Progress", ratio=2), - failing_color="tab:red", - complete_style=Style.parse("rgb(31,119,180)"), # tab:blue - finished_style=Style.parse("rgb(31,119,180)"), # tab:blue - ), - *columns, - console=Console(theme=progressbar_theme), - disable=not progressbar, - include_headers=True, - )