Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ba6c91b
feat: interactive plot for pointcloud hypothesis space
ramyamounir Nov 28, 2025
9a91325
refactor!: alpha to mesh transparency
ramyamounir Dec 2, 2025
d5427f1
refactor!: rename sensor to agent
ramyamounir Dec 2, 2025
c85b16c
refactor!: rename buttons to "Name: [on|off]"
ramyamounir Dec 2, 2025
97faf91
chore: reposition some buttons
ramyamounir Dec 2, 2025
784d6fb
feat: add tbp color palette
ramyamounir Dec 2, 2025
b9aa913
refactor: change colors to use tbp palette
ramyamounir Dec 2, 2025
7a96a50
fix: minor typing fix
ramyamounir Dec 2, 2025
8495f5d
feat: increase figure dpi to 200
ramyamounir Dec 2, 2025
ec627f7
refactor: extract FONT constant
ramyamounir Dec 3, 2025
39ba30d
refactor: change scalarbar font and font sizes extraction
ramyamounir Dec 3, 2025
ee94128
feat: allow topics to expire from inbox
ramyamounir Dec 3, 2025
d960862
refactor: fix fonts and style
ramyamounir Dec 3, 2025
fd54907
fix: handle empty hyp spaces and warnings
ramyamounir Dec 3, 2025
b92de35
refactor!: plot removed hypotheses intuitively
ramyamounir Dec 3, 2025
e93bb61
feat: add mesh transparency slider
ramyamounir Dec 6, 2025
b5ad627
refactor: improve correlation plot resolution
ramyamounir Dec 6, 2025
68d1d5e
feat!: disable default keys and add quit listener
ramyamounir Dec 6, 2025
6dd2522
feat: add ScopesViewer to the interactive library
ramyamounir Dec 6, 2025
f7183e8
refactor: use ScopeViewer in the correlation plot
ramyamounir Dec 6, 2025
80a24ae
refactor: update sensor name to agent
ramyamounir Dec 7, 2025
8b1ffe2
feat: add support for events in Widgets
ramyamounir Dec 7, 2025
bd60d10
feat: add agent and patch paths on keypress events
ramyamounir Dec 7, 2025
f823fb3
feat: repurpose transparency and add top_k slider
ramyamounir Dec 7, 2025
cbc1b80
feat: add options for hyp path
ramyamounir Dec 7, 2025
f71204a
feat: add rectangles and plotting mods
ramyamounir Dec 8, 2025
6c4658b
Merge branch 'hyp_space_pointcloud' into correlation_mods
ramyamounir Dec 8, 2025
45e9214
feat: color future paths differently
ramyamounir Dec 9, 2025
2daeb65
refactor!: label renames for lineplot
ramyamounir Dec 12, 2025
4b40f10
Merge branch 'hyp_space_pointcloud' into correlation_mods
ramyamounir Dec 12, 2025
b123d6d
refactor: rename evidence slope to recent evidence growth
ramyamounir Dec 12, 2025
99afc18
fix: add empty info text when hypothesis space is empty
ramyamounir Dec 12, 2025
e026465
fix: some minor misc fixes and comments modifications
ramyamounir Dec 12, 2025
5451703
docs: change docstrings and fix typos
ramyamounir Dec 13, 2025
de62bdd
refactor: change current_episode default from -1 to None
ramyamounir Dec 13, 2025
b7f2e32
Merge branch 'hyp_space_pointcloud' into correlation_mods
ramyamounir Dec 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"vedo",
"pypubsub",
"scipy>=1.16.2",
"torch-geometric>=2.7.0",
]
description = "A visualization tool for plotting tbp.monty visualizations."
dynamic = ["version"]
Expand Down
73 changes: 73 additions & 0 deletions src/tbp/interactive/colors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2025 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from __future__ import annotations

from dataclasses import dataclass

Color = str | tuple[float, float, float] | tuple[float, float, float, float]


def hex_to_rgb(hex_str: str, alpha: float | None = None) -> Color:
hex_clean = hex_str.lstrip("#")
if len(hex_clean) != 6:
raise ValueError(f"Expected 6 hex digits, got: {hex_str!r}")

r = int(hex_clean[0:2], 16) / 255.0
g = int(hex_clean[2:4], 16) / 255.0
b = int(hex_clean[4:6], 16) / 255.0

if alpha is None:
return (r, g, b)
return (r, g, b, alpha)


@dataclass(frozen=True)
class Palette:
"""The TBP color palette.

If you request a color that doesn't exist, a KeyError is raised with a list of
available names.
"""

# Primary Colors
indigo: str = "#2f2b5c"
numenta_blue: str = "#00a0df"

# Secondary Colors
bossanova: str = "#5c315f"
vivid_violet: str = "#86308b"
blue_violet: str = "#655eb2"
amethyst: str = "#915acc"

# Accent Colors/Shades
rich_black: str = "#000000"
charcoal: str = "#3f3f3f"
link_water: str = "#dfe6f5"

@classmethod
def _validate(cls, name: str) -> str:
if not hasattr(cls, name):
available = [k for k in cls.__dict__.keys() if not k.startswith("_")]
msg = (
f"Color '{name}' is not defined in Palette.\n"
f"Available colors: {', '.join(available)}"
)
raise KeyError(msg)
return getattr(cls, name)

@classmethod
def as_hex(cls, name: str) -> Color:
"""Return the raw hex string for a color name."""
return cls._validate(name)

@classmethod
def as_rgb(cls, name: str, alpha: float | None = None) -> Color:
"""Return the color as an RGB(A) tuple in [0,1] range."""
hex_str = cls._validate(name)
return hex_to_rgb(hex_str, alpha=alpha)
83 changes: 82 additions & 1 deletion src/tbp/interactive/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
from __future__ import annotations

import os
import pickle
import types
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import torch
import trimesh
from vedo import Mesh
from vedo import Mesh, Points

from tbp.plot.plots.stats import deserialize_json_chunks

Expand Down Expand Up @@ -100,6 +103,84 @@ def create_mesh(self, obj_name: str) -> Mesh:
return obj


class _MontyShimUnpickler(pickle.Unpickler):
"""Unpickler that shims out ``tbp.monty`` classes with lightweight dummies.

Any class reference from a module path starting with ``tbp.monty`` is
redirected to a dynamically created dummy class. This allows deserialization
of checkpoints that reference ``tbp.monty`` types even when the actual
package is not installed, while still making object attributes
(e.g., ``.pos``, ``.x``) accessible.

Dummy classes are cached by ``(module, name)`` so that repeated lookups
return the same type.
"""

_cache = {} # cache dummy classes

def find_class(self, module, name):
if module.startswith("tbp.monty"):
key = (module, name)
cls = self._cache.get(key)
if cls is None:
cls = type(name, (), {})
self._cache[key] = cls
return cls
return super().find_class(module, name)


class PretrainedModelsLoader:
"""Load Monty pretrained Models as point cloud Vedo object."""

def __init__(self, data_path: str, lm_id: int = 0, input_channel: str = "patch"):
"""Initialize the loader.

Args:
data_path: Path to the `model.pt` file holding the pretrained models.
lm_id: Which learning module to use when extracting the pretrained graphs.
input_channel: Which channel to use for extracting the pretrained graphs.
"""
self.path = data_path
models = self._torch_load_with_optional_shim()["lm_dict"][lm_id]["graph_memory"]
self.graphs = {k: v[input_channel]._graph for k, v in models.items()}

def _torch_load_with_optional_shim(self, map_location="cpu"):
"""Load a torch checkpoint with optional fallback for tbp.monty shimming.

Try a standard torch.load first (weights_only=False because we need objects).
If tbp.monty isn't installed and the checkpoint references it, optionally
retry using a restricted Unpickler that only dummies tbp.monty.* symbols.

Args:
map_location: Device mapping passed to `torch.load` (default: "cpu").

Returns:
The deserialized checkpoint object.

Raises:
ModuleNotFoundError: If a missing module other than `tbp.monty` is required.
"""
try:
return torch.load(self.path, map_location=map_location, weights_only=False)
except ModuleNotFoundError as e:
# Only intercept the specific missing tbp.monty namespace
if "tbp.monty" not in str(e):
raise

shim_pickle_module = types.ModuleType("monty_shim_pickle")
shim_pickle_module.Unpickler = _MontyShimUnpickler

return torch.load(
self.path,
map_location=map_location,
weights_only=False,
pickle_module=shim_pickle_module,
)

def create_model(self, obj_name: str) -> Points:
return Points(self.graphs[obj_name].pos.numpy(), r=4, c="gray")


class DataParser:
"""Parser that navigates nested JSON-like data using a `DataLocator`.

Expand Down
27 changes: 27 additions & 0 deletions src/tbp/interactive/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from dataclasses import dataclass


@dataclass
class EventSpec:
"""Specification for an Event to be defined as a WidgetUpdater callback trigger.

Attributes:
trigger: Event trigger name (e.g., KeyPressed)
name: Event name field in Vedo `event.name` (e.g., keypress).
required: Whether this event is required for the callback trigger. If
True, the updater will not call the callback until a message for this
topic arrives.
"""

trigger: str
name: str
required: bool = True
102 changes: 102 additions & 0 deletions src/tbp/interactive/scopes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2025 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.


from __future__ import annotations

from typing import Any

from vedo import Plotter

from tbp.interactive.widgets import Widget


class ScopeViewer:
"""Controls widget visibility using numeric keypress scopes.

Behavior summary:
- Scope 0:
* If at least one widget is hidden -> show all widgets.
* Else (all visible) -> hide ALL widgets.
- Scope k (1..9):
* Toggle that scope on/off.
* During a toggle off, a widget may remain visible if any other active scope
includes it.

The widgets themselves decide how to hide/show internally using
their .on() / .off() visibility handlers.
"""

def __init__(self, plotter: Plotter, widgets: dict[str, Widget]):
self.plotter = plotter
self.widgets = widgets

self.scope_to_widgets: dict[int, set[str]] = {}

# Build scope map from each widget's `scopes` list.
for name, widget in widgets.items():
for s in widget.scopes:
if s not in self.scope_to_widgets:
self.scope_to_widgets[s] = set()
self.scope_to_widgets[s].add(name)

self.active_scopes: set[int] = set(self.scope_to_widgets.keys())
self.plotter.add_callback("KeyPress", self._on_keypress)

def _on_keypress(self, event: Any) -> None:
key = getattr(event, "keypress", None)
if not key or not key.isdigit():
return

self.toggle_scope(int(key))
self.plotter.render()

def toggle_scope(self, scope_id: int) -> None:
"""Toggles a specific scope by its id."""
if scope_id == 0:
return self._toggle_all()

if scope_id not in self.scope_to_widgets:
return None

if scope_id in self.active_scopes:
self.active_scopes.remove(scope_id)
else:
self.active_scopes.add(scope_id)

self._apply_scope_visibility()

def _toggle_all(self) -> None:
"""Toggles all widgets on/off."""
any_hidden = any(not w.is_visible for w in self.widgets.values())

if any_hidden:
for w in self.widgets.values():
w.on()
self.active_scopes = set(self.scope_to_widgets.keys())
else:
for w in self.widgets.values():
w.off()
self.active_scopes.clear()

def _apply_scope_visibility(self) -> None:
# If nothing is active, hide everything
if not self.active_scopes:
for w in self.widgets.values():
w.off()
return

for widget in self.widgets.values():
# Does this widget belong to ANY active scope?
belongs = any(s in self.active_scopes for s in widget.scopes)

if belongs and not widget.is_visible:
widget.on()
elif not belongs and widget.is_visible:
widget.off()
14 changes: 13 additions & 1 deletion src/tbp/interactive/widget_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Protocol, runtime_checkable

from tbp.interactive.events import EventSpec
from tbp.interactive.topics import TopicMessage, TopicSpec

if TYPE_CHECKING:
Expand Down Expand Up @@ -56,7 +57,7 @@ class WidgetUpdater[WidgetT]:
`topics`.
"""

topics: Iterable[TopicSpec]
topics: Iterable[TopicSpec | EventSpec]
callback: Callable[
[WidgetT | None, list[TopicMessage]], tuple[WidgetT | None, bool]
]
Expand Down Expand Up @@ -86,6 +87,17 @@ def accepts(self, msg: TopicMessage) -> bool:
"""
return any(spec.name == msg.name for spec in self.topics)

def expire_topic(self, topic_name: str) -> None:
"""Expire (remove) the stored message for a given topic.

After expiration, the updater may require a new message for that topic
before becoming ready again.

Args:
topic_name: The topic whose message should be invalidated.
"""
self._inbox.pop(topic_name, None)

def __call__(
self, widget: WidgetT | None, msg: TopicMessage
) -> tuple[WidgetT | None, bool]:
Expand Down
Loading