Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"vedo",
"pypubsub",
"scipy>=1.16.2",
"torch-geometric>=2.7.0",
]
description = "A visualization tool for plotting tbp.monty visualizations."
dynamic = ["version"]
Expand Down
75 changes: 75 additions & 0 deletions src/tbp/interactive/colors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2025 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from __future__ import annotations

from dataclasses import dataclass

Color = str | tuple[float, float, float] | tuple[float, float, float, float]


def hex_to_rgb(hex_str: str, alpha: float | None = None) -> Color:
hex_clean = hex_str.lstrip("#")
if len(hex_clean) != 6:
raise ValueError(f"Expected 6 hex digits, got: {hex_str!r}")

r = int(hex_clean[0:2], 16) / 255.0
g = int(hex_clean[2:4], 16) / 255.0
b = int(hex_clean[4:6], 16) / 255.0

if alpha is None:
return (r, g, b)
return (r, g, b, alpha)


@dataclass(frozen=True)
class Palette:
"""The TBP color palette.

If you request a color that doesn't exist, a KeyError is raised with a list of
available names.
"""

# Primary Colors
indigo: str = "#2f2b5c"
numenta_blue: str = "#00a0df"

# Secondary Colors
bossanova: str = "#5c315f"
vivid_violet: str = "#86308b"
blue_violet: str = "#655eb2"
amethyst: str = "#915acc"

# Accent Colors/Shades
rich_black: str = "#000000"
charcoal: str = "#3f3f3f"
link_water: str = "#dfe6f5"

# ---------- Internal helper ----------
@classmethod
def _validate(cls, name: str) -> str:
if not hasattr(cls, name):
available = [k for k in cls.__dict__.keys() if not k.startswith("_")]
msg = (
f"Color '{name}' is not defined in Palette.\n"
f"Available colors: {', '.join(available)}"
)
raise KeyError(msg)
return getattr(cls, name)

# ---------- Public API ----------
@classmethod
def as_hex(cls, name: str) -> Color:
"""Return the raw hex string for a color name."""
return cls._validate(name)

@classmethod
def as_rgb(cls, name: str, alpha: float | None = None) -> Color:
"""Return the color as an RGB(A) tuple in [0,1] range."""
hex_str = cls._validate(name)
return hex_to_rgb(hex_str, alpha=alpha)
83 changes: 82 additions & 1 deletion src/tbp/interactive/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
from __future__ import annotations

import os
import pickle
import types
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import torch
import trimesh
from vedo import Mesh
from vedo import Mesh, Points

from tbp.plot.plots.stats import deserialize_json_chunks

Expand Down Expand Up @@ -100,6 +103,84 @@ def create_mesh(self, obj_name: str) -> Mesh:
return obj


class _MontyShimUnpickler(pickle.Unpickler):
"""Unpickler that shims out ``tbp.monty`` classes with lightweight dummies.

Any class reference from a module path starting with ``tbp.monty`` is
redirected to a dynamically created dummy class. This allows deserialization
of checkpoints that reference ``tbp.monty`` types even when the actual
package is not installed, while still making object attributes
(e.g., ``.pos``, ``.x``) accessible.

Dummy classes are cached by ``(module, name)`` so that repeated lookups
return the same type.
"""

_cache = {} # cache dummy classes

def find_class(self, module, name):
if module.startswith("tbp.monty"):
key = (module, name)
cls = self._cache.get(key)
if cls is None:
cls = type(name, (), {})
self._cache[key] = cls
return cls
return super().find_class(module, name)


class PretrainedModelsLoader:
"""Load Monty pretrained Models as point cloud Vedo object."""

def __init__(self, data_path: str, lm_id: int = 0, input_channel: str = "patch"):
"""Initialize the loader.

Args:
data_path: Path to the `model.pt` file holding the pretrained models.
lm_id: Which learning module to use when extracting the pretrained graphs.
input_channel: Which channel to use for extracting the pretrained graphs.
"""
self.path = data_path
models = self._torch_load_with_optional_shim()["lm_dict"][lm_id]["graph_memory"]
self.graphs = {k: v[input_channel]._graph for k, v in models.items()}

def _torch_load_with_optional_shim(self, map_location="cpu"):
"""Load a torch checkpoint with optional fallback for tbp.monty shimming.

Try a standard torch.load first (weights_only=False because we need objects).
If tbp.monty isn't installed and the checkpoint references it, optionally
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason tbp.monty wouldn't be installed if using these plotting tools? Was this something you came across?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, tbp.plot uses a different uv environment that does not contain tbp.monty in it. One reason is the mismatch in python versions, tbp.plot uses python 3.13 (not 3.8). But also, ideally, we wouldn't want to depend on the tbp.monty codebase since everything we need is the log files and pretrained models.

In the future, I imagine it would be really cool to have visualizations that actually run a Monty episode step by step from Vedo. This would allow us to change parameters, swap/manipulate objects, move sensors, etc., between steps. At this point, we would have to install tbp.monty.

retry using a restricted Unpickler that only dummies tbp.monty.* symbols.

Args:
map_location: Device mapping passed to `torch.load` (default: "cpu").

Returns:
The deserialized checkpoint object.

Raises:
ModuleNotFoundError: If a missing module other than `tbp.monty` is required.
"""
try:
return torch.load(self.path, map_location=map_location, weights_only=False)
except ModuleNotFoundError as e:
# Only intercept the specific missing tbp.monty namespace
if "tbp.monty" not in str(e):
raise

shim_pickle_module = types.ModuleType("monty_shim_pickle")
shim_pickle_module.Unpickler = _MontyShimUnpickler

return torch.load(
self.path,
map_location=map_location,
weights_only=False,
pickle_module=shim_pickle_module,
)

def create_model(self, obj_name: str) -> Points:
return Points(self.graphs[obj_name].pos.numpy(), r=4, c="gray")


class DataParser:
"""Parser that navigates nested JSON-like data using a `DataLocator`.

Expand Down
20 changes: 18 additions & 2 deletions src/tbp/interactive/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,32 @@
)


def extract_slider_state(widget: Slider2D) -> int:
def extract_button_state(widget: Button) -> str:
"""Read the Button state.

Args:
widget: The Vedo Button.

Returns:
The current button state as a string.
"""
return widget.status()


def extract_slider_state(widget: Slider2D, round_value: bool = True) -> int:
"""Read the slider state and round it to an integer value.

Args:
widget: The Vedo slider.
round_value: Whether to round the value to an integer or keep it as float.

Returns:
The current slider value rounded to the nearest integer.
"""
return round(widget.GetRepresentation().GetValue())
value = widget.GetRepresentation().GetValue()
if round_value:
value = round(value)
return value


def set_slider_state(widget: Slider2D, value: Any) -> None:
Expand Down
Loading