diff --git a/src/tbp/interactive/colors.py b/src/tbp/interactive/colors.py index aea90a3..5e27a86 100644 --- a/src/tbp/interactive/colors.py +++ b/src/tbp/interactive/colors.py @@ -35,22 +35,27 @@ class Palette: available names. """ - # Primary Colors + # Primary colors indigo: str = "#2f2b5c" numenta_blue: str = "#00a0df" - # Secondary Colors + # Secondary colors bossanova: str = "#5c315f" vivid_violet: str = "#86308b" blue_violet: str = "#655eb2" amethyst: str = "#915acc" - # Accent Colors/Shades + # Accent colors/shades rich_black: str = "#000000" charcoal: str = "#3f3f3f" link_water: str = "#dfe6f5" - # ---------- Internal helper ---------- + # Scientific notations colors + pink: str = "#f737bd" + purple: str = "#5d11bf" + green: str = "#008e43" + gold: str = "#ffbe31" + @classmethod def _validate(cls, name: str) -> str: if not hasattr(cls, name): @@ -62,7 +67,6 @@ def _validate(cls, name: str) -> str: raise KeyError(msg) return getattr(cls, name) - # ---------- Public API ---------- @classmethod def as_hex(cls, name: str) -> Color: """Return the raw hex string for a color name.""" diff --git a/src/tbp/interactive/events.py b/src/tbp/interactive/events.py new file mode 100644 index 0000000..b9ed582 --- /dev/null +++ b/src/tbp/interactive/events.py @@ -0,0 +1,27 @@ +# Copyright 2025 Thousand Brains Project +# +# Copyright may exist in Contributors' modifications +# and/or contributions to the work. +# +# Use of this source code is governed by the MIT +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +from dataclasses import dataclass + + +@dataclass +class EventSpec: + """Specification for an Event to be defined as a WidgetUpdater callback trigger. + + Attributes: + trigger: Event trigger name (e.g., KeyPressed) + name: Event name field in Vedo `event.name` (e.g., keypress). + required: Whether this event is required for the callback trigger. If + True, the updater will not call the callback until a message for this + topic arrives. + """ + + trigger: str + name: str + required: bool = True diff --git a/src/tbp/interactive/scopes.py b/src/tbp/interactive/scopes.py new file mode 100644 index 0000000..4e121bb --- /dev/null +++ b/src/tbp/interactive/scopes.py @@ -0,0 +1,102 @@ +# Copyright 2025 Thousand Brains Project +# +# Copyright may exist in Contributors' modifications +# and/or contributions to the work. +# +# Use of this source code is governed by the MIT +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + + +from __future__ import annotations + +from typing import Any + +from vedo import Plotter + +from tbp.interactive.widgets import Widget + + +class ScopeViewer: + """Controls widget visibility using numeric keypress scopes. + + Behavior summary: + - Scope 0: + * If at least one widget is hidden -> show all widgets. + * Else (all visible) -> hide ALL widgets. + - Scope k (1..9): + * Toggle that scope on/off. + * During a toggle off, a widget may remain visible if any other active scope + includes it. + + The widgets themselves decide how to hide/show internally using + their .on() / .off() visibility handlers. + """ + + def __init__(self, plotter: Plotter, widgets: dict[str, Widget]): + self.plotter = plotter + self.widgets = widgets + + self.scope_to_widgets: dict[int, set[str]] = {} + + # Build scope map from each widget's `scopes` list. + for name, widget in widgets.items(): + for s in widget.scopes: + if s not in self.scope_to_widgets: + self.scope_to_widgets[s] = set() + self.scope_to_widgets[s].add(name) + + self.active_scopes: set[int] = set(self.scope_to_widgets.keys()) + self.plotter.add_callback("KeyPress", self._on_keypress) + + def _on_keypress(self, event: Any) -> None: + key = getattr(event, "keypress", None) + if not key or not key.isdigit(): + return + + self.toggle_scope(int(key)) + self.plotter.render() + + def toggle_scope(self, scope_id: int) -> None: + """Toggles a specific scope by its id.""" + if scope_id == 0: + return self._toggle_all() + + if scope_id not in self.scope_to_widgets: + return None + + if scope_id in self.active_scopes: + self.active_scopes.remove(scope_id) + else: + self.active_scopes.add(scope_id) + + self._apply_scope_visibility() + + def _toggle_all(self) -> None: + """Toggles all widgets on/off.""" + any_hidden = any(not w.is_visible for w in self.widgets.values()) + + if any_hidden: + for w in self.widgets.values(): + w.on() + self.active_scopes = set(self.scope_to_widgets.keys()) + else: + for w in self.widgets.values(): + w.off() + self.active_scopes.clear() + + def _apply_scope_visibility(self) -> None: + # If nothing is active, hide everything + if not self.active_scopes: + for w in self.widgets.values(): + w.off() + return + + for widget in self.widgets.values(): + # Does this widget belong to ANY active scope? + belongs = any(s in self.active_scopes for s in widget.scopes) + + if belongs and not widget.is_visible: + widget.on() + elif not belongs and widget.is_visible: + widget.off() diff --git a/src/tbp/interactive/widget_updaters.py b/src/tbp/interactive/widget_updaters.py index 8e8a9b0..a5745b0 100644 --- a/src/tbp/interactive/widget_updaters.py +++ b/src/tbp/interactive/widget_updaters.py @@ -11,6 +11,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Protocol, runtime_checkable +from tbp.interactive.events import EventSpec from tbp.interactive.topics import TopicMessage, TopicSpec if TYPE_CHECKING: @@ -56,7 +57,7 @@ class WidgetUpdater[WidgetT]: `topics`. """ - topics: Iterable[TopicSpec] + topics: Iterable[TopicSpec | EventSpec] callback: Callable[ [WidgetT | None, list[TopicMessage]], tuple[WidgetT | None, bool] ] @@ -86,6 +87,17 @@ def accepts(self, msg: TopicMessage) -> bool: """ return any(spec.name == msg.name for spec in self.topics) + def expire_topic(self, topic_name: str) -> None: + """Expire (remove) the stored message for a given topic. + + After expiration, the updater may require a new message for that topic + before becoming ready again. + + Args: + topic_name: The topic whose message should be invalidated. + """ + self._inbox.pop(topic_name, None) + def __call__( self, widget: WidgetT | None, msg: TopicMessage ) -> tuple[WidgetT | None, bool]: diff --git a/src/tbp/interactive/widgets.py b/src/tbp/interactive/widgets.py index ad0737e..cf3481e 100644 --- a/src/tbp/interactive/widgets.py +++ b/src/tbp/interactive/widgets.py @@ -12,9 +12,10 @@ from typing import Any from pubsub.core import Publisher -from vedo import Button, Slider2D +from vedo import Button, Plotter, Slider2D -from tbp.interactive.topics import TopicMessage +from tbp.interactive.events import EventSpec +from tbp.interactive.topics import TopicMessage, TopicSpec from tbp.interactive.utils import VtkDebounceScheduler from tbp.interactive.widget_ops import ( HasStateToMessages, @@ -126,6 +127,7 @@ class Widget[WidgetT, StateT]: widget_ops: Composed functionality for get/set/add/remove operations. debounce_sec: Debounce delay in seconds for change publications. dedupe: If True, skip publishing unchanged values. + scopes: List of integer scopes this widget belongs to. Runtime Attributes: widget: The created widget instance. @@ -145,11 +147,13 @@ def __init__( | HasUpdaters[WidgetT] | WidgetOpsProto ), + scopes: list[int] | None, bus: Publisher, scheduler: VtkDebounceScheduler, debounce_sec: float = 0.25, dedupe: bool = True, ): + self.scopes = scopes or [] self.bus = bus self.scheduler = scheduler self.debounce_sec = debounce_sec @@ -160,11 +164,19 @@ def __init__( self.state: StateT | None = None self.last_published_state: StateT | None = None self._sched_key = object() # hashable unique key + self._visible: bool = True if isinstance(self.widget_ops, HasUpdaters): for topic in self.updater_topics: self.bus.subscribe(self._on_update_topic, topic) + for event in self.updater_events: + self.widget_ops.plotter.add_callback(event, self._on_event) + + @property + def is_visible(self) -> bool: + return self._visible + @property def updater_topics(self) -> set[str]: """Names of topics that can update this widget with `WidgetUpdater`. @@ -175,7 +187,29 @@ def updater_topics(self) -> set[str]: if not isinstance(self.widget_ops, HasUpdaters): return set() - return {t.name for u in self.widget_ops.updaters for t in u.topics} + return { + t.name + for u in self.widget_ops.updaters + for t in u.topics + if isinstance(t, TopicSpec) + } + + @property + def updater_events(self) -> set[str]: + """Names of events that can update this widget with `WidgetUpdater`. + + Returns: + A set of event triggers the widget listens to for updates. + """ + if not isinstance(self.widget_ops, HasUpdaters): + return set() + + return { + e.name + for u in self.widget_ops.updaters + for e in u.topics + if isinstance(e, EventSpec) + } def add(self) -> None: """Create the widget and register the debounce callback. @@ -239,6 +273,24 @@ def _on_change(self, widget: WidgetT, _event: str) -> None: self.state = self.extract_state() self.scheduler.schedule_once(self._sched_key, self.debounce_sec) + def _on_event(self, event) -> None: + if not isinstance(self.widget_ops, HasUpdaters): + return + + msg = TopicMessage(name=event.name, value=event) + + for updater in self.widget_ops.updaters: + self.widget, publish_state = updater(self.widget, msg) + if publish_state: + self.state = self.extract_state() + self.scheduler.schedule_once(self._sched_key, self.debounce_sec) + + if hasattr(self.widget_ops, "plotter") and isinstance( + self.widget_ops.plotter, Plotter + ): + self._set_visibility(self.is_visible) + self.widget_ops.plotter.render() + def _on_update_topic(self, msg: TopicMessage): if not isinstance(self.widget_ops, HasUpdaters): return @@ -249,6 +301,12 @@ def _on_update_topic(self, msg: TopicMessage): self.state = self.extract_state() self.scheduler.schedule_once(self._sched_key, self.debounce_sec) + if hasattr(self.widget_ops, "plotter") and isinstance( + self.widget_ops.plotter, Plotter + ): + self._set_visibility(self.is_visible) + self.widget_ops.plotter.render() + def _on_debounce_fire(self) -> None: """Handler fired by the scheduler to publish debounced state.""" self._publish(self.extract_state()) @@ -267,3 +325,47 @@ def _publish(self, state: StateT | None) -> None: self.bus.sendMessage(msg.name, msg=msg) self.last_published_state = state + + def on(self) -> None: + """Make this widget visually 'on'.""" + self._set_visibility(True) + + def off(self) -> None: + """Make this widget visually 'off'.""" + self._set_visibility(False) + + def _supports_toggle(self, obj: object) -> bool: + return hasattr(obj, "on") and hasattr(obj, "off") + + def _toggle_single(self, obj: object | None, visible: bool) -> None: + """Toggle a single object if it supports .on() / .off().""" + if obj is None: + return + if not self._supports_toggle(obj): + return + + if visible: + obj.on() + else: + obj.off() + + def _toggle_container(self, attr: object, visible: bool) -> None: + """Toggle first-level items inside common container types.""" + if isinstance(attr, (list, tuple, set)): + for item in attr: + self._toggle_single(item, visible) + + elif isinstance(attr, dict): + for item in attr.values(): + self._toggle_single(item, visible) + + def _set_visibility(self, visible: bool) -> None: + self._visible = visible + + # Iterate over widget_ops attributes to find vedo-style objects + # that also support .on() / .off(). This lets you attach extra + # meshes, paths, spheres, etc., and have them follow scope visibility. + self._toggle_single(self.widget, visible) + for attr in vars(self.widget_ops).values(): + self._toggle_single(attr, visible) + self._toggle_container(attr, visible) diff --git a/src/tbp/plot/plots/interactive_hypothesis_space_correlation.py b/src/tbp/plot/plots/interactive_hypothesis_space_correlation.py index 83cb288..6ef920d 100644 --- a/src/tbp/plot/plots/interactive_hypothesis_space_correlation.py +++ b/src/tbp/plot/plots/interactive_hypothesis_space_correlation.py @@ -19,16 +19,30 @@ import numpy as np import pandas as pd import seaborn as sns +import vedo from pandas import DataFrame, Series from pubsub.core import Publisher -from vedo import Button, Circle, Image, Line, Mesh, Plotter, Slider2D, Sphere, Text2D +from vedo import ( + Button, + Circle, + Image, + Line, + Mesh, + Plotter, + Slider2D, + Sphere, + Text2D, +) +from tbp.interactive.colors import Palette from tbp.interactive.data import ( DataLocator, DataLocatorStep, DataParser, YCBMeshLoader, ) +from tbp.interactive.events import EventSpec +from tbp.interactive.scopes import ScopeViewer from tbp.interactive.topics import TopicMessage, TopicSpec from tbp.interactive.utils import ( Bounds, @@ -49,15 +63,24 @@ logger = logging.getLogger(__name__) - -HUE_PALETTE = { - "Added": "#66c2a5", - "Removed": "#fc8d62", - "Maintained": "#8da0cb", - "Evidence": "#1f77b4", - "Slope": "#ff7f0e", +vedo.settings.enable_default_keyboard_callbacks = False + +COLOR_PALETTE = { + "Maintained": Palette.as_hex("numenta_blue"), + "Removed": Palette.as_hex("gold"), + "Added": Palette.as_hex("purple"), + "Selected": Palette.as_hex("pink"), + "Highlighted": Palette.as_hex("green"), + "Primary": Palette.as_hex("numenta_blue"), + "Secondary": Palette.as_hex("purple"), + "Accent": Palette.as_hex("charcoal"), + "Accent2": Palette.as_hex("link_water"), + "Accent3": Palette.as_hex("rich_black"), } +FONT = "Arial" +FONT_SIZE = 30 + class EpisodeSliderWidgetOps: """WidgetOps implementation for an Episode slider. @@ -84,6 +107,7 @@ def __init__(self, plotter: Plotter, data_parser: DataParser) -> None: "xmax": 10, "value": 0, "pos": [(0.1, 0.2), (0.7, 0.2)], + "font": FONT, "title": "Episode", } @@ -198,6 +222,7 @@ def __init__(self, plotter: Plotter, data_parser: DataParser) -> None: "xmax": 10, "value": 0, "pos": [(0.1, 0.1), (0.7, 0.1)], + "font": FONT, "title": "Step", } self._locators = self.create_locators() @@ -307,7 +332,6 @@ def update_slider_range( # set slider value back to zero self.set_state(widget, 0) - self.plotter.at(0).render() return widget, True @@ -336,23 +360,53 @@ def __init__( self.ycb_loader = ycb_loader self.updaters = [ WidgetUpdater( - topics=[TopicSpec("episode_number", required=True)], + topics=[ + TopicSpec("episode_number", required=True), + ], callback=self.update_mesh, ), WidgetUpdater( topics=[ TopicSpec("episode_number", required=True), TopicSpec("step_number", required=True), + EventSpec("KeyPressed", "KeyPressEvent", required=False), + ], + callback=self.update_agent, + ), + WidgetUpdater( + topics=[ + EventSpec("KeyPressed", "KeyPressEvent", required=True), ], - callback=self.update_sensor, + callback=self.update_transparency, ), ] self._locators = self.create_locators() self.gaze_line: Line | None = None - self.sensor_sphere: Sphere | None = None + self.agent_sphere: Sphere | None = None + self.text_label: Text2D = Text2D( + txt="Ground Truth", pos="top-center", font=FONT + ) + + # Path visibility flags + self.mesh_transparency: float = 0.0 + self.show_agent_past: bool = False + self.show_agent_future: bool = False + self.show_patch_past: bool = False + self.show_patch_future: bool = False + + # Path geometry + self.agent_past_spheres: list[Sphere] = [] + self.agent_past_line: Line | None = None + self.agent_future_spheres: list[Sphere] = [] + self.agent_future_line: Line | None = None + + self.patch_past_spheres: list[Sphere] = [] + self.patch_past_line: Line | None = None + self.patch_future_spheres: list[Sphere] = [] + self.patch_future_line: Line | None = None - self.plotter.at(1).add(Text2D(txt="Ground Truth", pos="top-center")) + self.plotter.at(1).add(self.text_label) def create_locators(self) -> dict[str, DataLocator]: """Create and return data locators used by this widget. @@ -376,7 +430,7 @@ def create_locators(self) -> dict[str, DataLocator]: ] ) - locators["sensor_location"] = DataLocator( + locators["agent_location"] = DataLocator( path=[ DataLocatorStep.key(name="episode"), DataLocatorStep.key(name="system", value="motor_system"), @@ -410,6 +464,42 @@ def remove(self, widget: Mesh) -> None: self.plotter.at(1).remove(widget) self.plotter.at(1).render() + def _clear_agent_paths(self) -> None: + for s in self.agent_past_spheres: + self.plotter.at(1).remove(s) + for s in self.agent_future_spheres: + self.plotter.at(1).remove(s) + + self.agent_past_spheres.clear() + self.agent_future_spheres.clear() + + if self.agent_past_line is not None: + self.plotter.at(1).remove(self.agent_past_line) + self.agent_past_line = None + if self.agent_future_line is not None: + self.plotter.at(1).remove(self.agent_future_line) + self.agent_future_line = None + + def _clear_patch_paths(self) -> None: + for s in self.patch_past_spheres: + self.plotter.at(1).remove(s) + for s in self.patch_future_spheres: + self.plotter.at(1).remove(s) + + self.patch_past_spheres.clear() + self.patch_future_spheres.clear() + + if self.patch_past_line is not None: + self.plotter.at(1).remove(self.patch_past_line) + self.patch_past_line = None + if self.patch_future_line is not None: + self.plotter.at(1).remove(self.patch_future_line) + self.patch_future_line = None + + def _clear_all_paths(self) -> None: + self._clear_agent_paths() + self._clear_patch_paths() + def update_mesh(self, widget: Mesh, msgs: list[TopicMessage]) -> tuple[Mesh, bool]: """Update the target mesh when the episode changes. @@ -440,15 +530,13 @@ def update_mesh(self, widget: Mesh, msgs: list[TopicMessage]) -> tuple[Mesh, boo widget.rotate_y(target_rot[1]) widget.rotate_z(target_rot[2]) widget.shift(*target_pos) + widget.alpha(1.0 - self.mesh_transparency) self.plotter.at(1).add(widget) - self.plotter.at(1).render() return widget, False - def update_sensor( - self, widget: None, msgs: list[TopicMessage] - ) -> tuple[None, bool]: + def update_agent(self, widget: None, msgs: list[TopicMessage]) -> tuple[None, bool]: msgs_dict = {msg.name: msg.value for msg in msgs} episode_number = msgs_dict["episode_number"] step_number = msgs_dict["step_number"] @@ -458,8 +546,8 @@ def update_sensor( ) mapping = np.flatnonzero(steps_mask) - sensor_pos = self.data_parser.extract( - self._locators["sensor_location"], + agent_pos = self.data_parser.extract( + self._locators["agent_location"], episode=str(episode_number), sm_step=int(mapping[step_number]), ) @@ -470,17 +558,162 @@ def update_sensor( step=step_number, ) - if self.sensor_sphere is None: - self.sensor_sphere = Sphere(pos=sensor_pos, r=0.002) - self.plotter.at(1).add(self.sensor_sphere) - self.sensor_sphere.pos(sensor_pos) + if self.agent_sphere is None: + self.agent_sphere = Sphere( + pos=agent_pos, + r=0.004, + c=COLOR_PALETTE["Secondary"], + ) + + self.plotter.at(1).add(self.agent_sphere) + self.agent_sphere.pos(agent_pos) if self.gaze_line is None: - self.gaze_line = Line(sensor_pos, patch_pos, c="black", lw=2) + self.gaze_line = Line( + agent_pos, patch_pos, c=COLOR_PALETTE["Accent3"], lw=4 + ) self.plotter.at(1).add(self.gaze_line) - self.gaze_line.points = [sensor_pos, patch_pos] + self.gaze_line.points = [agent_pos, patch_pos] + + self._clear_all_paths() + key_event = msgs_dict.get("KeyPressEvent", None) + if key_event is not None and getattr(key_event, "at", None) == 1: + key = getattr(key_event, "keypress", None) + + if key == "a": + self.show_agent_past = not self.show_agent_past + elif key == "A": + self.show_agent_future = not self.show_agent_future + elif key == "s": + self.show_patch_past = not self.show_patch_past + elif key == "S": + self.show_patch_future = not self.show_patch_future + elif key == "d": + self.show_agent_past = False + self.show_agent_future = False + self.show_patch_past = False + self.show_patch_future = False + + # expire the event so it only affects this call + self.updaters[1].expire_topic("KeyPressEvent") + + max_idx = len(mapping) - 1 + curr_idx = int(np.clip(step_number, 0, max_idx)) + + if self.show_agent_past or self.show_agent_future: + self._rebuild_agent_paths(episode_number, mapping, curr_idx) + + if self.show_patch_past or self.show_patch_future: + self._rebuild_patch_paths(episode_number, len(mapping), curr_idx) + + return widget, False + + def _rebuild_agent_paths( + self, + episode_number: int, + mapping: np.ndarray, + curr_idx: int, + ) -> None: + """Rebuild past/future agent paths.""" + # Collect all agent positions + agent_positions: list[np.ndarray] = [] + for k in range(len(mapping)): + pos = self.data_parser.extract( + self._locators["agent_location"], + episode=str(episode_number), + sm_step=int(mapping[k]), + ) + agent_positions.append(pos) + + if self.show_agent_past and agent_positions: + past_pts = agent_positions[: curr_idx + 1] + for p in past_pts: + s = Sphere(pos=p, r=0.002, c=COLOR_PALETTE["Secondary"]) + self.plotter.at(1).add(s) + self.agent_past_spheres.append(s) + if len(past_pts) >= 2: + self.agent_past_line = Line( + past_pts, c=COLOR_PALETTE["Secondary"], lw=1 + ) + self.plotter.at(1).add(self.agent_past_line) + + if ( + self.show_agent_future + and agent_positions + and curr_idx < len(agent_positions) - 1 + ): + future_pts = agent_positions[curr_idx + 1 :] + for p in future_pts: + s = Sphere(pos=p, r=0.002, c=COLOR_PALETTE["Secondary"]) + self.plotter.at(1).add(s) + self.agent_future_spheres.append(s) + if len(future_pts) >= 2: + self.agent_future_line = Line( + future_pts, c=COLOR_PALETTE["Secondary"], lw=1 + ) + self.plotter.at(1).add(self.agent_future_line) + + def _rebuild_patch_paths( + self, + episode_number: int, + num_steps: int, + curr_idx: int, + ) -> None: + """Rebuild past/future patch (sensor) paths.""" + patch_positions: list[np.ndarray] = [] + for k in range(num_steps): + pos = self.data_parser.extract( + self._locators["patch_location"], + episode=str(episode_number), + step=k, + ) + patch_positions.append(pos) + + if self.show_patch_past and patch_positions: + past_pts = patch_positions[: curr_idx + 1] + for p in past_pts: + s = Sphere(pos=p, r=0.002, c=COLOR_PALETTE["Accent3"]) + self.plotter.at(1).add(s) + self.patch_past_spheres.append(s) + if len(past_pts) >= 2: + self.patch_past_line = Line(past_pts, c=COLOR_PALETTE["Accent3"], lw=1) + self.plotter.at(1).add(self.patch_past_line) + + if ( + self.show_patch_future + and patch_positions + and curr_idx < len(patch_positions) - 1 + ): + future_pts = patch_positions[curr_idx + 1 :] + for p in future_pts: + s = Sphere(pos=p, r=0.002, c=COLOR_PALETTE["Accent2"]) + self.plotter.at(1).add(s) + self.patch_future_spheres.append(s) + if len(future_pts) >= 2: + self.patch_future_line = Line( + future_pts, c=COLOR_PALETTE["Accent2"], lw=1 + ) + self.plotter.at(1).add(self.patch_future_line) + + def update_transparency( + self, widget: None, msgs: list[TopicMessage] + ) -> tuple[None, bool]: + msgs_dict = {msg.name: msg.value for msg in msgs} + + key_event = msgs_dict.get("KeyPressEvent", None) + if key_event is not None and getattr(key_event, "at", None) == 1: + key = getattr(key_event, "keypress", None) - self.plotter.at(1).render() + if key == "Left": + self.mesh_transparency -= 0.5 + elif key == "Right": + self.mesh_transparency += 0.5 + + self.mesh_transparency = float(np.clip(self.mesh_transparency, 0.0, 1.0)) + if widget is not None: + widget.alpha(1.0 - self.mesh_transparency) + + self.updaters[2].expire_topic("KeyPressEvent") return widget, False @@ -500,13 +733,13 @@ def __init__(self, plotter: Plotter): self.plotter = plotter self._add_kwargs = { - "pos": (0.85, 0.2), + "pos": (0.9, 0.2), "states": ["Primary Target"], - "c": "w", - "bc": "dg", - "size": 30, - "font": "Calco", - "bold": True, + "c": ["w"], + "bc": [COLOR_PALETTE["Primary"]], + "size": FONT_SIZE, + "font": FONT, + "bold": False, } def add(self, callback: Callable) -> Button: @@ -552,13 +785,13 @@ def __init__(self, plotter: Plotter): self.plotter = plotter self._add_kwargs = { - "pos": (0.83, 0.13), + "pos": (0.88, 0.13), "states": ["<"], "c": ["w"], - "bc": ["dg"], - "size": 30, - "font": "Calco", - "bold": True, + "bc": [COLOR_PALETTE["Primary"]], + "size": FONT_SIZE, + "font": FONT, + "bold": False, } def add(self, callback: Callable) -> Button: @@ -604,13 +837,13 @@ def __init__(self, plotter: Plotter): self.plotter = plotter self._add_kwargs = { - "pos": (0.88, 0.13), + "pos": (0.93, 0.13), "states": [">"], "c": ["w"], - "bc": ["dg"], - "size": 30, - "font": "Calco", - "bold": True, + "bc": [COLOR_PALETTE["Primary"]], + "size": FONT_SIZE, + "font": FONT, + "bold": False, } def add(self, callback: Callable) -> Button: @@ -657,6 +890,7 @@ def __init__(self, plotter: Plotter) -> None: "xmax": 10, "value": 0, "pos": [(0.05, 0.01), (0.05, 0.3)], + "font": FONT, "title": "Age", } @@ -705,6 +939,40 @@ def state_to_messages(self, state: int) -> Iterable[TopicMessage]: return [TopicMessage(name="age_threshold", value=state)] +class TopKSliderWidgetOps: + """WidgetOps implementation for the TopK slider. + + This widget provides a slider to control the number of top-k highlighted hypotheses. + It publishes on the topic `top_k` an int value between 0 and 5. + """ + + def __init__(self, plotter: Plotter) -> None: + self.plotter = plotter + + self._add_kwargs = { + "xmin": 0, + "xmax": 5, + "value": 0, + "pos": [(0.77, 0.01), (0.77, 0.3)], + "title": "Top-K Highlighted", + "font": FONT, + } + + def add(self, callback: Callable) -> Slider2D: + widget = self.plotter.at(0).add_slider(callback, **self._add_kwargs) + self.plotter.at(0).render() + return widget + + def extract_state(self, widget: Slider2D) -> int: + return extract_slider_state(widget) + + def set_state(self, widget: Slider2D, value: int) -> None: + set_slider_state(widget, value) + + def state_to_messages(self, state: int) -> Iterable[TopicMessage]: + return [TopicMessage(name="top_k", value=state)] + + class CurrentObjectWidgetOps: """Tracks and publishes the currently selected object label. @@ -987,7 +1255,6 @@ def on_left_click(self, event): cam.SetFocalPoint(self.cam_dict["focal_point"]) cam.SetViewUp((0, 1, 0)) cam.SetClippingRange((0.01, 1000.01)) - self.plotter.at(0).render() elif event.at == 1: cam_clicked = self.plotter.renderers[1].GetActiveCamera() cam_copy = self.plotter.renderers[2].GetActiveCamera() @@ -1027,6 +1294,7 @@ def __init__(self, plotter: Plotter, data_parser: DataParser) -> None: TopicSpec("step_number", required=True), TopicSpec("current_object", required=True), TopicSpec("age_threshold", required=True), + TopicSpec("top_k", required=True), ], callback=self.update_plot, ), @@ -1046,7 +1314,7 @@ def __init__(self, plotter: Plotter, data_parser: DataParser) -> None: self.df: DataFrame self.selected_hypothesis: Series | None = None self.highlight_circle: Circle | None = None - self.mlh_circle: Circle | None = None + self.mlh_circles: list[Circle] = [] self.info_widget: Text2D | None = None def create_locators(self) -> dict[str, DataLocator]: @@ -1119,6 +1387,41 @@ def increment_step(self, episode: int, step: int) -> tuple[int, int]: return episode, step + 1 return episode + 1, 0 + def decrement_step(self, episode: int, step: int) -> tuple[int, int]: + """Compute the previous `(episode, step)` pair. + + If the current pair is the very beginning `(0, 0)`, return `(0, 0)`. + + Args: + episode: Current episode index. + step: Current step index. + + Returns: + A tuple `(prev_episode, prev_step)`. + """ + # Already at earliest step + if episode == 0 and step == 0: + return 0, 0 + + # If we're not at the start of the episode, just decrement the step + if step > 0: + return episode, step - 1 + + # If step is 0, we need to go to previous episode + prev_episode = episode - 1 + + # Find last step index in previous episode + prev_last_step = ( + len( + self.data_parser.query( + self._locators["channel"], episode=str(prev_episode) + ) + ) + - 1 + ) + + return prev_episode, prev_last_step + def generate_df(self, episode: int, step: int, graph_id: str) -> DataFrame: """Build a DataFrame of hypotheses and their stats. @@ -1148,6 +1451,7 @@ def generate_df(self, episode: int, step: int, graph_id: str) -> DataFrame: all_dfs: list[DataFrame] = [] for input_channel in input_channels: + # Current timestep data channel_data = self.data_parser.extract( self._locators["channel"], episode=str(episode), @@ -1162,36 +1466,59 @@ def generate_df(self, episode: int, step: int, graph_id: str) -> DataFrame: obj=graph_id, channel=input_channel, ) - inc_episode, inc_step = self.increment_step(episode, step) - inc_updater_data = self.data_parser.extract( + + # Previous timestep data + dec_episode, dec_step = self.decrement_step(episode, step) + dec_channel_data = self.data_parser.extract( + self._locators["channel"], + episode=str(dec_episode), + step=dec_step, + obj=graph_id, + channel=input_channel, + ) + dec_updater_data = self.data_parser.extract( self._locators["updater"], - episode=str(inc_episode), - step=inc_step, + episode=str(dec_episode), + step=dec_step, obj=graph_id, channel=input_channel, ) # Removed hypotheses - removed_ids = inc_updater_data.get("removed_ids", []) + removed_ids = updater_data.get("removed_ids", []) if len(removed_ids) > 0: df_removed = DataFrame( { "id": removed_ids, + "episode": dec_episode, + "step": dec_step, "graph_id": graph_id, - "Evidence": np.array(channel_data["evidence"])[removed_ids], - "Evidence Slope": np.array(updater_data["evidence_slopes"])[ + "Evidence": np.array(dec_channel_data["evidence"])[removed_ids], + "Evidence Slope": np.array(dec_updater_data["evidence_slopes"])[ removed_ids ], - "Rot_x": np.array(channel_data["rotations"])[removed_ids][:, 0], - "Rot_y": np.array(channel_data["rotations"])[removed_ids][:, 1], - "Rot_z": np.array(channel_data["rotations"])[removed_ids][:, 2], - "Loc_x": np.array(channel_data["locations"])[removed_ids][:, 0], - "Loc_y": np.array(channel_data["locations"])[removed_ids][:, 1], - "Loc_z": np.array(channel_data["locations"])[removed_ids][:, 2], - "Pose Error": np.array(channel_data["pose_errors"])[ + "Rot_x": np.array(dec_channel_data["rotations"])[removed_ids][ + :, 0 + ], + "Rot_y": np.array(dec_channel_data["rotations"])[removed_ids][ + :, 1 + ], + "Rot_z": np.array(dec_channel_data["rotations"])[removed_ids][ + :, 2 + ], + "Loc_x": np.array(dec_channel_data["locations"])[removed_ids][ + :, 0 + ], + "Loc_y": np.array(dec_channel_data["locations"])[removed_ids][ + :, 1 + ], + "Loc_z": np.array(dec_channel_data["locations"])[removed_ids][ + :, 2 + ], + "Pose Error": np.array(dec_channel_data["pose_errors"])[ removed_ids ], - "age": np.array(updater_data["ages"])[removed_ids], + "age": np.array(dec_updater_data["ages"])[removed_ids], "kind": "Removed", "input_channel": input_channel, } @@ -1200,11 +1527,12 @@ def generate_df(self, episode: int, step: int, graph_id: str) -> DataFrame: # Added hypotheses added_ids = updater_data.get("added_ids", []) - added_ids = sorted(set(added_ids) - set(removed_ids)) if added_ids: df_added = DataFrame( { "id": added_ids, + "episode": episode, + "step": step, "graph_id": graph_id, "Evidence": np.array(channel_data["evidence"])[added_ids], "Evidence Slope": np.array(updater_data["evidence_slopes"])[ @@ -1226,11 +1554,13 @@ def generate_df(self, episode: int, step: int, graph_id: str) -> DataFrame: # Maintained hypotheses total_ids = list(range(len(updater_data["evidence_slopes"]))) - maintained_ids = sorted(set(total_ids) - set(added_ids) - set(removed_ids)) + maintained_ids = sorted(set(total_ids) - set(added_ids)) if maintained_ids: df_maintained = DataFrame( { "id": maintained_ids, + "episode": episode, + "step": step, "graph_id": graph_id, "Evidence": np.array(channel_data["evidence"])[maintained_ids], "Evidence Slope": np.array(updater_data["evidence_slopes"])[ @@ -1264,6 +1594,29 @@ def generate_df(self, episode: int, step: int, graph_id: str) -> DataFrame: ) all_dfs.append(df_maintained) + if not all_dfs: + # No hypotheses for any input_channel at this (episode, step, graph_id) + return DataFrame( + columns=[ + "id", + "episode", + "step", + "graph_id", + "Evidence", + "Evidence Slope", + "Rot_x", + "Rot_y", + "Rot_z", + "Loc_x", + "Loc_y", + "Loc_z", + "Pose Error", + "age", + "kind", + "input_channel", + ] + ) + return pd.concat(all_dfs, ignore_index=True) def add_correlation_figure(self, x="Evidence Slope", y="Pose Error") -> Image: @@ -1277,53 +1630,58 @@ def add_correlation_figure(self, x="Evidence Slope", y="Pose Error") -> Image: The Image widget for the correlation plot. """ g = sns.JointGrid(data=self.df, x=x, y=y, height=6) + g.figure.set_dpi(400) + + if not self.df.empty: + sns.scatterplot( + data=self.df, + x=x, + y=y, + hue="kind", + ax=g.ax_joint, + s=8, + alpha=0.8, + palette=COLOR_PALETTE, + ) + + sns.kdeplot( + data=self.df, + x=x, + hue="kind", + ax=g.ax_marg_x, + fill=True, + alpha=0.2, + common_norm=False, + palette=COLOR_PALETTE, + legend=False, + warn_singular=False, + ) + + sns.kdeplot( + data=self.df, + y=y, + hue="kind", + ax=g.ax_marg_y, + fill=True, + alpha=0.2, + common_norm=False, + palette=COLOR_PALETTE, + legend=False, + warn_singular=False, + ) - sns.scatterplot( - data=self.df, - x=x, - y=y, - hue="kind", - ax=g.ax_joint, - s=8, - alpha=0.8, - palette=HUE_PALETTE, - ) - - sns.kdeplot( - data=self.df, - x=x, - hue="kind", - ax=g.ax_marg_x, - fill=True, - alpha=0.2, - common_norm=False, - palette=HUE_PALETTE, - legend=False, - ) - - sns.kdeplot( - data=self.df, - y=y, - hue="kind", - ax=g.ax_marg_y, - fill=True, - alpha=0.2, - common_norm=False, - palette=HUE_PALETTE, - legend=False, - ) - - legend = g.ax_joint.get_legend() - if legend: - legend.set_title(None) + legend = g.ax_joint.get_legend() + if legend: + legend.set_title(None) g.ax_joint.set_xlim(-2.0, 2.0) g.ax_joint.set_ylim(0, 3.25) - g.ax_joint.set_xlabel(x, labelpad=10) + g.ax_joint.set_xlabel("Recent Evidence Change", labelpad=10) g.ax_joint.set_ylabel(y, labelpad=10) g.figure.tight_layout() widget = Image(g.figure) + widget.scale(0.25) plt.close(g.figure) self.plotter.at(0).add(widget) return widget @@ -1351,57 +1709,76 @@ def get_closest_row(self, df: DataFrame, slope: float, error: float) -> Series: ) return df.loc[distances.idxmin()] - def add_info_text(self) -> None: + def add_info_text(self, obj) -> None: """Summarize hypotheses statistics from a dataframe and add to plot.""" if self.info_widget is not None: self.plotter.at(0).remove(self.info_widget) if self.df.empty: - return - - # Assume all rows share the same object name - graph_id = self.df["graph_id"].iloc[0] + text = ( + f"Object: {obj}\n" + f"Total Existing Hypotheses: 0\n" + f"Added Hypotheses: 0\n" + f"Removed Hypotheses: 0" + ) + else: + # Assume all rows share the same object name + graph_id = self.df["graph_id"].iloc[0] - # Count per kind - kind_counts = self.df["kind"].value_counts() + # Count per kind + kind_counts = self.df["kind"].value_counts() - total = len(self.df) - added = kind_counts.get("Added", 0) - removed = kind_counts.get("Removed", 0) + added = kind_counts.get("Added", 0) + removed = kind_counts.get("Removed", 0) + total = len(self.df) - removed - text = ( - f"Object: {graph_id}\n" - f"Total Existing Hypotheses: {total}\n" - f"Added Hypotheses: {added}\n" - f"To be removed Hypotheses: {removed}" - ) + text = ( + f"Object: {graph_id}\n" + f"Total Existing Hypotheses: {total}\n" + f"Added Hypotheses: {added}\n" + f"Removed Hypotheses: {removed}" + ) - self.info_widget = Text2D(txt=text, pos="top-left") + self.info_widget = Text2D(txt=text, pos="top-left", font=FONT) self.plotter.at(0).add(self.info_widget) - def add_mlh_circle(self): - """Adds the circle marker for the MLH.""" - if self.mlh_circle is not None: - self.plotter.at(0).remove(self.mlh_circle) + def add_mlh_circles(self, top_k: int) -> None: + """Adds the circle markers for the MLH.""" + for c in self.mlh_circles: + self.plotter.at(0).remove(c) + self.mlh_circles.clear() - if self.df.empty: + if self.df is None or self.df.empty: return - (slope, error) = tuple( - self.df.loc[self.df["Evidence"].idxmax(), ["Evidence Slope", "Pose Error"]] - ) + df_valid = self.df[self.df["kind"] != "Removed"].copy() + if df_valid.empty: + return + + df_valid.sort_values("Evidence", ascending=False, inplace=True) - if pd.isna(slope): + # Clamp to [0, len(df_valid)] + k = int(max(0, min(top_k, len(df_valid)))) + if k == 0: return - # Map location back to a Location3D in GUI Space - gui_location = self._coordinate_mapper.map_data_coords_to_world( - Location2D(slope, error) - ).to_3d(z=0.05) + top_rows = df_valid.head(k) + + for _, row in top_rows.iterrows(): + slope = row["Evidence Slope"] + error = row["Pose Error"] + + if pd.isna(slope) or pd.isna(error): + continue - self.mlh_circle = Circle(pos=gui_location.to_numpy(), r=3.0, res=16) - self.mlh_circle.c("green") - self.plotter.at(0).add(self.mlh_circle) + gui_location = self._coordinate_mapper.map_data_coords_to_world( + Location2D(float(slope), float(error)) + ).to_3d(z=0.05) + + circle = Circle(pos=gui_location.to_numpy(), r=3.0, res=16) + circle.c(COLOR_PALETTE["Highlighted"]) + self.plotter.at(0).add(circle) + self.mlh_circles.append(circle) def add_highlight_circle(self, gui_location: Location3D): """Adds the circle marker for the selected hypothesis. @@ -1413,7 +1790,7 @@ def add_highlight_circle(self, gui_location: Location3D): self.plotter.at(0).remove(self.highlight_circle) self.highlight_circle = Circle(pos=gui_location.to_numpy(), r=3.0, res=16) - self.highlight_circle.c("red") + self.highlight_circle.c(COLOR_PALETTE["Selected"]) self.plotter.at(0).add(self.highlight_circle) def update_plot(self, widget: Image, msgs: list[TopicMessage]) -> tuple[Any, bool]: @@ -1451,12 +1828,11 @@ def update_plot(self, widget: Image, msgs: list[TopicMessage]) -> tuple[Any, boo widget = self.add_correlation_figure() # Add info text to scene - self.add_info_text() + self.add_info_text(obj=msgs_dict["current_object"]) # Add mlh circle to scene - self.add_mlh_circle() + self.add_mlh_circles(msgs_dict["top_k"]) - self.plotter.at(0).render() return widget, True def update_selection( @@ -1505,7 +1881,6 @@ def update_selection( # Add the selected hypothesis marker self.add_highlight_circle(gui_location) - self.plotter.at(0).render() return widget, True @@ -1520,13 +1895,17 @@ class HypothesisMeshWidgetOps: Attributes: plotter: A `vedo.Plotter` used to add and remove actors. + data_parser: Parser that extracts entries from the JSON log. ycb_loader: Loader that returns a textured `vedo.Mesh` for a YCB object. updaters: Two `WidgetUpdater`s for clear and update actions. info_widget: The text panel shown alongside the mesh. """ - def __init__(self, plotter: Plotter, ycb_loader: YCBMeshLoader) -> None: + def __init__( + self, plotter: Plotter, data_parser: DataParser, ycb_loader: YCBMeshLoader + ) -> None: self.plotter = plotter + self.data_parser = data_parser self.ycb_loader = ycb_loader self.updaters = [ WidgetUpdater( @@ -1534,24 +1913,210 @@ def __init__(self, plotter: Plotter, ycb_loader: YCBMeshLoader) -> None: callback=self.clear_mesh, ), WidgetUpdater( - topics=[TopicSpec("selected_hypothesis", required=True)], + topics=[ + TopicSpec("selected_hypothesis", required=True), + ], callback=self.update_mesh, ), + WidgetUpdater( + topics=[ + EventSpec("KeyPressed", "KeyPressEvent", required=True), + ], + callback=self.update_transparency, + ), + WidgetUpdater( + topics=[ + TopicSpec("selected_hypothesis", required=True), + EventSpec("KeyPressed", "KeyPressEvent", required=False), + ], + callback=self.update_paths, + ), ] + self.mesh_transparency: float = 0.0 self.default_object_position = (0, 1.5, 0) self.sensor_sphere: Sphere | None = None + self.text_label: Text2D = Text2D( + txt="Selected Hypothesis", pos="top-center", font=FONT + ) + + # Path visibility states + self.show_past_path: bool = False + self.show_future_path: bool = False + + self.past_path_spheres: list[Sphere] = [] + self.future_path_spheres: list[Sphere] = [] + self.past_path_line: Line | None = None + self.future_path_line: Line | None = None + + self._locators = self.create_locators() + + self.plotter.at(2).add(self.text_label) + + def create_locators(self) -> dict[str, DataLocator]: + """Returns data locators needed to trace the hypothesis.""" + locators: dict[str, DataLocator] = {} + + locators["episode"] = DataLocator(path=[DataLocatorStep.key(name="episode")]) + + locators["step"] = locators["episode"].extend( + steps=[ + DataLocatorStep.key(name="lm", value="LM_0"), + DataLocatorStep.key( + name="telemetry", value="hypotheses_updater_telemetry" + ), + DataLocatorStep.index(name="step"), + ] + ) + + locators["channel"] = locators["step"].extend( + steps=[ + DataLocatorStep.key(name="obj"), + DataLocatorStep.key(name="channel", value="patch"), + ] + ) - self.plotter.at(2).add(Text2D(txt="Selected Hypothesis", pos="top-center")) + locators["updater"] = locators["channel"].extend( + steps=[DataLocatorStep.key(name="stat", value="hypotheses_updater")] + ) + + return locators + + def _extract_channel_data( + self, episode: str, step: int, obj: str + ) -> dict[str, Iterable]: + channel_data = self.data_parser.extract( + self._locators["channel"], episode=episode, step=step, obj=obj + ) + updater_data = self.data_parser.extract( + self._locators["updater"], episode=episode, step=step, obj=obj + ) + return { + "locations": channel_data["locations"], + "evidence": channel_data["evidence"], + "evidence_slopes": updater_data["evidence_slopes"], + } + + def _extract_ids_at_step( + self, episode: str, step: int, obj: str + ) -> tuple[Iterable[int], Iterable[int]]: + """Return added and removed hypothesis ids.""" + updater_data = self.data_parser.extract( + self._locators["updater"], episode=episode, step=step, obj=obj + ) + return updater_data["added_ids"], updater_data["removed_ids"] + + def _num_steps_in_episode(self, episode: str) -> int: + return len(self.data_parser.query(self._locators["step"], episode=episode)) + + def _num_episodes(self) -> int: + return len(self.data_parser.query(self._locators["episode"])) + + def _increment_location_pair( + self, episode: str, step: int + ) -> tuple[str | None, int | None]: + ep = int(episode) + num_episodes = self._num_episodes() + steps_here = self._num_steps_in_episode(episode) + + if step < steps_here - 1: + return episode, step + 1 + + if ep < num_episodes - 1: + return str(ep + 1), 0 + + return None, None + + def _decrement_location_pair( + self, episode: str, step: int + ) -> tuple[str | None, int | None]: + ep = int(episode) + if step > 0: + return episode, step - 1 + + if ep > 0: + prev_ep = str(ep - 1) + return prev_ep, self._num_steps_in_episode(prev_ep) - 1 + + return None, None + + def _trace_hypothesis_positions( + self, episode: str, step: int, obj: str, ix: int + ) -> DataFrame: + # Current row + row_data = self._extract_channel_data(episode, step, obj) + rows: list[dict] = [ + { + "Episode": int(episode), + "Step": int(step), + "Loc_x": row_data["locations"][ix][0], + "Loc_y": row_data["locations"][ix][1], + "Loc_z": row_data["locations"][ix][2], + } + ] + + # Backward + rows_back: list[dict] = [] + episode_b, step_b, idx_b = episode, step, ix + while True: + added_ids, removed_ids = self._extract_ids_at_step(episode_b, step_b, obj) + idx_prev = trace_hypothesis_backward( + idx_b, removed_ids=sorted(removed_ids), added_ids=sorted(added_ids) + ) + if idx_prev is None: + break + + episode_prev, step_prev = self._decrement_location_pair(episode_b, step_b) + if episode_prev is None or step_prev is None: + break + + episode_b, step_b, idx_b = episode_prev, step_prev, idx_prev + row_data = self._extract_channel_data(episode_b, step_b, obj) + rows_back.append( + { + "Episode": int(episode_b), + "Step": int(step_b), + "Loc_x": row_data["locations"][idx_b][0], + "Loc_y": row_data["locations"][idx_b][1], + "Loc_z": row_data["locations"][idx_b][2], + } + ) + + # Forward + rows_forward: list[dict] = [] + episode_f, step_f, idx_f = episode, step, ix + while True: + episode_next, step_next = self._increment_location_pair(episode_f, step_f) + if episode_next is None or step_next is None: + break + + _, removed_ids = self._extract_ids_at_step(episode_next, step_next, obj) + idx_next = trace_hypothesis_forward(idx_f, removed_ids=sorted(removed_ids)) + if idx_next is None: + break + + episode_f, step_f, idx_f = episode_next, step_next, idx_next + row_data = self._extract_channel_data(episode_f, step_f, obj) + rows_forward.append( + { + "Episode": int(episode_f), + "Step": int(step_f), + "Loc_x": row_data["locations"][idx_f][0], + "Loc_y": row_data["locations"][idx_f][1], + "Loc_z": row_data["locations"][idx_f][2], + } + ) + + rows_back.reverse() + all_rows = rows_back + rows + rows_forward + return DataFrame( + all_rows, columns=["Episode", "Step", "Loc_x", "Loc_y", "Loc_z"] + ) def clear_mesh( self, widget: Mesh | None, msgs: list[TopicMessage] ) -> tuple[Any, bool]: """Clear the mesh and info panel if present. - Args: - widget: Current mesh object, if any. - msgs: Unused. Present for the updater interface. - Returns: `(widget, False)` to indicate no publish should occur. """ @@ -1562,7 +2127,9 @@ def clear_mesh( self.plotter.at(2).remove(self.sensor_sphere) self.sensor_sphere = None - self.plotter.at(2).render() + self._clear_paths() + + self.updaters[3].expire_topic("selected_hypothesis") return widget, False def update_mesh( @@ -1593,15 +2160,132 @@ def update_mesh( widget.rotate_y(hypothesis["Rot_y"]) widget.rotate_z(hypothesis["Rot_z"]) widget.shift(self.default_object_position) + widget.alpha(1.0 - self.mesh_transparency) self.plotter.at(2).add(widget) # Add sphere for sensor's hypothesized location sensor_pos = (hypothesis["Loc_x"], hypothesis["Loc_y"], hypothesis["Loc_z"]) - self.sensor_sphere = Sphere(pos=sensor_pos, r=0.002).c("green") - self.sensor_sphere.pos(sensor_pos) + self.sensor_sphere = Sphere(pos=sensor_pos, r=0.003).c(COLOR_PALETTE["Primary"]) self.plotter.at(2).add(self.sensor_sphere) - self.plotter.at(2).render() + self.updaters[1].expire_topic("selected_hypothesis") + + return widget, False + + def update_transparency( + self, widget: None, msgs: list[TopicMessage] + ) -> tuple[None, bool]: + msgs_dict = {msg.name: msg.value for msg in msgs} + + key_event = msgs_dict.get("KeyPressEvent", None) + if key_event is not None and getattr(key_event, "at", None) == 2: + key = getattr(key_event, "keypress", None) + + if key == "Left": + self.mesh_transparency -= 0.5 + elif key == "Right": + self.mesh_transparency += 0.5 + + self.mesh_transparency = float(np.clip(self.mesh_transparency, 0.0, 1.0)) + if widget is not None: + widget.alpha(1.0 - self.mesh_transparency) + + self.updaters[2].expire_topic("KeyPressEvent") + + return widget, False + + def _clear_paths(self) -> None: + for s in self.past_path_spheres: + self.plotter.at(2).remove(s) + for s in self.future_path_spheres: + self.plotter.at(2).remove(s) + + self.past_path_spheres.clear() + self.future_path_spheres.clear() + + if self.past_path_line is not None: + self.plotter.at(2).remove(self.past_path_line) + self.past_path_line = None + if self.future_path_line is not None: + self.plotter.at(2).remove(self.future_path_line) + self.future_path_line = None + + def _rebuild_paths( + self, + episode: str, + step: int, + hyp: Series, + ) -> None: + self._clear_paths() + + if not (self.show_past_path or self.show_future_path): + return + + df = self._trace_hypothesis_positions( + episode=episode, + step=step, + obj=hyp["graph_id"], + ix=int(hyp["id"]), + ) + if df.empty: + return + + df = df.reset_index(drop=True) + mask_current = (df["Episode"] == int(episode)) & (df["Step"] == step) + idx_list = df.index[mask_current].tolist() + current_idx = idx_list[0] + + if self.show_past_path: + past_pts = df.loc[:current_idx, ["Loc_x", "Loc_y", "Loc_z"]].to_numpy() + self._build_path_geometry(past_pts, past=True) + + if self.show_future_path and current_idx < len(df) - 1: + future_pts = df.loc[ + current_idx + 1 :, ["Loc_x", "Loc_y", "Loc_z"] + ].to_numpy() + self._build_path_geometry(future_pts, past=False) + + def _build_path_geometry(self, points: np.ndarray, past: bool) -> None: + if points.size == 0: + return + + spheres_list = self.past_path_spheres if past else self.future_path_spheres + line_attr = "past_path_line" if past else "future_path_line" + color = COLOR_PALETTE["Primary"] if past else COLOR_PALETTE["Accent2"] + + for p in points: + s = Sphere(pos=p, r=0.002, c=color) + self.plotter.at(2).add(s) + spheres_list.append(s) + + if len(points) >= 2: + line = Line(points, c=color, lw=1) + setattr(self, line_attr, line) + self.plotter.at(2).add(line) + + def update_paths( + self, widget: Mesh | None, msgs: list[TopicMessage] + ) -> tuple[Mesh | None, bool]: + msgs_dict = {msg.name: msg.value for msg in msgs} + + hyp: Series = msgs_dict["selected_hypothesis"] + episode = str(hyp["episode"]) + step = int(hyp["step"]) + + key_event = msgs_dict.get("KeyPressEvent", None) + if key_event is not None and getattr(key_event, "at", None) == 2: + key = getattr(key_event, "keypress", None) + if key == "s": + self.show_past_path = not self.show_past_path + elif key == "S": + self.show_future_path = not self.show_future_path + elif key == "d": + self.show_past_path = False + self.show_future_path = False + + self._rebuild_paths(episode=episode, step=step, hyp=hyp) + self.updaters[3].expire_topic("KeyPressEvent") + return widget, False @@ -1735,12 +2419,29 @@ def add_hyp_space_size_figure( fig, ax = plt.subplots(figsize=(6, 3)) sns.lineplot( - ax=ax, data=merged, x="step", y="idx_current", label=str(current_object) + ax=ax, + data=merged, + x="step", + y="idx_current", + label=str(current_object), + color=COLOR_PALETTE["Primary"], ) sns.lineplot( - ax=ax, data=merged, x="step", y="idx_others", label="Other objects" + ax=ax, + data=merged, + x="step", + y="idx_others", + label="other objects", + color=COLOR_PALETTE["Secondary"], + ) + sns.lineplot( + ax=ax, + data=merged, + x="step", + y="idx_total", + label="total", + color=COLOR_PALETTE["Accent3"], ) - sns.lineplot(ax=ax, data=merged, x="step", y="idx_total", label="Total") ax.set_xlabel("Step") ax.set_ylabel("% change from step 0") @@ -1776,7 +2477,6 @@ def update_plot( widget = self.add_hyp_space_size_figure( hyp_size_df, msgs_dict["current_object"] ) - self.plotter.at(0).render() return widget, False @@ -1865,7 +2565,7 @@ def _extract_channel_data( "evidence_slopes": updater_data["evidence_slopes"], } - def _extract_mod_ids( + def _extract_ids_at_step( self, episode: str, step: int, obj: str ) -> tuple[Iterable[int], Iterable[int]]: """Returns the added and removed ids.""" @@ -1981,7 +2681,7 @@ def _trace_hypothesis( rows_back: list[dict] = [] episode_b, step_b, idx_b = episode, step, ix while True: - added_ids, removed_ids = self._extract_mod_ids(episode_b, step_b, obj) + added_ids, removed_ids = self._extract_ids_at_step(episode_b, step_b, obj) idx_prev = trace_hypothesis_backward( idx_b, removed_ids=sorted(removed_ids), added_ids=sorted(added_ids) ) @@ -2018,7 +2718,7 @@ def _trace_hypothesis( if episode_next is None or step_next is None: break - _, removed_ids = self._extract_mod_ids(episode_next, step_next, obj) + _, removed_ids = self._extract_ids_at_step(episode_next, step_next, obj) idx_next = trace_hypothesis_forward(idx_f, removed_ids=sorted(removed_ids)) # Hypothesis is deleted @@ -2086,10 +2786,10 @@ def _add_lifespan_figure( marker="o", markersize=4, linewidth=1.2, - color=HUE_PALETTE["Evidence"], + color=COLOR_PALETTE["Primary"], label="Evidence", ) - ax1.set_xlabel("Episode / Step") + ax1.set_xlabel("Time") ax1.set_ylabel("Evidence") # Evidence slopes plot on right axis @@ -2102,13 +2802,15 @@ def _add_lifespan_figure( marker="o", markersize=4, linewidth=1.2, - color=HUE_PALETTE["Slope"], + color=COLOR_PALETTE["Secondary"], label="Evidence Slope", ) - ax2.set_ylabel("Evidence Slope") + ax2.set_ylabel("Recent Evidence Change") # Setting ticks on x-axis x_min, x_max = df["x"].min(), df["x"].max() + x_min = min(x_min, x_current) + x_max = max(x_max, x_current) ax1.set_xlim(x_min - 0.5, x_max + 0.5) major_locs_all = [ (ep, episode_offsets[ep]) for ep in range(start_episode, end_episode + 1) @@ -2130,17 +2832,20 @@ def _add_lifespan_figure( ax1.axvline(x=x_current, linestyle="--", linewidth=1.0, color="0.2") # Legend for both axes + label_renames = { + "Evidence": "Evidence", + "Evidence Slope": "Recent Evidence Change", + } lines, labels = [], [] for ax in (ax1, ax2): line, label = ax.get_legend_handles_labels() if line: - lines += line - labels += label + lines.append(line[0]) + labels.append(label_renames.get(label[0], label)) ax.legend_.remove() if lines: ax1.legend(lines, labels, loc="best", frameon=True) - ax1.set_title("Hypothesis Lifespan") fig.tight_layout() widget = Image(fig) @@ -2154,10 +2859,10 @@ def _add_info_text(self, hyp: Series): info = ( f"Age: {hyp['age']}\n" + f"Evidence: {hyp['Evidence']:.2f}\n" - + f"Evidence Slope: {hyp['Evidence Slope']:.2f}\n" + + f"Recent Evidence Change: {hyp['Evidence Slope']:.2f}\n" + f"Pose Error: {hyp['Pose Error']:.2f}" ) - self.info_widget = Text2D(txt=info, pos="top-right") + self.info_widget = Text2D(txt=info, pos="top-right", font=FONT) self.plotter.at(0).add(self.info_widget) def update_plot( @@ -2179,15 +2884,20 @@ def update_plot( step = int(msgs_dict["step_number"]) hyp = msgs_dict["selected_hypothesis"] - df = self._trace_hypothesis(episode, step, hyp["graph_id"], hyp["id"]) + df = self._trace_hypothesis( + str(hyp["episode"]), + int(hyp["step"]), + hyp["graph_id"], + hyp["id"], + ) + widget = self._add_lifespan_figure( df, current_episode=int(episode), current_step=step ) self._add_info_text(hyp) - self.plotter.at(0).render() - + self.updaters[0].expire_topic("selected_hypothesis") return widget, False def clear_plot( @@ -2209,8 +2919,6 @@ def clear_plot( self.plotter.at(0).remove(self.info_widget) self.info_widget = None - self.plotter.at(0).render() - return None, False @@ -2269,6 +2977,10 @@ def __init__( w.add() self._widgets["episode_slider"].set_state(0) self._widgets["age_threshold"].set_state(0) + self._widgets["topk_slider"].set_state(0) + + self.scope_viewer = ScopeViewer(self.plotter, self._widgets) + self.plotter.add_callback("KeyPress", self._on_keypress_quit) self.plotter.at(0).show( camera=deepcopy(self.cam_dict), @@ -2297,6 +3009,7 @@ def create_widgets(self): plotter=self.plotter, data_parser=self.data_parser, ), + scopes=[1, 2, 3], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.5, @@ -2308,6 +3021,7 @@ def create_widgets(self): plotter=self.plotter, data_parser=self.data_parser, ), + scopes=[1, 2, 3], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.5, @@ -2320,6 +3034,7 @@ def create_widgets(self): data_parser=self.data_parser, ycb_loader=self.ycb_loader, ), + scopes=[1], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.5, @@ -2328,6 +3043,7 @@ def create_widgets(self): widgets["primary_button"] = Widget[Button, str]( widget_ops=PrimaryButtonWidgetOps(plotter=self.plotter), + scopes=[2, 3], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.2, @@ -2336,6 +3052,7 @@ def create_widgets(self): widgets["prev_button"] = Widget[Button, str]( widget_ops=PrevButtonWidgetOps(plotter=self.plotter), + scopes=[2, 3], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.2, @@ -2344,6 +3061,7 @@ def create_widgets(self): widgets["next_button"] = Widget[Button, str]( widget_ops=NextButtonWidgetOps(plotter=self.plotter), + scopes=[2, 3], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.2, @@ -2352,14 +3070,27 @@ def create_widgets(self): widgets["age_threshold"] = Widget[Slider2D, int]( widget_ops=AgeThresholdWidgetOps(plotter=self.plotter), + scopes=[2], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.5, dedupe=True, ) + widgets["topk_slider"] = Widget[Slider2D, int]( + widget_ops=TopKSliderWidgetOps( + plotter=self.plotter, + ), + scopes=[2], + bus=self.event_bus, + scheduler=self.scheduler, + debounce_sec=0.1, + dedupe=True, + ) + widgets["current_object"] = Widget[None, str]( widget_ops=CurrentObjectWidgetOps(data_parser=self.data_parser), + scopes=[2, 3], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.2, @@ -2370,6 +3101,7 @@ def create_widgets(self): widget_ops=ClickWidgetOps( plotter=self.plotter, cam_dict=deepcopy(self.cam_dict) ), + scopes=[1, 2, 3], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.1, @@ -2380,6 +3112,7 @@ def create_widgets(self): widget_ops=CorrelationPlotWidgetOps( plotter=self.plotter, data_parser=self.data_parser ), + scopes=[2], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.3, @@ -2389,8 +3122,10 @@ def create_widgets(self): widgets["hypothesis_mesh"] = Widget[Mesh, None]( widget_ops=HypothesisMeshWidgetOps( plotter=self.plotter, + data_parser=self.data_parser, ycb_loader=self.ycb_loader, ), + scopes=[3], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.0, @@ -2402,6 +3137,7 @@ def create_widgets(self): plotter=self.plotter, data_parser=self.data_parser, ), + scopes=[2], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.0, @@ -2413,6 +3149,7 @@ def create_widgets(self): plotter=self.plotter, data_parser=self.data_parser, ), + scopes=[3], bus=self.event_bus, scheduler=self.scheduler, debounce_sec=0.0, @@ -2421,6 +3158,11 @@ def create_widgets(self): return widgets + def _on_keypress_quit(self, event): + key = getattr(event, "keypress", None) + if key is not None and key.lower() == "q": + self.plotter.interactor.ExitCallback() + @register( "interactive_hypothesis_space_correlation", diff --git a/src/tbp/plot/plots/interactive_hypothesis_space_pointcloud.py b/src/tbp/plot/plots/interactive_hypothesis_space_pointcloud.py index ec209c4..a0a9892 100644 --- a/src/tbp/plot/plots/interactive_hypothesis_space_pointcloud.py +++ b/src/tbp/plot/plots/interactive_hypothesis_space_pointcloud.py @@ -57,6 +57,20 @@ FONT_SIZE = 25 +COLOR_PALETTE = { + "Blue": Palette.as_hex("numenta_blue"), + "Pink": Palette.as_hex("pink"), + "Purple": Palette.as_hex("purple"), + "Gold": Palette.as_hex("gold"), + "Green": Palette.as_hex("green"), + "Primary": Palette.as_hex("numenta_blue"), + "Secondary": Palette.as_hex("purple"), + "Accent": Palette.as_hex("charcoal"), + "Accent2": Palette.as_hex("link_water"), + "Accent3": Palette.as_hex("rich_black"), +} + + class StepMapper: """Bidirectional mapping between global step indices and (episode, local_step). @@ -446,14 +460,14 @@ def update_agent( if self.agent_sphere is None: self.agent_sphere = Sphere( - pos=agent_pos, r=0.004, c=Palette.as_hex("vivid_violet") + pos=agent_pos, r=0.004, c=COLOR_PALETTE["Secondary"] ) self.plotter.at(1).add(self.agent_sphere) self.agent_sphere.pos(agent_pos) if self.gaze_line is None: self.gaze_line = Line( - agent_pos, patch_pos, c=Palette.as_hex("rich_black"), lw=4 + agent_pos, patch_pos, c=COLOR_PALETTE["Accent3"], lw=4 ) self.plotter.at(1).add(self.gaze_line) self.gaze_line.points = [agent_pos, patch_pos] @@ -506,15 +520,13 @@ def update_agent_path( # Create small spheres at each position for p in points: - sphere = Sphere(pos=p, r=0.002, c=Palette.as_hex("vivid_violet")) + sphere = Sphere(pos=p, r=0.002, c=COLOR_PALETTE["Secondary"]) self.plotter.at(1).add(sphere) self.agent_path_spheres.append(sphere) # Create a polyline connecting all points if len(points) >= 2: - self.agent_path_line = Line( - points, c=Palette.as_hex("vivid_violet"), lw=1 - ) + self.agent_path_line = Line(points, c=COLOR_PALETTE["Secondary"], lw=1) self.plotter.at(1).add(self.agent_path_line) self.plotter.at(1).render() @@ -554,15 +566,13 @@ def update_patch_path( # Create small black spheres at each patch position for p in points: - sphere = Sphere(pos=p, r=0.002, c=Palette.as_hex("rich_black")) + sphere = Sphere(pos=p, r=0.002, c=COLOR_PALETTE["Accent3"]) self.plotter.at(1).add(sphere) self.patch_path_spheres.append(sphere) # Create a thin black polyline connecting all patch positions if len(points) >= 2: - self.patch_path_line = Line( - points, c=Palette.as_hex("rich_black"), lw=1 - ) + self.patch_path_line = Line(points, c=COLOR_PALETTE["Accent3"], lw=1) self.plotter.at(1).add(self.patch_path_line) self.plotter.at(1).render() @@ -632,7 +642,7 @@ def __init__(self, plotter: Plotter): "pos": (0.16, 0.98), "states": ["Agent Path: On", "Agent Path: Off"], "c": ["w", "w"], - "bc": [Palette.as_hex("numenta_blue"), Palette.as_hex("vivid_violet")], + "bc": [COLOR_PALETTE["Primary"], COLOR_PALETTE["Secondary"]], "size": FONT_SIZE, "font": FONT, "bold": False, @@ -671,7 +681,7 @@ def __init__(self, plotter: Plotter): "pos": (0.37, 0.98), "states": ["Patch Path: On", "Patch Path: Off"], "c": ["w", "w"], - "bc": [Palette.as_hex("numenta_blue"), Palette.as_hex("vivid_violet")], + "bc": [COLOR_PALETTE["Primary"], COLOR_PALETTE["Secondary"]], "size": FONT_SIZE, "font": FONT, "bold": False, @@ -840,7 +850,7 @@ def _create_hyp_space( evidences, locations, pose_errors, ages, slopes = self._extract_obj_telemetry( episode_number, step_number, curr_object ) - pts = Points(np.array(locations), r=6, c=Palette.as_hex("vivid_violet")) + pts = Points(np.array(locations), r=6, c=COLOR_PALETTE["Secondary"]) if hyp_color_button == "Evidence": pts.cmap("viridis", evidences, vmin=0.0) @@ -852,7 +862,7 @@ def _create_hyp_space( mlh_sphere = Sphere( pos=locations[int(np.argmax(evidences))], r=0.002, - c=Palette.as_hex("numenta_blue"), + c=COLOR_PALETTE["Primary"], ) self.mlh_sphere = mlh_sphere self.plotter.at(2).add(mlh_sphere) @@ -903,7 +913,7 @@ def __init__(self, plotter: Plotter): "pos": (0.63, 0.98), "states": ["Pretrained Model: On", "Pretrained Model: Off"], "c": ["w", "w"], - "bc": [Palette.as_hex("numenta_blue"), Palette.as_hex("vivid_violet")], + "bc": [COLOR_PALETTE["Primary"], COLOR_PALETTE["Secondary"]], "size": FONT_SIZE, "font": FONT, "bold": False, @@ -937,7 +947,7 @@ def __init__(self, plotter: Plotter): "pos": (0.85, 0.98), "states": ["Hypotheses: On", "Hypotheses: Off"], "c": ["w", "w"], - "bc": [Palette.as_hex("numenta_blue"), Palette.as_hex("vivid_violet")], + "bc": [COLOR_PALETTE["Primary"], COLOR_PALETTE["Secondary"]], "size": FONT_SIZE, "font": FONT, "bold": False, @@ -977,12 +987,12 @@ def __init__(self, plotter: Plotter): "states": ["None", "Evidence", "MLH", "Pose Error", "Slope", "Ages"], "c": ["w", "w", "w", "w", "w", "w"], "bc": [ - Palette.as_hex("link_water"), - Palette.as_hex("numenta_blue"), - Palette.as_hex("vivid_violet"), - Palette.as_hex("bossanova"), - Palette.as_hex("charcoal"), - Palette.as_hex("amethyst"), + COLOR_PALETTE["Accent"], + COLOR_PALETTE["Primary"], + COLOR_PALETTE["Secondary"], + COLOR_PALETTE["Pink"], + COLOR_PALETTE["Gold"], + COLOR_PALETTE["Green"], ], "size": FONT_SIZE - 5, "font": FONT, @@ -1216,7 +1226,7 @@ def _create_burst_figure(self, global_step: int) -> plt.Figure: ax_left.plot( x[start_idx:], slopes[start_idx:], - color=Palette.as_hex("numenta_blue"), + color=COLOR_PALETTE["Primary"], label="Max slope", ) @@ -1224,7 +1234,7 @@ def _create_burst_figure(self, global_step: int) -> plt.Figure: x, hyp_space_sizes, linestyle="--", - color=Palette.as_hex("numenta_blue"), + color=COLOR_PALETTE["Primary"], label="Hypothesis space size", ) @@ -1234,7 +1244,7 @@ def _create_burst_figure(self, global_step: int) -> plt.Figure: hyp_space_sizes, 0.0, alpha=0.15, - color=Palette.as_hex("numenta_blue"), + color=COLOR_PALETTE["Primary"], zorder=0, ) @@ -1250,7 +1260,7 @@ def _create_burst_figure(self, global_step: int) -> plt.Figure: add_idx, ymin, ymax, - colors=Palette.as_hex("vivid_violet"), + colors=COLOR_PALETTE["Secondary"], linestyles="--", alpha=1.0, linewidth=1.0, @@ -1272,7 +1282,7 @@ def _create_burst_figure(self, global_step: int) -> plt.Figure: label="Current step", ) - ax_left.set_ylabel("Max Recent Evidence Growth") + ax_left.set_ylabel("Max Recent Evidence Change") ax_right.set_ylabel("Hyp space size") # Collect legend entries @@ -1289,7 +1299,7 @@ def _create_burst_figure(self, global_step: int) -> plt.Figure: time_labels: list[str] = [] label_renames = { - "Max slope": "Max Recent Evidence Growth", + "Max slope": "Max Recent Evidence Change", "Hypothesis space size": "Hypothesis Space Size", "Burst": "Sampling Burst", "Current step": "Current Step",