From aa0544819a4c5955219c9da7992b2bdd059c940e Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 29 Mar 2026 16:49:24 +0100 Subject: [PATCH 1/4] Memory2 record/replay functionality --- .gitignore | 3 + dimos/core/blueprints.py | 98 ++++ dimos/core/global_config.py | 1 + dimos/core/module_coordinator.py | 8 +- dimos/memory2/backend.py | 12 + dimos/memory2/codecs/base.py | 4 +- dimos/memory2/observationstore/base.py | 4 + dimos/memory2/observationstore/memory.py | 8 + dimos/memory2/observationstore/sqlite.py | 16 + dimos/memory2/stream.py | 6 + dimos/protocol/pubsub/impl/lcmpubsub.py | 5 +- dimos/protocol/pubsub/spec.py | 2 +- dimos/record/__init__.py | 17 + dimos/record/record_replay.py | 389 ++++++++++++++++ dimos/record/test_record_replay.py | 255 +++++++++++ dimos/robot/cli/dimos.py | 57 ++- dimos/types/__init__.py | 0 dimos/utils/cli/recorder/__init__.py | 0 dimos/utils/cli/recorder/run_recorder.py | 542 +++++++++++++++++++++++ flake.nix | 1 + pyproject.toml | 2 + uv.lock | 2 + 22 files changed, 1402 insertions(+), 30 deletions(-) create mode 100644 dimos/record/__init__.py create mode 100644 dimos/record/record_replay.py create mode 100644 dimos/record/test_record_replay.py create mode 100644 dimos/types/__init__.py create mode 100644 dimos/utils/cli/recorder/__init__.py create mode 100644 dimos/utils/cli/recorder/run_recorder.py diff --git a/.gitignore b/.gitignore index 4045db012e..4c34c551f4 100644 --- a/.gitignore +++ b/.gitignore @@ -77,3 +77,6 @@ CLAUDE.MD htmlcov/ .coverage .coverage.* + +# Created from simulation +MUJOCO_LOG.TXT diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index cac8507881..fa18cbe83e 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -30,6 +30,8 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.record.record_replay import RecordReplay from dimos.spec.utils import Spec, is_spec, spec_annotation_compliance, spec_structural_compliance from dimos.utils.generic import short_id from dimos.utils.logging_config import setup_logger @@ -471,6 +473,79 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: requested_method_name, rpc_methods_dot[requested_method_name] ) + def replay( + self, + recording: RecordReplay | str, + *, + speed: float = 1.0, + cli_config_overrides: Mapping[str, Any] | None = None, + ) -> ModuleCoordinator: + """Build the blueprint with a recording providing some module outputs. + + Modules whose OUT streams are fully covered by the recording are + disabled — their data comes from the recording instead. All other + modules run normally. + + Args: + recording: A :class:`RecordReplay` instance, or a str + to a ``.db`` recording file. + speed: Playback speed multiplier (1.0 = realtime). + cli_config_overrides: Extra global config overrides. + + Returns: + The running :class:`ModuleCoordinator`. + """ + if isinstance(recording, str): + recording = RecordReplay(recording) + + recorded_streams = set(recording.store.list_streams()) + if not recorded_streams: + raise ValueError("Recording is empty — no streams to replay") + + # Find modules whose OUTs overlap with the recording. + # If ANY OUTs are covered, disable the module — the recording + # replaces it. Uncovered OUTs (e.g. on SHM, or never published) + # are simply absent during replay; downstream modules that need + # them won't receive data, which is the expected degraded mode. + modules_to_disable: list[type[ModuleBase]] = [] + for bp in self.blueprints: + out_names = {conn.name for conn in bp.streams if conn.direction == "out"} + if not out_names: + continue + covered = out_names & recorded_streams + if covered: + modules_to_disable.append(bp.module) + uncovered = out_names - covered + if uncovered: + logger.warning( + "Replay: disabling %s (partial coverage: replaying %s, missing %s)", + bp.module.__name__, + covered, + uncovered, + ) + else: + logger.info( + "Replay: disabling %s (all OUTs covered)", + bp.module.__name__, + ) + + if not modules_to_disable: + logger.warning( + "Replay: no modules disabled — recording streams %s " + "don't match any module OUT names", + recorded_streams, + ) + + patched = self.disabled_modules(*modules_to_disable) + coordinator = patched.build(cli_config_overrides) + + # Start playback in background — publishes to LCM so other modules receive data + lcm = LCM() + lcm.start() + recording.play(pubsub=lcm, speed=speed) + + return coordinator + def build( self, cli_config_overrides: Mapping[str, Any] | None = None, @@ -480,6 +555,29 @@ def build( if cli_config_overrides: global_config.update(**dict(cli_config_overrides)) + # Auto-replay if --replay-file is set in global config + replay_file = global_config.replay_file + if replay_file: + logger.info("Auto-replay from %s", replay_file) + # Strip replay_file from all override sources so the nested + # build() inside replay() does not re-enter this branch. + global_config.replay_file = None + clean_cli = ( + {k: v for k, v in cli_config_overrides.items() if k != "replay_file"} + if cli_config_overrides + else None + ) + clean_bp = replace( + self, + global_config_overrides=MappingProxyType( + {k: v for k, v in self.global_config_overrides.items() if k != "replay_file"} + ), + ) + return clean_bp.replay( + replay_file, + cli_config_overrides=clean_cli, + ) + self._run_configurators() self._check_requirements() self._verify_no_name_conflicts() diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index 60072ae7fd..3e2c700d58 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -33,6 +33,7 @@ class GlobalConfig(BaseSettings): simulation: bool = False replay: bool = False replay_dir: str = "go2_sf_office" + replay_file: str | None = None new_memory: bool = False viewer: ViewerBackend = "rerun" n_workers: int = 2 diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index 10227eae93..e048c3cfd2 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -14,8 +14,8 @@ from __future__ import annotations +import asyncio from concurrent.futures import ThreadPoolExecutor -import threading from typing import TYPE_CHECKING, Any from dimos.core.global_config import GlobalConfig, global_config @@ -154,10 +154,10 @@ def start_all_modules(self) -> None: def get_instance(self, module: type[ModuleBase]) -> ModuleProxy: return self._deployed_modules.get(module) # type: ignore[return-value, no-any-return] - def loop(self) -> None: - stop = threading.Event() + async def loop(self) -> None: + stop = asyncio.Event() try: - stop.wait() + await stop.wait() except KeyboardInterrupt: return finally: diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index c861993de9..b990700d46 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -16,6 +16,7 @@ from __future__ import annotations +from contextlib import suppress from dataclasses import replace from typing import TYPE_CHECKING, Any, Generic, TypeVar @@ -220,6 +221,17 @@ def _iterate_live( finally: sub.dispose() + def delete_range(self, t1: float, t2: float) -> int: + """Delete observations in [t1, t2] from all stores. Returns count deleted.""" + ids = self.metadata_store.delete_range(t1, t2) + for obs_id in ids: + if self.blob_store is not None: + with suppress(KeyError): + self.blob_store.delete(self.name, obs_id) + if self.vector_store is not None: + self.vector_store.delete(self.name, obs_id) + return len(ids) + def count(self, query: StreamQuery) -> int: if query.search_vec: return sum(1 for _ in self.iterate(query)) diff --git a/dimos/memory2/codecs/base.py b/dimos/memory2/codecs/base.py index 821b36b60f..01447de852 100644 --- a/dimos/memory2/codecs/base.py +++ b/dimos/memory2/codecs/base.py @@ -17,6 +17,8 @@ import importlib from typing import Any, Protocol, TypeVar, runtime_checkable +from dimos.msgs.sensor_msgs.Image import Image + T = TypeVar("T") @@ -33,8 +35,6 @@ def codec_for(payload_type: type[Any] | None = None) -> Codec[Any]: from dimos.memory2.codecs.pickle import PickleCodec if payload_type is not None: - from dimos.msgs.sensor_msgs.Image import Image - if issubclass(payload_type, Image): from dimos.memory2.codecs.jpeg import JpegCodec diff --git a/dimos/memory2/observationstore/base.py b/dimos/memory2/observationstore/base.py index 4d94889fb0..abe2f50e8e 100644 --- a/dimos/memory2/observationstore/base.py +++ b/dimos/memory2/observationstore/base.py @@ -69,5 +69,9 @@ def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: """Batch fetch by id (for vector search results).""" ... + def delete_range(self, t1: float, t2: float) -> list[int]: + """Delete observations with ts in [t1, t2]. Returns deleted IDs.""" + raise NotImplementedError + def serialize(self) -> dict[str, Any]: return {"class": qual(type(self)), "config": self.config.model_dump()} diff --git a/dimos/memory2/observationstore/memory.py b/dimos/memory2/observationstore/memory.py index 529cd06394..19c90eb109 100644 --- a/dimos/memory2/observationstore/memory.py +++ b/dimos/memory2/observationstore/memory.py @@ -78,3 +78,11 @@ def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: id_set = set(ids) with self._lock: return [obs for obs in self._observations if obs.id in id_set] + + def delete_range(self, t1: float, t2: float) -> list[int]: + """Delete observations with ts in [t1, t2]. Returns deleted IDs.""" + with self._lock: + to_delete = [obs for obs in self._observations if t1 <= obs.ts <= t2] + ids = [obs.id for obs in to_delete] + self._observations = [obs for obs in self._observations if not (t1 <= obs.ts <= t2)] + return ids diff --git a/dimos/memory2/observationstore/sqlite.py b/dimos/memory2/observationstore/sqlite.py index 5d680c540a..0146117ea8 100644 --- a/dimos/memory2/observationstore/sqlite.py +++ b/dimos/memory2/observationstore/sqlite.py @@ -440,5 +440,21 @@ def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: rows = self._conn.execute(sql, ids).fetchall() return [self._row_to_obs(r, has_blob=join) for r in rows] + def delete_range(self, t1: float, t2: float) -> list[int]: + """Delete observations with ts in [t1, t2]. Returns deleted IDs.""" + with self._lock: + rows = self._conn.execute( + f'SELECT id FROM "{self._name}" WHERE ts >= ? AND ts <= ?', (t1, t2) + ).fetchall() + ids = [r[0] for r in rows] + if ids: + placeholders = ",".join("?" * len(ids)) + self._conn.execute(f'DELETE FROM "{self._name}" WHERE id IN ({placeholders})', ids) + self._conn.execute( + f'DELETE FROM "{self._name}_rtree" WHERE id IN ({placeholders})', ids + ) + self._conn.commit() + return ids + def stop(self) -> None: super().stop() diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index 545d387c32..b3b828d3d7 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -335,6 +335,12 @@ def subscribe( on_completed=on_completed, ) + def delete_range(self, t1: float, t2: float) -> int: + """Delete all observations with timestamps in [t1, t2]. Returns count deleted.""" + if isinstance(self._source, Stream): + raise TypeError("Cannot delete from a transform stream.") + return self._source.delete_range(t1, t2) + def append( self, payload: T, diff --git a/dimos/protocol/pubsub/impl/lcmpubsub.py b/dimos/protocol/pubsub/impl/lcmpubsub.py index 50c7c49f2f..7933c2503f 100644 --- a/dimos/protocol/pubsub/impl/lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/lcmpubsub.py @@ -18,7 +18,6 @@ from dataclasses import dataclass import re import threading -from typing import Any from dimos.msgs.protocol import DimosMsg from dimos.protocol.pubsub.encoders import ( @@ -73,7 +72,7 @@ def from_channel_str(channel: str, default_lcm_type: type[DimosMsg] | None = Non return Topic(topic=topic_str, lcm_type=lcm_type or default_lcm_type) -class LCMPubSubBase(LCMService, AllPubSub[Topic, Any]): +class LCMPubSubBase(LCMService, AllPubSub[Topic, bytes]): """LCM-based PubSub with native regex subscription support. LCM natively supports regex patterns in subscribe(), so we implement @@ -92,7 +91,7 @@ def publish(self, topic: Topic | str, message: bytes) -> None: topic_str = str(topic) if isinstance(topic, Topic) else topic self.l.publish(topic_str, message) - def subscribe_all(self, callback: Callable[[bytes, Topic], Any]) -> Callable[[], None]: + def subscribe_all(self, callback: Callable[[bytes, Topic], None]) -> Callable[[], None]: return self.subscribe(Topic(re.compile(".*")), callback) # type: ignore[arg-type] def subscribe( diff --git a/dimos/protocol/pubsub/spec.py b/dimos/protocol/pubsub/spec.py index fe979fce82..0e292cfcb7 100644 --- a/dimos/protocol/pubsub/spec.py +++ b/dimos/protocol/pubsub/spec.py @@ -186,6 +186,6 @@ class SubscribeAllCapable(Protocol[MsgT_co, TopicT_co]): Both AllPubSub (native) and DiscoveryPubSub (synthesized) satisfy this. """ - def subscribe_all(self, callback: Callable[[Any, Any], Any]) -> Callable[[], None]: + def subscribe_all(self, callback: Callable[[MsgT_co, TopicT_co], None]) -> Callable[[], None]: """Subscribe to all messages on all topics.""" ... diff --git a/dimos/record/__init__.py b/dimos/record/__init__.py new file mode 100644 index 0000000000..1bb0bfbcc9 --- /dev/null +++ b/dimos/record/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.record.record_replay import RecordReplay + +__all__ = ("RecordReplay",) diff --git a/dimos/record/record_replay.py b/dimos/record/record_replay.py new file mode 100644 index 0000000000..4ef31e7511 --- /dev/null +++ b/dimos/record/record_replay.py @@ -0,0 +1,389 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RecordReplay — record and replay pub/sub topics using memory2 stores. + +Usage:: + + from dimos.record import RecordReplay + from dimos.protocol.pubsub.impl.lcmpubsub import LCM + + # Record from live LCM traffic + rec = RecordReplay(Path("my_recording.db")) + rec.start_recording([LCM()]) + # ... robot runs ... + rec.stop_recording() + + # Replay into LCM (viewable via rerun-bridge) + rec.play(pubsub=LCM(), speed=1.0) + + # Timeline editing + rec.trim(start=2.0, end=30.0) + rec.delete_range(start=10.0, end=12.0) + + # Query/filter via memory2 streams + rec.stream("lidar").time_range(0, 5).count() +""" + +import asyncio +from collections.abc import Callable, Collection, Container +from contextlib import suppress +import heapq +import logging +import math +import re +import sys +import time +from typing import Any, TypedDict + +from dimos.memory2.codecs.base import _resolve_payload_type +from dimos.memory2.store.sqlite import SqliteStore +from dimos.protocol.pubsub.impl.lcmpubsub import LCMPubSubBase, Topic +from dimos.protocol.pubsub.spec import PubSub + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing import Any as Self + +logger = logging.getLogger(__name__) + +_SANITIZE_RE = re.compile(r"[^A-Za-z0-9_]") + + +def topic_to_stream_name(channel: str) -> str: + """Convert a raw LCM channel/topic pattern to a safe stream name.""" + name = channel.split("#")[0].lstrip("/") + name = _SANITIZE_RE.sub("_", name) + if name and name[0].isdigit(): + name = f"_{name}" + return name or "_unknown" + + +class StreamInfo(TypedDict): + name: str + count: int + start: float + end: float + duration: float + type: str + + +class RecordReplay: + """Record and replay pub/sub topics using memory2 stores. + + A recording is a single SQLite file containing one stream per topic. + Supports recording from any ``SubscribeAllCapable`` pubsub, playback + with speed control, seeking, and timeline editing (trim/delete). + + The underlying :class:`SqliteStore` is fully accessible for advanced + queries via :attr:`store`, :attr:`streams`, and :meth:`stream`. + """ + + def __init__(self, path: str) -> None: + self._store = SqliteStore(path=path) + + self._recording = False + self._unsubscribes: list[Callable[[], None]] = [] + self._topic_filter: Container[str] | None = None + + self._resume = asyncio.Event() + self._play_task: asyncio.Task | None = None + self._play_speed = 1.0 + self._position = 0.0 + self._pubsub = None + + @property + def store(self) -> SqliteStore: + """The underlying store.""" + return self._store + + @property + def path(self) -> str: + """Path to the recording file.""" + return self._store.config.path + + @property + def is_recording(self) -> bool: + return self._recording + + @property + def is_playing(self) -> bool: + return self._play_task is not None and not self._play_task.done() + + @property + def is_paused(self) -> bool: + return self.is_playing and not self._resume.is_set() + + def start_recording( + self, + pubsubs: Collection[LCMPubSubBase], + topic_filter: Container[str] | None = None, + ) -> None: + """Start recording messages from the given pubsubs. + + Each pubsub is subscribed via ``subscribe_all()``. Messages are + stored in per-topic streams with automatic codec selection. + + Args: + pubsubs: List of pubsubs to subscribe to. + topic_filter: If provided, only record topics whose sanitized + stream name is in this set. If ``None``, record everything. + """ + if self._recording: + raise RuntimeError("Already recording") + self._recording = True + self._topic_filter = topic_filter + + for pubsub in pubsubs: + pubsub.start() + unsub = pubsub.subscribe_all(self._on_message) + self._unsubscribes.append(unsub) + + logger.info("Recording started on %d pubsub(s)", len(pubsubs)) + + def stop_recording(self) -> None: + """Stop recording.""" + if not self._recording: + return + self._recording = False + for unsub in self._unsubscribes: + unsub() + self._unsubscribes.clear() + logger.info("Recording stopped") + + def _on_message(self, msg: bytes, topic: Topic) -> None: + """Handle incoming message during recording.""" + stream_name = topic_to_stream_name(topic.pattern) + + if self._topic_filter is not None and stream_name not in self._topic_filter: + return + + msg_type = type(msg) + + s = self._store.stream(stream_name, msg_type) + s.append(msg, ts=time.time()) + + @property + def duration(self) -> float: + """Total duration of the recording in seconds.""" + t_min, t_max = self.time_range + return t_max - t_min + + @property + def time_range(self) -> tuple[float, float]: + """Absolute (min_ts, max_ts) across all streams.""" + streams = self._store.list_streams() + if not streams: + return (0.0, 0.0) + t_min = math.inf + t_max = -math.inf + for name in streams: + s = self._store.stream(name) + if s.exists(): + t0, t1 = s.get_time_range() + t_min = min(t_min, t0) + t_max = max(t_max, t1) + if t_min is math.inf: + return (0.0, 0.0) + return (t_min, t_max) + + @property + def position(self) -> float: + """Current playback position in seconds from recording start.""" + return self._position + + def stream_info(self) -> tuple[StreamInfo, ...]: + """Return per-stream metadata: name, count, time range, type.""" + result = [] + for name in self._store.list_streams(): + s = self._store.stream(name) + info: StreamInfo = {"name": name, "count": s.count()} + if info["count"] > 0: + t0, t1 = s.get_time_range() + info["start"] = t0 + info["end"] = t1 + info["duration"] = t1 - t0 + # Get payload type from registry + reg = self._store._registry.get(name) + if reg: + info["type"] = reg.get("payload_module", "unknown") + result.append(info) + return tuple(result) + + def play( + self, + pubsub: PubSub[Any, Any] | None = None, + speed: float = 1.0, + ) -> None: + """Start playback in a background thread. + + Args: + pubsub: If provided, messages are published to this pubsub. + Use LCM() to make them visible to rerun-bridge. + speed: Playback speed multiplier (1.0 = realtime, 2.0 = 2x, etc). + """ + if self.is_playing: + raise RuntimeError("Already playing") + + self._play_speed = speed + self._pubsub = pubsub + # Set resume so playback starts, this is cleared to pause playback. + self._resume.set() + self._play_task = asyncio.create_task(self._playback_loop(pubsub)) + + async def _playback_loop(self, pubsub: PubSub[Any, Any] | None) -> None: + """Main playback loop — merges all streams by timestamp.""" + t_min, t_max = self.time_range + if t_min >= t_max: + return + + # Build iterators for each stream, starting from seek position + start_ts = t_min + self._position + heap: list[tuple[float, int, str, Any]] = [] + counter = 0 # tiebreaker for heapq + + for name in self._store.list_streams(): + s = self._store.stream(name) + it = iter(s.after(start_ts - 0.001)) + try: + obs = next(it) + except StopIteration: + continue + heapq.heappush(heap, (obs.ts, counter, name, (obs, it))) + counter += 1 + + if not heap: + return + + topic_map = {} if pubsub is None else self._build_topic_map(pubsub) + wall_start = time.monotonic() + rec_start = heap[0][0] # earliest observation timestamp + + while heap: + if not self._resume.is_set(): + pause_start = time.monotonic() + await self._resume.wait() + wall_start += time.monotonic() - pause_start + + ts, _, stream_name, (obs, it) = heapq.heappop(heap) + + # Wait for correct playback time + elapsed_rec = ts - rec_start + target_wall = wall_start + (elapsed_rec / self._play_speed) + sleep_time = target_wall - time.monotonic() + await asyncio.sleep(sleep_time) + + self._position = ts - t_min + if pubsub is not None and stream_name in topic_map: + pubsub.publish(topic_map[stream_name], obs.data) + + try: + next_obs = next(it) + except StopIteration: + continue + heapq.heappush(heap, (next_obs.ts, counter, stream_name, (next_obs, it))) + counter += 1 + + def _build_topic_map(self, pubsub: PubSub[Any, Any]) -> dict[str, Topic]: + """Build stream_name -> Topic mapping for publishing.""" + topic_map: dict[str, Topic] = {} + + for name in self._store.list_streams(): + reg = self._store._registry.get(name) + if reg is None: + continue + + payload_module = reg.get("payload_module") + lcm_type = None + if payload_module: + lcm_type = _resolve_payload_type(payload_module) + + topic_map[name] = Topic(f"/{name}", lcm_type=lcm_type) + + return topic_map + + def pause(self) -> None: + """Pause playback. Resume with :meth:`resume`.""" + self._resume.clear() + + def resume(self) -> None: + """Resume paused playback.""" + self._resume.set() + + async def stop_playback(self) -> None: + """Stop playback.""" + self._resume.set() + if self._play_task is not None: + self._play_task.cancel() + with suppress(asyncio.CancelledError): + await self._play_task + self._play_task = None + + def seek(self, seconds: float) -> None: + """Set playback position in seconds from recording start. + + Takes effect immediately if playing (restarts playback loop). + If not playing, sets the position for the next :meth:`play`. + """ + self._position = max(0.0, min(seconds, self.duration)) + if self.is_playing: + self.stop_playback() + assert self._pubsub is not None + self.play(pubsub=self._pubsub, speed=self._play_speed) + + def delete_range(self, start: float, end: float) -> int: + """Delete observations in [start, end] seconds from recording start. + + Returns total count of deleted observations across all streams. + """ + t_min, _ = self.time_range + abs_start = t_min + start + abs_end = t_min + end + + total = 0 + for name in self._store.list_streams(): + s = self._store.stream(name) + total += s.delete_range(abs_start, abs_end) + return total + + def trim(self, start: float, end: float) -> int: + """Keep only [start, end] seconds, delete everything else. + + Returns total count of deleted observations. + """ + t_min, t_max = self.time_range + total = 0 + if start > 0: + total += self.delete_range(0, start - 0.0001) + if end < (t_max - t_min): + total += self.delete_range(end + 0.0001, t_max - t_min + 1) + return total + + async def close(self) -> None: + """Stop recording/playback and close the store.""" + self.stop_recording() + await self.stop_playback() + self._store.stop() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *exc: Any) -> None: + await self.close() + + def __repr__(self) -> str: + streams = self._store.list_streams() + dur = self.duration + return f"RecordReplay({self._store.config.path!r}, streams={len(streams)}, duration={dur:.1f}s)" diff --git a/dimos/record/test_record_replay.py b/dimos/record/test_record_replay.py new file mode 100644 index 0000000000..a14d1ef3cd --- /dev/null +++ b/dimos/record/test_record_replay.py @@ -0,0 +1,255 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RecordReplay.""" + +import asyncio +from collections.abc import Callable +from contextlib import suppress +from pathlib import Path +import threading +from typing import Any + +import pytest + +from dimos.record import RecordReplay + + +class FakeTopic: + """Minimal topic for testing.""" + + def __init__(self, name: str) -> None: + self.topic = name + self.lcm_type = None + + @property + def pattern(self) -> str: + return self.topic + + def __str__(self) -> str: + return self.topic + + +class FakePubSub: + """Minimal PubSub that supports subscribe_all for testing.""" + + def __init__(self) -> None: + self._subscribers: list[Callable[[Any, Any], Any]] = [] + self._lock = threading.Lock() + self._started = False + + def start(self) -> None: + self._started = True + + def stop(self) -> None: + self._started = False + + def publish(self, topic: Any, message: Any) -> None: + # Not needed for recording tests + pass + + def subscribe_all(self, callback: Callable[[Any, Any], Any]) -> Callable[[], None]: + with self._lock: + self._subscribers.append(callback) + + def unsub() -> None: + with self._lock: + with suppress(ValueError): + self._subscribers.remove(callback) + + return unsub + + def emit(self, topic_name: str, msg: Any) -> None: + """Test helper: simulate a message arriving.""" + topic = FakeTopic(topic_name) + with self._lock: + subs = list(self._subscribers) + for cb in subs: + cb(msg, topic) + + +class SimpleMsg: + """A simple test message (not LCM, uses pickle codec).""" + + def __init__(self, value: float) -> None: + self.value = value + + def __eq__(self, other: object) -> bool: + return isinstance(other, SimpleMsg) and self.value == other.value + + +@pytest.fixture +def tmp_db(tmp_path: Path) -> str: + return str(tmp_path / "test_recording.db") + + +class TestRecordReplay: + async def test_record_and_list_streams(self, tmp_db: str) -> None: + pubsub = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([pubsub]) + assert rec.is_recording + + pubsub.emit("/sensor/lidar", SimpleMsg(1.0)) + pubsub.emit("/sensor/odom", SimpleMsg(2.0)) + pubsub.emit("/sensor/lidar", SimpleMsg(3.0)) + await asyncio.sleep(0.05) # let timestamps diverge + + rec.stop_recording() + assert not rec.is_recording + + streams = rec.store.list_streams() + assert "sensor_lidar" in streams + assert "sensor_odom" in streams + + async def test_record_and_query(self, tmp_db: str) -> None: + pubsub = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([pubsub]) + + for i in range(10): + pubsub.emit("/data", SimpleMsg(float(i))) + await asyncio.sleep(0.01) + + rec.stop_recording() + + s = rec.store.stream("data") + assert s.count() == 10 + first = s.first() + assert isinstance(first.data, SimpleMsg) + assert first.data.value == 0.0 + + async def test_duration(self, tmp_db: str) -> None: + pubsub = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([pubsub]) + pubsub.emit("/a", SimpleMsg(0.0)) + await asyncio.sleep(0.1) + pubsub.emit("/a", SimpleMsg(1.0)) + rec.stop_recording() + + assert rec.duration >= 0.05 # at least some duration + + async def test_stream_info(self, tmp_db: str) -> None: + pubsub = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([pubsub]) + for i in range(5): + pubsub.emit("/sensor", SimpleMsg(float(i))) + await asyncio.sleep(0.01) + rec.stop_recording() + + infos = rec.stream_info() + assert len(infos) == 1 + assert infos[0]["name"] == "sensor" + assert infos[0]["count"] == 5 + + async def test_delete_range(self, tmp_db: str) -> None: + pubsub = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([pubsub]) + for i in range(20): + pubsub.emit("/data", SimpleMsg(float(i))) + await asyncio.sleep(0.01) + rec.stop_recording() + + before = rec.store.stream("data").count() + assert before == 20 + + dur = rec.duration + # Delete middle third + deleted = rec.delete_range(dur / 3, 2 * dur / 3) + assert deleted > 0 + + after = rec.store.stream("data").count() + assert after < before + + async def test_trim(self, tmp_db: str) -> None: + pubsub = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([pubsub]) + for i in range(30): + pubsub.emit("/data", SimpleMsg(float(i))) + await asyncio.sleep(0.01) + rec.stop_recording() + + before = rec.store.stream("data").count() + dur = rec.duration + # Trim to middle third + rec.trim(dur / 3, 2 * dur / 3) + + after = rec.store.stream("data").count() + assert after < before + + async def test_repr(self, tmp_db: str) -> None: + async with RecordReplay(tmp_db) as rec: + r = repr(rec) + assert "RecordReplay" in r + assert "streams=0" in r + + async def test_playback_runs(self, tmp_db: str) -> None: + """Test that playback task starts and finishes.""" + pubsub = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([pubsub]) + for i in range(5): + pubsub.emit("/data", SimpleMsg(float(i))) + await asyncio.sleep(0.01) + rec.stop_recording() + + rec.play(speed=100.0) # very fast + assert rec.is_playing + async with asyncio.timeout(0.1): + await rec._play_task + assert not rec.is_playing + + async def test_stop_playback(self, tmp_db: str) -> None: + pubsub = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([pubsub]) + for i in range(100): + pubsub.emit("/data", SimpleMsg(float(i))) + await asyncio.sleep(0.005) + rec.stop_recording() + + rec.play(speed=0.1) # slow + await asyncio.sleep(0.1) + assert rec.is_playing + await rec.stop_playback() + assert not rec.is_playing + + async def test_seek(self, tmp_db: str) -> None: + pubsub = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([pubsub]) + for i in range(10): + pubsub.emit("/data", SimpleMsg(float(i))) + await asyncio.sleep(0.01) + rec.stop_recording() + + rec.seek(0.05) + assert rec.position == pytest.approx(0.05, abs=0.01) + + async def test_multiple_pubsubs(self, tmp_db: str) -> None: + ps1 = FakePubSub() + ps2 = FakePubSub() + async with RecordReplay(tmp_db) as rec: + rec.start_recording([ps1, ps2]) + ps1.emit("/from_ps1", SimpleMsg(1.0)) + ps2.emit("/from_ps2", SimpleMsg(2.0)) + rec.stop_recording() + + streams = rec.store.list_streams() + assert "from_ps1" in streams + assert "from_ps2" in streams diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index 1137a612f3..17614034ac 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio from datetime import datetime, timezone import inspect import json @@ -28,9 +29,21 @@ import typer from dimos.agents.mcp.mcp_adapter import McpAdapter, McpError +from dimos.core.blueprints import autoconnect from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.run_registry import get_most_recent, is_pid_alive, stop_entry -from dimos.utils.logging_config import setup_logger +from dimos.core.run_registry import ( + LOG_BASE_DIR, + RunEntry, + check_port_conflicts, + cleanup_stale, + generate_run_id, + get_most_recent, + is_pid_alive, + stop_entry, +) +from dimos.robot.get_all_blueprints import get_by_name, get_module_by_name +from dimos.utils.cli.recorder.run_recorder import main as recorder_main +from dimos.utils.logging_config import set_run_log_dir, setup_exception_handler, setup_logger logger = setup_logger() @@ -108,27 +121,12 @@ def callback(**kwargs) -> None: # type: ignore[no-untyped-def] main.callback()(create_dynamic_callback()) # type: ignore[no-untyped-call] -@main.command() -def run( +async def _run( ctx: typer.Context, robot_types: list[str] = typer.Argument(..., help="Blueprints or modules to run"), daemon: bool = typer.Option(False, "--daemon", "-d", help="Run in background"), disable: list[str] = typer.Option([], "--disable", help="Module names to disable"), ) -> None: - """Start a robot blueprint""" - logger.info("Starting DimOS") - - from dimos.core.blueprints import autoconnect - from dimos.core.run_registry import ( - LOG_BASE_DIR, - RunEntry, - check_port_conflicts, - cleanup_stale, - generate_run_id, - ) - from dimos.robot.get_all_blueprints import get_by_name, get_module_by_name - from dimos.utils.logging_config import set_run_log_dir, setup_exception_handler - setup_exception_handler() cli_config_overrides: dict[str, Any] = ctx.obj @@ -203,7 +201,7 @@ def run( ) entry.save() install_signal_handlers(entry, coordinator) - coordinator.loop() + await coordinator.loop() else: entry = RunEntry( run_id=run_id, @@ -217,11 +215,23 @@ def run( ) entry.save() try: - coordinator.loop() + await coordinator.loop() finally: entry.remove() +@main.command() +def run( + ctx: typer.Context, + robot_types: list[str] = typer.Argument(..., help="Blueprints or modules to run"), + daemon: bool = typer.Option(False, "--daemon", "-d", help="Run in background"), + disable: list[str] = typer.Option([], "--disable", help="Module names to disable"), +) -> None: + """Start a robot blueprint""" + logger.info("Starting DimOS") + asyncio.run(_run(ctx, robot_types, daemon, disable)) + + @main.command() def status() -> None: """Show the running DimOS instance.""" @@ -523,6 +533,13 @@ def top(ctx: typer.Context) -> None: dtop_main() +@main.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) +def recorder(ctx: typer.Context) -> None: + """Record and replay tool — terminal VLC for dimos recordings.""" + sys.argv = ["recorder", *ctx.args] + recorder_main() + + topic_app = typer.Typer(help="Topic commands for pub/sub") main.add_typer(topic_app, name="topic") diff --git a/dimos/types/__init__.py b/dimos/types/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/utils/cli/recorder/__init__.py b/dimos/utils/cli/recorder/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/utils/cli/recorder/run_recorder.py b/dimos/utils/cli/recorder/run_recorder.py new file mode 100644 index 0000000000..4827abf6f2 --- /dev/null +++ b/dimos/utils/cli/recorder/run_recorder.py @@ -0,0 +1,542 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""recorder — Terminal VLC for dimos recordings. + +Record from live LCM traffic, play back recordings, trim, seek. +Run ``rerun-bridge`` in another terminal to visualize playback. + +Usage:: + + recorder # interactive — record from LCM + recorder play my_recording.db # play an existing recording + recorder --help +""" + +from __future__ import annotations + +from collections import deque +import time +from typing import Any + +from rich.text import Text +from textual.app import App, ComposeResult +from textual.color import Color +from textual.containers import Horizontal +from textual.widgets import DataTable, Footer, Header, Static + +from dimos.utils.cli import theme + +# Braille sparkline constants (same as dtop) +_BRAILLE_BASE = 0x2800 +_LDOTS = (0x00, 0x40, 0x44, 0x46, 0x47) +_RDOTS = (0x00, 0x80, 0xA0, 0xB0, 0xB8) +_SPARK_WIDTH = 16 + +from dimos.record.record_replay import topic_to_stream_name + + +def _heat(ratio: float) -> str: + """Map 0..1 to cyan -> yellow -> red.""" + cyan = Color.parse(theme.CYAN) + yellow = Color.parse(theme.YELLOW) + red = Color.parse(theme.RED) + if ratio <= 0.5: + return cyan.blend(yellow, ratio * 2).hex + return yellow.blend(red, (ratio - 0.5) * 2).hex + + +def _spark(history: deque[float], max_val: float, width: int = _SPARK_WIDTH) -> Text: + """Braille sparkline from a deque of values.""" + n = width * 2 + vals = list(history) + if len(vals) < n: + vals = [0.0] * (n - len(vals)) + vals + else: + vals = vals[-n:] + result = Text() + if max_val <= 0: + max_val = 1.0 + for i in range(0, n, 2): + lv = min(vals[i] / max_val, 1.0) + rv = min(vals[i + 1] / max_val, 1.0) + li = min(int(lv * 4 + 0.5), 4) + ri = min(int(rv * 4 + 0.5), 4) + ch = chr(_BRAILLE_BASE | _LDOTS[li] | _RDOTS[ri]) + result.append(ch, style=_heat(max(lv, rv))) + return result + + +def _fmt_time(seconds: float) -> str: + """Format seconds as MM:SS.s""" + if seconds < 0: + seconds = 0 + m = int(seconds) // 60 + s = seconds - m * 60 + return f"{m:02d}:{s:05.2f}" + + +def _progress_bar(position: float, duration: float, width: int = 40) -> Text: + """Render a progress bar with position indicator.""" + if duration <= 0: + return Text("░" * width, style=theme.DIM) + ratio = min(position / duration, 1.0) + filled = int(ratio * width) + result = Text() + result.append("█" * filled, style=theme.CYAN) + if filled < width: + result.append("▓", style=theme.BRIGHT_CYAN) + result.append("░" * (width - filled - 1), style=theme.DIM) + return result + + +def _short_type(channel: str) -> str: + """Extract the short type name from a channel string.""" + if "#" not in channel: + return "" + return channel.rsplit("#", 1)[-1].rsplit(".", 1)[-1] + + +class RecorderApp(App[None]): + """Terminal VLC for dimos recordings. + + Shows all live LCM topics (like lcmspy). Select topics then press + ``r`` to record, ``space`` to play back, arrow keys to seek, etc. + """ + + CSS_PATH = "../dimos.tcss" + + CSS = f""" + Screen {{ + layout: vertical; + background: {theme.BACKGROUND}; + }} + #streams {{ + height: 1fr; + border: solid {theme.BORDER}; + background: {theme.BG}; + scrollbar-size: 0 0; + }} + #streams > .datatable--header {{ + color: {theme.ACCENT}; + background: transparent; + }} + #streams > .datatable--cursor {{ + background: {theme.BRIGHT_BLACK}; + }} + #timeline {{ + height: 5; + padding: 1 2; + background: {theme.BG}; + border-top: solid {theme.DIM}; + }} + #controls {{ + height: 3; + padding: 0 2; + background: {theme.BG}; + border-top: solid {theme.DIM}; + }} + #status-left {{ + width: 1fr; + }} + #status-right {{ + width: auto; + }} + """ + + BINDINGS = [ + ("q", "quit", "Quit"), + ("space", "toggle_select", "Toggle"), + ("a", "select_all", "All"), + ("n", "select_none", "None"), + ("r", "toggle_record", "Rec"), + ("p", "toggle_play", "Play"), + ("s", "stop_all", "Stop"), + ("left", "seek_back", "-5s"), + ("right", "seek_fwd", "+5s"), + ("shift+left", "seek_back_big", "-30s"), + ("shift+right", "seek_fwd_big", "+30s"), + ("[", "mark_trim_start", "In"), + ("]", "mark_trim_end", "Out"), + ("t", "do_trim", "Trim"), + ("d", "do_delete", "Del"), + ] + + def __init__( + self, + db_path: str | None = None, + autoplay: bool = False, + ) -> None: + super().__init__() + from dimos.protocol.service.lcmservice import autoconf + + autoconf(check_only=True) + + if db_path is None: + ts = time.strftime("%Y%m%d_%H%M%S") + self._db_path = f"recording_{ts}.db" + else: + self._db_path = db_path + self._autoplay = autoplay + self._recorder: RecordReplay | None = None + self._lcm: Any = None + self._spy: Any = None + + # Per-stream sparkline history keyed by stream_name + self._freq_history: dict[str, deque[float]] = {} + # Set of selected stream names (for recording) + self._selected: set[str] = set() + + self._trim_in: float | None = None + self._trim_out: float | None = None + + self._table: DataTable[Any] | None = None + + def compose(self) -> ComposeResult: + yield Header(show_clock=True) + self._table = DataTable(zebra_stripes=False, cursor_type="row") + self._table.id = "streams" + self._table.add_column("REC", key="sel", width=5) + self._table.add_column("Topic", key="topic") + self._table.add_column("Type", key="type") + self._table.add_column("Freq", key="freq") + self._table.add_column("Bandwidth", key="bw") + self._table.add_column("Recorded", key="rec") + self._table.add_column("Activity", key="activity") + yield self._table + yield Static(id="timeline") + with Horizontal(id="controls"): + yield Static(id="status-left") + yield Static(id="status-right") + yield Footer() + + def on_mount(self) -> None: + from dimos.protocol.pubsub.impl.lcmpubsub import LCM + from dimos.record import RecordReplay + from dimos.utils.cli.lcmspy.lcmspy import GraphLCMSpy + + self._lcm = LCM() + + # Live topic discovery via LCM spy (same as lcmspy tool) + self._spy = GraphLCMSpy(graph_log_window=0.5) + self._spy.start() + + if self._db_path: + self._recorder = RecordReplay(self._db_path) + else: + self._recorder = RecordReplay() + + self.title = f"recorder — {self._recorder.path}" + self.set_interval(0.5, self._refresh) + + if self._autoplay and self._db_path: + self._start_playback() + + async def on_unmount(self) -> None: + if self._recorder: + await self._recorder.close() + if self._spy: + self._spy.stop() + if self._lcm and hasattr(self._lcm, "stop"): + self._lcm.stop() + + # ------------------------------------------------------------------ + # Actions + # ------------------------------------------------------------------ + + def action_toggle_select(self) -> None: + """Toggle selection on the row under the cursor.""" + if self._table is None or self._table.row_count == 0: + return + row_key, _ = self._table.coordinate_to_cell_key(self._table.cursor_coordinate) + name = str(row_key.value) + if name in self._selected: + self._selected.discard(name) + else: + self._selected.add(name) + + def action_select_all(self) -> None: + """Select all visible topics.""" + if self._spy is None: + return + with self._spy._topic_lock: + channels = list(self._spy.topic.keys()) + for ch in channels: + self._selected.add(topic_to_stream_name(ch)) + # Also include any already-recorded streams + if self._recorder: + self._selected.update(self._recorder.store.list_streams()) + + def action_select_none(self) -> None: + self._selected.clear() + + def action_toggle_play(self) -> None: + if self._recorder is None: + return + if self._recorder.is_recording: + return + if self._recorder.is_playing: + if self._recorder.is_paused: + self._recorder.resume() + else: + self._recorder.pause() + else: + self._start_playback() + + def action_toggle_record(self) -> None: + if self._recorder is None: + return + if self._recorder.is_playing: + return + if self._recorder.is_recording: + self._recorder.stop_recording() + else: + filt = self._selected if self._selected else None + self._recorder.start_recording([self._lcm], topic_filter=filt) + + async def action_stop_all(self) -> None: + if self._recorder is None: + return + self._recorder.stop_recording() + await self._recorder.stop_playback() + + def action_seek_back(self) -> None: + self._seek_relative(-5.0) + + def action_seek_fwd(self) -> None: + self._seek_relative(5.0) + + def action_seek_back_big(self) -> None: + self._seek_relative(-30.0) + + def action_seek_fwd_big(self) -> None: + self._seek_relative(30.0) + + def action_mark_trim_start(self) -> None: + if self._recorder: + self._trim_in = self._recorder.position + + def action_mark_trim_end(self) -> None: + if self._recorder: + self._trim_out = self._recorder.position + + def action_do_trim(self) -> None: + if self._recorder and self._trim_in is not None and self._trim_out is not None: + self._recorder.stop_playback() + lo, hi = sorted((self._trim_in, self._trim_out)) + self._recorder.trim(lo, hi) + self._trim_in = self._trim_out = None + + def action_do_delete(self) -> None: + if self._recorder and self._trim_in is not None and self._trim_out is not None: + self._recorder.stop_playback() + lo, hi = sorted((self._trim_in, self._trim_out)) + self._recorder.delete_range(lo, hi) + self._trim_in = self._trim_out = None + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _start_playback(self) -> None: + if self._recorder is None or self._lcm is None: + return + if hasattr(self._lcm, "start"): + self._lcm.start() + self._recorder.play(pubsub=self._lcm, speed=1.0) + + def _seek_relative(self, delta: float) -> None: + if self._recorder: + self._recorder.seek(self._recorder.position + delta) + + # ------------------------------------------------------------------ + # Refresh + # ------------------------------------------------------------------ + + def _refresh(self) -> None: + if self._table is None: + return + rec = self._recorder + spy = self._spy + + # ---- Build unified row list: live topics + recorded-only streams ---- + # Each row: (stream_name, channel, spy_topic_or_None) + rows: dict[str, tuple[str, Any]] = {} # stream_name -> (channel, spy_topic) + + if spy is not None: + with spy._topic_lock: + live_topics: dict[str, Any] = dict(spy.topic) # channel -> GraphTopic + for channel, spy_topic in live_topics.items(): + sname = topic_to_stream_name(channel) + rows[sname] = (channel, spy_topic) + + # Add streams that exist in the recording but are not live + recorded_streams = set(rec.store.list_streams()) if rec else set() + for sname in recorded_streams: + if sname not in rows: + rows[sname] = (sname, None) + + # ---- Render table ---- + # Remember cursor position so we can restore it + cursor_row = self._table.cursor_coordinate.row if self._table.row_count > 0 else 0 + self._table.clear(columns=False) + + sorted_names = sorted(rows.keys()) + for sname in sorted_names: + channel, spy_topic = rows[sname] + + # Selection marker + is_sel = sname in self._selected + if is_sel: + sel = Text(" [●] ", style=f"bold {theme.BRIGHT_GREEN}") + else: + sel = Text(" [ ] ", style=theme.DIM) + + # Topic name — green when actively recording, bright when selected + if rec and rec.is_recording and sname in (rec.store.list_streams()): + topic_style = f"bold {theme.BRIGHT_GREEN}" + elif is_sel: + topic_style = theme.BRIGHT_WHITE + else: + topic_style = theme.FOREGROUND + + # Type + type_str = _short_type(channel) if "#" in channel else "" + + # Live freq / bandwidth from spy + if spy_topic is not None: + freq = spy_topic.freq(5.0) + freq_text = Text(f"{freq:.1f} Hz", style=_heat(min(freq / 30.0, 1.0))) + kbps = spy_topic.kbps(5.0) + bw_text = Text(spy_topic.kbps_hr(5.0), style=_heat(min(kbps / 3072, 1.0))) + else: + freq_text = Text("—", style=theme.DIM) + bw_text = Text("—", style=theme.DIM) + + # Recorded count + if sname in recorded_streams: + try: + count = rec.stream(sname).count() + rec_text = Text(str(count), style=theme.YELLOW) + except Exception: + rec_text = Text("—", style=theme.DIM) + else: + rec_text = Text("", style=theme.DIM) + + # Sparkline from spy freq history + if sname not in self._freq_history: + self._freq_history[sname] = deque(maxlen=_SPARK_WIDTH * 2) + if spy_topic is not None: + self._freq_history[sname].append(spy_topic.freq(0.5)) + else: + self._freq_history[sname].append(0.0) + max_f = max(self._freq_history[sname]) if self._freq_history[sname] else 1.0 + activity = _spark(self._freq_history[sname], max_f) + + self._table.add_row( + sel, + Text(sname, style=topic_style), + Text(type_str, style=theme.BLUE), + freq_text, + bw_text, + rec_text, + activity, + key=sname, + ) + + # Restore cursor + if self._table.row_count > 0: + row = min(cursor_row, self._table.row_count - 1) + self._table.move_cursor(row=row) + + # ---- Timeline ---- + duration = rec.duration if rec else 0.0 + position = rec.position if rec else 0.0 + + timeline = Text() + timeline.append(" ") + timeline.append(_fmt_time(position), style=theme.BRIGHT_WHITE) + timeline.append(" ", style=theme.DIM) + timeline.append_text(_progress_bar(position, duration, width=50)) + timeline.append(" ", style=theme.DIM) + timeline.append(_fmt_time(duration), style=theme.FOREGROUND) + + if self._trim_in is not None or self._trim_out is not None: + timeline.append("\n ") + timeline.append("[", style=theme.YELLOW) + timeline.append( + _fmt_time(self._trim_in) if self._trim_in is not None else "--:--", + style=theme.YELLOW if self._trim_in is not None else theme.DIM, + ) + timeline.append(" .. ", style=theme.DIM) + timeline.append( + _fmt_time(self._trim_out) if self._trim_out is not None else "--:--", + style=theme.YELLOW if self._trim_out is not None else theme.DIM, + ) + timeline.append("]", style=theme.YELLOW) + + self.query_one("#timeline", Static).update(timeline) + + # ---- Status bar ---- + status = Text() + if rec and rec.is_recording: + status.append(" ● REC ", style=f"bold on {theme.RED}") + elif rec and rec.is_paused: + status.append(" ❚❚ PAUSED ", style=theme.YELLOW) + elif rec and rec.is_playing: + status.append(" ▶ PLAYING ", style=theme.BRIGHT_GREEN) + else: + status.append(" ■ STOPPED ", style=theme.DIM) + + n_live = len([r for r in rows.values() if r[1] is not None]) + n_sel = len(self._selected) + status.append(f" {n_live} live", style=theme.FOREGROUND) + if n_sel: + status.append(f" {n_sel} selected", style=theme.BRIGHT_GREEN) + if recorded_streams: + status.append(f" {len(recorded_streams)} recorded", style=theme.YELLOW) + + # Contextual hint + if not (rec and (rec.is_recording or rec.is_playing)): + if n_live > 0 and n_sel == 0: + status.append(" SPACE select, A all, R rec", style=theme.DIM) + elif n_sel > 0: + status.append(" R to record selected", style=theme.DIM) + + self.query_one("#status-left", Static).update(status) + + rhs = Text() + if rec: + rhs.append(f"{rec.path} ", style=theme.DIM) + self.query_one("#status-right", Static).update(rhs) + + +def main() -> None: + import sys + + db_path: Path | None = None + autoplay = False + + args = sys.argv[1:] + if args and args[0] == "play" and len(args) > 1: + db_path = Path(args[1]) + autoplay = True + elif args and not args[0].startswith("-"): + db_path = Path(args[0]) + + RecorderApp(db_path=db_path, autoplay=autoplay).run() + + +if __name__ == "__main__": + main() diff --git a/flake.nix b/flake.nix index 68dbf0ee8c..c8c4039e05 100644 --- a/flake.nix +++ b/flake.nix @@ -42,6 +42,7 @@ { vals.pkg=pkgs.opensshWithKerberos;flags={}; } { vals.pkg=pkgs.unixtools.ifconfig; flags={}; } { vals.pkg=pkgs.unixtools.netstat; flags={}; } + { vals.pkg=pkgs.uv; flags={}; } # when pip packages call cc with -I/usr/include, that causes problems on some machines, this swaps that out for the nix cc headers # this is only necessary for pip packages from venv, pip packages from nixpkgs.python312Packages.* already have "-I/usr/include" patched with the nix equivalent diff --git a/pyproject.toml b/pyproject.toml index 1fbd29f86f..3f0833a588 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ dependencies = [ "psutil>=7.0.0", "sqlite-vec>=0.1.6", "lz4>=4.4.5", + "pytest-asyncio>=0.26.0", ] @@ -99,6 +100,7 @@ dimos = "dimos.robot.cli.dimos:main" rerun-bridge = "dimos.visualization.rerun.bridge:app" doclinks = "dimos.utils.docs.doclinks:main" dtop = "dimos.utils.cli.dtop:main" +recorder = "dimos.utils.cli.recorder.run_recorder:main" [project.urls] Homepage = "https://dimensionalos.com" diff --git a/uv.lock b/uv.lock index 0d6a3a88ab..aba544ae53 100644 --- a/uv.lock +++ b/uv.lock @@ -1700,6 +1700,7 @@ dependencies = [ { name = "psutil" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pytest-asyncio" }, { name = "python-dotenv" }, { name = "pyturbojpeg" }, { name = "reactivex" }, @@ -2075,6 +2076,7 @@ requires-dist = [ { name = "pymavlink", marker = "extra == 'drone'" }, { name = "pyrealsense2", marker = "sys_platform != 'darwin' and extra == 'manipulation'" }, { name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" }, + { name = "pytest-asyncio", specifier = ">=0.26.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==0.26.0" }, { name = "pytest-env", marker = "extra == 'dev'", specifier = "==1.1.5" }, { name = "pytest-mock", marker = "extra == 'dev'", specifier = "==3.15.0" }, From 8f763c1a82e6d5c06d6eac98c5f2fd75437c3db8 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Mon, 30 Mar 2026 19:47:06 +0100 Subject: [PATCH 2/4] Fix async --- dimos/record/record_replay.py | 4 ++-- dimos/record/test_record_replay.py | 2 +- dimos/utils/cli/recorder/run_recorder.py | 28 ++++++++++++------------ 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/dimos/record/record_replay.py b/dimos/record/record_replay.py index 4ef31e7511..888f28fd1a 100644 --- a/dimos/record/record_replay.py +++ b/dimos/record/record_replay.py @@ -331,7 +331,7 @@ async def stop_playback(self) -> None: await self._play_task self._play_task = None - def seek(self, seconds: float) -> None: + async def seek(self, seconds: float) -> None: """Set playback position in seconds from recording start. Takes effect immediately if playing (restarts playback loop). @@ -339,7 +339,7 @@ def seek(self, seconds: float) -> None: """ self._position = max(0.0, min(seconds, self.duration)) if self.is_playing: - self.stop_playback() + await self.stop_playback() assert self._pubsub is not None self.play(pubsub=self._pubsub, speed=self._play_speed) diff --git a/dimos/record/test_record_replay.py b/dimos/record/test_record_replay.py index a14d1ef3cd..f75ef121a2 100644 --- a/dimos/record/test_record_replay.py +++ b/dimos/record/test_record_replay.py @@ -238,7 +238,7 @@ async def test_seek(self, tmp_db: str) -> None: await asyncio.sleep(0.01) rec.stop_recording() - rec.seek(0.05) + await rec.seek(0.05) assert rec.position == pytest.approx(0.05, abs=0.01) async def test_multiple_pubsubs(self, tmp_db: str) -> None: diff --git a/dimos/utils/cli/recorder/run_recorder.py b/dimos/utils/cli/recorder/run_recorder.py index 4827abf6f2..4f47cf1ce3 100644 --- a/dimos/utils/cli/recorder/run_recorder.py +++ b/dimos/utils/cli/recorder/run_recorder.py @@ -311,17 +311,17 @@ async def action_stop_all(self) -> None: self._recorder.stop_recording() await self._recorder.stop_playback() - def action_seek_back(self) -> None: - self._seek_relative(-5.0) + async def action_seek_back(self) -> None: + await self._seek_relative(-5.0) - def action_seek_fwd(self) -> None: - self._seek_relative(5.0) + async def action_seek_fwd(self) -> None: + await self._seek_relative(5.0) - def action_seek_back_big(self) -> None: - self._seek_relative(-30.0) + async def action_seek_back_big(self) -> None: + await self._seek_relative(-30.0) - def action_seek_fwd_big(self) -> None: - self._seek_relative(30.0) + async def action_seek_fwd_big(self) -> None: + await self._seek_relative(30.0) def action_mark_trim_start(self) -> None: if self._recorder: @@ -331,16 +331,16 @@ def action_mark_trim_end(self) -> None: if self._recorder: self._trim_out = self._recorder.position - def action_do_trim(self) -> None: + async def action_do_trim(self) -> None: if self._recorder and self._trim_in is not None and self._trim_out is not None: - self._recorder.stop_playback() + await self._recorder.stop_playback() lo, hi = sorted((self._trim_in, self._trim_out)) self._recorder.trim(lo, hi) self._trim_in = self._trim_out = None - def action_do_delete(self) -> None: + async def action_do_delete(self) -> None: if self._recorder and self._trim_in is not None and self._trim_out is not None: - self._recorder.stop_playback() + await self._recorder.stop_playback() lo, hi = sorted((self._trim_in, self._trim_out)) self._recorder.delete_range(lo, hi) self._trim_in = self._trim_out = None @@ -356,9 +356,9 @@ def _start_playback(self) -> None: self._lcm.start() self._recorder.play(pubsub=self._lcm, speed=1.0) - def _seek_relative(self, delta: float) -> None: + async def _seek_relative(self, delta: float) -> None: if self._recorder: - self._recorder.seek(self._recorder.position + delta) + await self._recorder.seek(self._recorder.position + delta) # ------------------------------------------------------------------ # Refresh From 4e646d6d7537e96786ecd9ea3e6fa2c446875f2d Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Tue, 31 Mar 2026 19:52:01 +0100 Subject: [PATCH 3/4] Load in viewer as separate source --- dimos/record/record_replay.py | 90 +++++++++++++++-------------- dimos/visualization/rerun/bridge.py | 2 +- 2 files changed, 47 insertions(+), 45 deletions(-) diff --git a/dimos/record/record_replay.py b/dimos/record/record_replay.py index 888f28fd1a..d769f039c8 100644 --- a/dimos/record/record_replay.py +++ b/dimos/record/record_replay.py @@ -47,10 +47,11 @@ import time from typing import Any, TypedDict -from dimos.memory2.codecs.base import _resolve_payload_type +import rerun as rr + from dimos.memory2.store.sqlite import SqliteStore from dimos.protocol.pubsub.impl.lcmpubsub import LCMPubSubBase, Topic -from dimos.protocol.pubsub.spec import PubSub +from dimos.visualization.rerun.bridge import RerunConvertible, is_rerun_multi if sys.version_info >= (3, 11): from typing import Self @@ -164,17 +165,21 @@ def stop_recording(self) -> None: logger.info("Recording stopped") def _on_message(self, msg: bytes, topic: Topic) -> None: - """Handle incoming message during recording.""" stream_name = topic_to_stream_name(topic.pattern) if self._topic_filter is not None and stream_name not in self._topic_filter: return - msg_type = type(msg) - - s = self._store.stream(stream_name, msg_type) + s = self._store.stream(stream_name, type(msg)) s.append(msg, ts=time.time()) + # Persist the full channel string (with #type) in the registry + # so playback can reconstruct the lcm_type for decoding. + reg = self._store._registry.get(stream_name) + if reg and "channel" not in reg: + reg["channel"] = str(topic) + self._store._registry.put(stream_name, reg) + @property def duration(self) -> float: """Total duration of the recording in seconds.""" @@ -222,34 +227,40 @@ def stream_info(self) -> tuple[StreamInfo, ...]: result.append(info) return tuple(result) - def play( - self, - pubsub: PubSub[Any, Any] | None = None, - speed: float = 1.0, - ) -> None: - """Start playback in a background thread. + def play(self, speed: float = 1.0) -> None: + """Start playback as a separate Rerun recording. - Args: - pubsub: If provided, messages are published to this pubsub. - Use LCM() to make them visible to rerun-bridge. - speed: Playback speed multiplier (1.0 = realtime, 2.0 = 2x, etc). + Connects to the running Rerun viewer and logs decoded messages + under a recording called ``"playback"``, so it appears alongside + (but separate from) any live data. """ if self.is_playing: raise RuntimeError("Already playing") self._play_speed = speed - self._pubsub = pubsub # Set resume so playback starts, this is cleared to pause playback. self._resume.set() - self._play_task = asyncio.create_task(self._playback_loop(pubsub)) + self._play_task = asyncio.create_task(self._playback_loop()) - async def _playback_loop(self, pubsub: PubSub[Any, Any] | None) -> None: - """Main playback loop — merges all streams by timestamp.""" + async def _playback_loop(self) -> None: t_min, t_max = self.time_range if t_min >= t_max: return - # Build iterators for each stream, starting from seek position + # Separate Rerun recording so playback appears as its own source + rec = rr.RecordingStream("playback", make_default=False) + rec.connect_grpc() + + # Build topic map for decoding raw bytes -> DimosMsg + topic_map: dict[str, Topic] = {} + for name in self._store.list_streams(): + reg = self._store._registry.get(name) + if reg: + channel = reg.get("channel") + if channel: + topic_map[name] = Topic.from_channel_str(channel) + + # Merge-sort all streams by timestamp start_ts = t_min + self._position heap: list[tuple[float, int, str, Any]] = [] counter = 0 # tiebreaker for heapq @@ -267,7 +278,6 @@ async def _playback_loop(self, pubsub: PubSub[Any, Any] | None) -> None: if not heap: return - topic_map = {} if pubsub is None else self._build_topic_map(pubsub) wall_start = time.monotonic() rec_start = heap[0][0] # earliest observation timestamp @@ -282,12 +292,22 @@ async def _playback_loop(self, pubsub: PubSub[Any, Any] | None) -> None: # Wait for correct playback time elapsed_rec = ts - rec_start target_wall = wall_start + (elapsed_rec / self._play_speed) - sleep_time = target_wall - time.monotonic() - await asyncio.sleep(sleep_time) + await asyncio.sleep(target_wall - time.monotonic()) self._position = ts - t_min - if pubsub is not None and stream_name in topic_map: - pubsub.publish(topic_map[stream_name], obs.data) + + # Decode raw bytes -> DimosMsg -> Rerun archetype + topic = topic_map.get(stream_name) + if topic is not None and topic.lcm_type is not None: + msg = topic.lcm_type.lcm_decode(obs.data) + if isinstance(msg, RerunConvertible): + entity_path = f"world/{stream_name}" + rerun_data = msg.to_rerun() + if is_rerun_multi(rerun_data): + for path, archetype in rerun_data: + rec.log(path, archetype) + else: + rec.log(entity_path, rerun_data) try: next_obs = next(it) @@ -296,24 +316,6 @@ async def _playback_loop(self, pubsub: PubSub[Any, Any] | None) -> None: heapq.heappush(heap, (next_obs.ts, counter, stream_name, (next_obs, it))) counter += 1 - def _build_topic_map(self, pubsub: PubSub[Any, Any]) -> dict[str, Topic]: - """Build stream_name -> Topic mapping for publishing.""" - topic_map: dict[str, Topic] = {} - - for name in self._store.list_streams(): - reg = self._store._registry.get(name) - if reg is None: - continue - - payload_module = reg.get("payload_module") - lcm_type = None - if payload_module: - lcm_type = _resolve_payload_type(payload_module) - - topic_map[name] = Topic(f"/{name}", lcm_type=lcm_type) - - return topic_map - def pause(self) -> None: """Pause playback. Resume with :meth:`resume`.""" self._resume.clear() diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 8b1cda443c..f793b9966c 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -307,7 +307,7 @@ def start(self) -> None: ) rr.spawn(connect=True, memory_limit=self.config.memory_limit) elif self.config.viewer_mode == "web": - server_uri = rr.serve_grpc() + server_uri = rr.serve_grpc(grpc_port=RERUN_GRPC_PORT) rr.serve_web_viewer(connect_to=server_uri, open_browser=False) elif self.config.viewer_mode == "connect": rr.connect_grpc(self.config.connect_url) From 6c7ff5dc48295aa93826bef0bbd77536a886b300 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 1 Apr 2026 00:30:38 +0100 Subject: [PATCH 4/4] Fix playback --- dimos/utils/cli/recorder/run_recorder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dimos/utils/cli/recorder/run_recorder.py b/dimos/utils/cli/recorder/run_recorder.py index 4f47cf1ce3..a1f84d4bbe 100644 --- a/dimos/utils/cli/recorder/run_recorder.py +++ b/dimos/utils/cli/recorder/run_recorder.py @@ -222,11 +222,11 @@ def compose(self) -> ComposeResult: yield Footer() def on_mount(self) -> None: - from dimos.protocol.pubsub.impl.lcmpubsub import LCM + from dimos.protocol.pubsub.impl.lcmpubsub import LCMPubSubBase from dimos.record import RecordReplay from dimos.utils.cli.lcmspy.lcmspy import GraphLCMSpy - self._lcm = LCM() + self._lcm = LCMPubSubBase() # Live topic discovery via LCM spy (same as lcmspy tool) self._spy = GraphLCMSpy(graph_log_window=0.5) @@ -354,7 +354,7 @@ def _start_playback(self) -> None: return if hasattr(self._lcm, "start"): self._lcm.start() - self._recorder.play(pubsub=self._lcm, speed=1.0) + self._recorder.play(speed=1.0) async def _seek_relative(self, delta: float) -> None: if self._recorder: