-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add marimo progress bar for pymc sampler #7883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | ||||
# limitations under the License. | ||||
from collections.abc import Iterable | ||||
from typing import TYPE_CHECKING, Literal | ||||
from typing import TYPE_CHECKING, Literal, Protocol, cast | ||||
|
||||
from rich.box import SIMPLE_HEAD | ||||
from rich.console import Console | ||||
|
@@ -192,6 +192,143 @@ def callbacks(self, task: "Task"): | |||
self.finished_style = self.default_finished_style | ||||
|
||||
|
||||
class ProgressTask(Protocol): | ||||
"""A protocol for a task in a progress bar. | ||||
This protocol defines the expected interface for tasks that can be added to a progress bar. | ||||
""" | ||||
|
||||
@property | ||||
def elapsed(self): | ||||
"""Get the elapsed time for this task.""" | ||||
|
||||
|
||||
class ProgressBar(Protocol): | ||||
@property | ||||
def tasks(self) -> list[ProgressTask]: | ||||
"""Get the tasks in the progress bar.""" | ||||
|
||||
def add_task(self, *args, **kwargs) -> ProgressTask | None: | ||||
"""Add a task to the progress bar.""" | ||||
|
||||
def update(self, task_id, **kwargs): | ||||
"""Update the task with the given ID with the provided keyword arguments.""" | ||||
|
||||
def __enter__(self): | ||||
"""Enter the context manager.""" | ||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): | ||||
"""Exit the context manager.""" | ||||
|
||||
|
||||
def compute_draw_speed(elapsed, draws): | ||||
speed = draws / max(elapsed, 1e-6) | ||||
|
||||
if speed > 1 or speed == 0: | ||||
unit = "draws/s" | ||||
else: | ||||
unit = "s/draws" | ||||
speed = 1 / speed | ||||
|
||||
return speed, unit | ||||
|
||||
|
||||
def create_rich_progress_bar(full_stats, step_columns, progressbar, progressbar_theme): | ||||
columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] | ||||
|
||||
if full_stats: | ||||
columns += step_columns | ||||
|
||||
columns += [ | ||||
TextColumn( | ||||
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", | ||||
table_column=Column("Sampling Speed", ratio=1), | ||||
), | ||||
TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), | ||||
TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), | ||||
] | ||||
|
||||
return CustomProgress( | ||||
RecolorOnFailureBarColumn( | ||||
table_column=Column("Progress", ratio=2), | ||||
failing_color="tab:red", | ||||
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue | ||||
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue | ||||
), | ||||
*columns, | ||||
console=Console(theme=progressbar_theme), | ||||
disable=not progressbar, | ||||
include_headers=True, | ||||
) | ||||
|
||||
|
||||
class MarimoProgressTask: | ||||
def __init__(self, *args, **kwargs): | ||||
self.args = args | ||||
self.kwargs = kwargs | ||||
|
||||
@property | ||||
def chain_idx(self) -> int: | ||||
return self.kwargs.get("chain_idx", 0) | ||||
|
||||
@property | ||||
def total(self): | ||||
return self.kwargs.get("total", 0) | ||||
|
||||
@property | ||||
def elapsed(self): | ||||
return self.kwargs.get("elapsed", 0) | ||||
|
||||
|
||||
class MarimoProgressBar: | ||||
def __init__(self) -> None: | ||||
self.tasks: list[ProgressTask] = [] | ||||
self.divergences: dict[int, int] = {} | ||||
|
||||
def __enter__(self): | ||||
"""Enter the context manager.""" | ||||
import marimo as mo | ||||
|
||||
total_draws = (self.tasks[0].total + 1) * len(self.tasks) | ||||
|
||||
self.bar = mo.status.progress_bar(total=total_draws, title="Sampling PyMC model") | ||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): | ||||
"""Exit the context manager.""" | ||||
self.bar._finish() | ||||
|
||||
def add_task(self, *args, **kwargs): | ||||
"""Add a task to the progress bar.""" | ||||
task = MarimoProgressTask(*args, **kwargs) | ||||
self.tasks.append(task) | ||||
return task | ||||
|
||||
def update(self, task_id, **kwargs): | ||||
"""Update the task with the given ID with the provided keyword arguments.""" | ||||
if self.bar.progress.current >= cast(int, self.bar.progress.total): | ||||
return | ||||
|
||||
self.divergences[task_id.chain_idx] = kwargs.get("divergences", 0) | ||||
|
||||
total_divergences = sum(self.divergences.values()) | ||||
|
||||
update_kwargs = {} | ||||
if total_divergences > 0: | ||||
word = "draws" if total_divergences > 1 else "draw" | ||||
update_kwargs["subtitle"] = f"{total_divergences} diverging {word}" | ||||
|
||||
self.bar.progress.update(**update_kwargs) | ||||
Comment on lines
+306
to
+320
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We changed the logic to not be obsessed with NUTS. There is now a This reverts back to over-specializing for nuts There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does failing report the number of divergences? I can remove this logic and just not have a subtitle which would be specific to NUTS There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Failing is just binary, we use for coloring. The relevant summary stats are reported as part of Line 517 in 886b584
|
||||
|
||||
|
||||
def in_marimo_notebook() -> bool: | ||||
try: | ||||
import marimo as mo | ||||
|
||||
return mo.running_in_notebook() | ||||
except ImportError: | ||||
return False | ||||
|
||||
|
||||
class ProgressBarManager: | ||||
"""Manage progress bars displayed during sampling.""" | ||||
|
||||
|
@@ -275,11 +412,16 @@ def __init__( | |||
|
||||
progress_columns, progress_stats = step_method._progressbar_config(chains) | ||||
|
||||
self._progress = self.create_progress_bar( | ||||
progress_columns, | ||||
progressbar=progressbar, | ||||
progressbar_theme=progressbar_theme, | ||||
) | ||||
if in_marimo_notebook(): | ||||
self.combined_progress = False | ||||
self._progress = MarimoProgressBar() | ||||
else: | ||||
self._progress = create_rich_progress_bar( | ||||
self.full_stats, | ||||
progress_columns, | ||||
progressbar=progressbar, | ||||
progressbar_theme=progressbar_theme, | ||||
) | ||||
self.progress_stats = progress_stats | ||||
self.update_stats_functions = step_method._make_progressbar_update_functions() | ||||
|
||||
|
@@ -331,18 +473,6 @@ def _initialize_tasks(self): | |||
for chain_idx in range(self.chains) | ||||
] | ||||
|
||||
@staticmethod | ||||
def compute_draw_speed(elapsed, draws): | ||||
speed = draws / max(elapsed, 1e-6) | ||||
|
||||
if speed > 1 or speed == 0: | ||||
unit = "draws/s" | ||||
else: | ||||
unit = "s/draws" | ||||
speed = 1 / speed | ||||
|
||||
return speed, unit | ||||
|
||||
def update(self, chain_idx, is_last, draw, tuning, stats): | ||||
if not self._show_progress: | ||||
return | ||||
|
@@ -353,7 +483,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): | |||
chain_idx = 0 | ||||
|
||||
elapsed = self._progress.tasks[chain_idx].elapsed | ||||
speed, unit = self.compute_draw_speed(elapsed, draw) | ||||
speed, unit = compute_draw_speed(elapsed, draw) | ||||
|
||||
failing = False | ||||
all_step_stats = {} | ||||
|
@@ -395,31 +525,3 @@ def update(self, chain_idx, is_last, draw, tuning, stats): | |||
**all_step_stats, | ||||
refresh=True, | ||||
) | ||||
|
||||
def create_progress_bar(self, step_columns, progressbar, progressbar_theme): | ||||
columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] | ||||
|
||||
if self.full_stats: | ||||
columns += step_columns | ||||
|
||||
columns += [ | ||||
TextColumn( | ||||
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", | ||||
table_column=Column("Sampling Speed", ratio=1), | ||||
), | ||||
TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), | ||||
TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), | ||||
] | ||||
|
||||
return CustomProgress( | ||||
RecolorOnFailureBarColumn( | ||||
table_column=Column("Progress", ratio=2), | ||||
failing_color="tab:red", | ||||
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue | ||||
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue | ||||
), | ||||
*columns, | ||||
console=Console(theme=progressbar_theme), | ||||
disable=not progressbar, | ||||
include_headers=True, | ||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why these aren't methods anymore? It's not like they are used outside of ProgressBarManager (compute_draw_speed was static method anyway)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ProgressBarManager seemed to be doing too much. The second one was specific to the rich implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still seems like a function that would ever only be called by ProgressBarManager, the inputs are way too specific
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And the object that is created is also called very specifically by ProgressBarManager
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ideally it would be passed in as many of the init kwargs of the ProgressBarManager are specific to this type of ProgressBar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be passed in? I was just arguing against this function being moved out of the class, because it seems tightly coupled to it anyway