Skip to content

Add marimo progress bar for pymc sampler #7883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 149 additions & 47 deletions pymc/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, Protocol, cast

from rich.box import SIMPLE_HEAD
from rich.console import Console
Expand Down Expand Up @@ -192,6 +192,143 @@ def callbacks(self, task: "Task"):
self.finished_style = self.default_finished_style


class ProgressTask(Protocol):
"""A protocol for a task in a progress bar.
This protocol defines the expected interface for tasks that can be added to a progress bar.
"""

@property
def elapsed(self):
"""Get the elapsed time for this task."""


class ProgressBar(Protocol):
@property
def tasks(self) -> list[ProgressTask]:
"""Get the tasks in the progress bar."""

def add_task(self, *args, **kwargs) -> ProgressTask | None:
"""Add a task to the progress bar."""

def update(self, task_id, **kwargs):
"""Update the task with the given ID with the provided keyword arguments."""

def __enter__(self):
"""Enter the context manager."""

def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit the context manager."""


def compute_draw_speed(elapsed, draws):
speed = draws / max(elapsed, 1e-6)

if speed > 1 or speed == 0:
unit = "draws/s"
else:
unit = "s/draws"
speed = 1 / speed

return speed, unit


def create_rich_progress_bar(full_stats, step_columns, progressbar, progressbar_theme):
Comment on lines +224 to +236
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why these aren't methods anymore? It's not like they are used outside of ProgressBarManager (compute_draw_speed was static method anyway)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ProgressBarManager seemed to be doing too much. The second one was specific to the rich implementation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still seems like a function that would ever only be called by ProgressBarManager, the inputs are way too specific

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the object that is created is also called very specifically by ProgressBarManager

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally it would be passed in as many of the init kwargs of the ProgressBarManager are specific to this type of ProgressBar

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be passed in? I was just arguing against this function being moved out of the class, because it seems tightly coupled to it anyway

columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))]

if full_stats:
columns += step_columns

columns += [
TextColumn(
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
table_column=Column("Sampling Speed", ratio=1),
),
TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)),
TimeRemainingColumn(table_column=Column("Remaining", ratio=1)),
]

return CustomProgress(
RecolorOnFailureBarColumn(
table_column=Column("Progress", ratio=2),
failing_color="tab:red",
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue
),
*columns,
console=Console(theme=progressbar_theme),
disable=not progressbar,
include_headers=True,
)


class MarimoProgressTask:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs

@property
def chain_idx(self) -> int:
return self.kwargs.get("chain_idx", 0)

@property
def total(self):
return self.kwargs.get("total", 0)

@property
def elapsed(self):
return self.kwargs.get("elapsed", 0)


class MarimoProgressBar:
def __init__(self) -> None:
self.tasks: list[ProgressTask] = []
self.divergences: dict[int, int] = {}

def __enter__(self):
"""Enter the context manager."""
import marimo as mo

total_draws = (self.tasks[0].total + 1) * len(self.tasks)

self.bar = mo.status.progress_bar(total=total_draws, title="Sampling PyMC model")

def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit the context manager."""
self.bar._finish()

def add_task(self, *args, **kwargs):
"""Add a task to the progress bar."""
task = MarimoProgressTask(*args, **kwargs)
self.tasks.append(task)
return task

def update(self, task_id, **kwargs):
"""Update the task with the given ID with the provided keyword arguments."""
if self.bar.progress.current >= cast(int, self.bar.progress.total):
return

self.divergences[task_id.chain_idx] = kwargs.get("divergences", 0)

total_divergences = sum(self.divergences.values())

update_kwargs = {}
if total_divergences > 0:
word = "draws" if total_divergences > 1 else "draw"
update_kwargs["subtitle"] = f"{total_divergences} diverging {word}"

self.bar.progress.update(**update_kwargs)
Comment on lines +306 to +320
Copy link
Member

@ricardoV94 ricardoV94 Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We changed the logic to not be obsessed with NUTS.

There is now a failing variable that the step methods report back, and that's passed to update to indicate sampling is failing. Also each step decides what is a meaningful stat to report.

This reverts back to over-specializing for nuts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does failing report the number of divergences? I can remove this logic and just not have a subtitle which would be specific to NUTS

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failing is just binary, we use for coloring. The relevant summary stats are reported as part of all_step_stats, we don't treat them specially.

**all_step_stats,



def in_marimo_notebook() -> bool:
try:
import marimo as mo

return mo.running_in_notebook()
except ImportError:
return False


class ProgressBarManager:
"""Manage progress bars displayed during sampling."""

Expand Down Expand Up @@ -275,11 +412,16 @@ def __init__(

progress_columns, progress_stats = step_method._progressbar_config(chains)

self._progress = self.create_progress_bar(
progress_columns,
progressbar=progressbar,
progressbar_theme=progressbar_theme,
)
if in_marimo_notebook():
self.combined_progress = False
self._progress = MarimoProgressBar()
else:
self._progress = create_rich_progress_bar(
self.full_stats,
progress_columns,
progressbar=progressbar,
progressbar_theme=progressbar_theme,
)
self.progress_stats = progress_stats
self.update_stats_functions = step_method._make_progressbar_update_functions()

Expand Down Expand Up @@ -331,18 +473,6 @@ def _initialize_tasks(self):
for chain_idx in range(self.chains)
]

@staticmethod
def compute_draw_speed(elapsed, draws):
speed = draws / max(elapsed, 1e-6)

if speed > 1 or speed == 0:
unit = "draws/s"
else:
unit = "s/draws"
speed = 1 / speed

return speed, unit

def update(self, chain_idx, is_last, draw, tuning, stats):
if not self._show_progress:
return
Expand All @@ -353,7 +483,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
chain_idx = 0

elapsed = self._progress.tasks[chain_idx].elapsed
speed, unit = self.compute_draw_speed(elapsed, draw)
speed, unit = compute_draw_speed(elapsed, draw)

failing = False
all_step_stats = {}
Expand Down Expand Up @@ -395,31 +525,3 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
**all_step_stats,
refresh=True,
)

def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))]

if self.full_stats:
columns += step_columns

columns += [
TextColumn(
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
table_column=Column("Sampling Speed", ratio=1),
),
TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)),
TimeRemainingColumn(table_column=Column("Remaining", ratio=1)),
]

return CustomProgress(
RecolorOnFailureBarColumn(
table_column=Column("Progress", ratio=2),
failing_color="tab:red",
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue
),
*columns,
console=Console(theme=progressbar_theme),
disable=not progressbar,
include_headers=True,
)
Loading