diff --git a/data/.lfs/smartnav_paths.tar.gz b/data/.lfs/smartnav_paths.tar.gz new file mode 100644 index 0000000000..150d528185 --- /dev/null +++ b/data/.lfs/smartnav_paths.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0ab6939a8bcd589a4eba4355bc82f662654eadc758512c74a535adace40425e +size 1291310 diff --git a/dimos/agents_deprecated/agent.py b/dimos/agents_deprecated/agent.py index 1d48ce2fa4..4515cd5bfb 100644 --- a/dimos/agents_deprecated/agent.py +++ b/dimos/agents_deprecated/agent.py @@ -897,5 +897,3 @@ def stream_query(self, query_text: str) -> Observable: # type: ignore[type-arg] return create( lambda observer, _: self._observable_query(observer, incoming_query=query_text) # type: ignore[arg-type] ) - - diff --git a/dimos/agents_deprecated/memory/spatial_vector_db.py b/dimos/agents_deprecated/memory/spatial_vector_db.py index 7d0c8eb2f7..e93b3fab8d 100644 --- a/dimos/agents_deprecated/memory/spatial_vector_db.py +++ b/dimos/agents_deprecated/memory/spatial_vector_db.py @@ -227,8 +227,8 @@ def _process_query_results(self, results) -> list[dict]: # type: ignore[no-unty ) # Get the image from visual memory - #image = self.visual_memory.get(lookup_id) - #result["image"] = image + # image = self.visual_memory.get(lookup_id) + # result["image"] = image processed_results.append(result) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 74471f34d5..2925548a33 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -56,6 +56,7 @@ class MyCppModule(NativeModule): from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig +from dimos.utils.change_detect import PathEntry, did_change from dimos.utils.logging_config import setup_logger if sys.version_info < (3, 13): @@ -81,9 +82,10 @@ class NativeModuleConfig(ModuleConfig): extra_env: dict[str, str] = Field(default_factory=dict) shutdown_timeout: float = 10.0 log_format: LogFormat = LogFormat.TEXT + rebuild_on_change: list[PathEntry] | None = None # Override in subclasses to exclude fields from CLI arg generation - cli_exclude: frozenset[str] = frozenset() + cli_exclude: frozenset[str] = frozenset({"rebuild_on_change"}) def to_cli_args(self) -> list[str]: """Auto-convert subclass config fields to CLI args. @@ -132,17 +134,31 @@ class NativeModule(Module[_NativeConfig]): _process: subprocess.Popen[bytes] | None = None _watchdog: threading.Thread | None = None _stopping: bool = False - _last_stderr_lines: collections.deque[str] + _stderr_tail: list[str] + _stdout_tail: list[str] + _tail_lock: threading.Lock + + @property + def _mod_label(self) -> str: + """Short human-readable label: ClassName(executable_basename).""" + exe = Path(self.config.executable).name if self.config.executable else "?" + return f"{type(self).__name__}({exe})" def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self._last_stderr_lines = collections.deque(maxlen=50) + self._stderr_tail: collections.deque[str] = collections.deque(maxlen=50) + self._stdout_tail: collections.deque[str] = collections.deque(maxlen=50) + self._tail_lock = threading.Lock() self._resolve_paths() @rpc def start(self) -> None: if self._process is not None and self._process.poll() is None: - logger.warning("Native process already running", pid=self._process.pid) + logger.warning( + "Native process already running", + module=self._mod_label, + pid=self._process.pid, + ) return self._maybe_build() @@ -158,10 +174,14 @@ def start(self) -> None: env = {**os.environ, **self.config.extra_env} cwd = self.config.cwd or str(Path(self.config.executable).resolve().parent) - module_name = type(self).__name__ + # Reset tail buffers for this run. + with self._tail_lock: + self._stderr_tail.clear() + self._stdout_tail.clear() + logger.info( - f"Starting native process: {module_name}", - module=module_name, + "Starting native process", + module=self._mod_label, cmd=" ".join(cmd), cwd=cwd, ) @@ -173,26 +193,36 @@ def start(self) -> None: stderr=subprocess.PIPE, ) logger.info( - f"Native process started: {module_name}", - module=module_name, + "Native process started", + module=self._mod_label, pid=self._process.pid, ) self._stopping = False - self._watchdog = threading.Thread(target=self._watch_process, daemon=True) + self._watchdog = threading.Thread( + target=self._watch_process, + daemon=True, + name=f"native-watchdog-{self._mod_label}", + ) self._watchdog.start() @rpc def stop(self) -> None: self._stopping = True if self._process is not None and self._process.poll() is None: - logger.info("Stopping native process", pid=self._process.pid) + logger.info( + "Stopping native process", + module=self._mod_label, + pid=self._process.pid, + ) self._process.send_signal(signal.SIGTERM) try: self._process.wait(timeout=self.config.shutdown_timeout) except subprocess.TimeoutExpired: logger.warning( - "Native process did not exit, sending SIGKILL", pid=self._process.pid + "Native process did not exit, sending SIGKILL", + module=self._mod_label, + pid=self._process.pid, ) self._process.kill() self._process.wait(timeout=5) @@ -207,57 +237,110 @@ def _watch_process(self) -> None: if self._process is None: return - stdout_t = self._start_reader(self._process.stdout, "info") - stderr_t = self._start_reader(self._process.stderr, "warning") + stdout_t = self._start_reader(self._process.stdout, "info", self._stdout_tail) + stderr_t = self._start_reader(self._process.stderr, "warning", self._stderr_tail) rc = self._process.wait() stdout_t.join(timeout=2) stderr_t.join(timeout=2) if self._stopping: + logger.info( + "Native process exited (expected)", + module=self._mod_label, + pid=self._process.pid, + returncode=rc, + ) return - module_name = type(self).__name__ - exe_name = Path(self.config.executable).name if self.config.executable else "unknown" - - # Use buffered stderr lines from the reader thread for the crash report. - last_stderr = "\n".join(self._last_stderr_lines) + # Grab the tail for diagnostics. + with self._tail_lock: + stderr_snapshot = list(self._stderr_tail) + stdout_snapshot = list(self._stdout_tail) logger.error( - f"Native process crashed: {module_name} ({exe_name})", - module=module_name, - executable=exe_name, + "Native process died unexpectedly", + module=self._mod_label, pid=self._process.pid, returncode=rc, - last_stderr=last_stderr[:500] if last_stderr else None, ) + + # Log the last stderr/stdout lines so the cause is visible. + if stderr_snapshot: + logger.error( + f"Last {len(stderr_snapshot)} stderr lines from {self._mod_label}:", + module=self._mod_label, + pid=self._process.pid, + ) + for line in stderr_snapshot: + logger.error(f" stderr| {line}", module=self._mod_label) + + if stdout_snapshot and not stderr_snapshot: + # Only dump stdout if stderr was empty (avoid double-noise). + logger.error( + f"Last {len(stdout_snapshot)} stdout lines from {self._mod_label}:", + module=self._mod_label, + pid=self._process.pid, + ) + for line in stdout_snapshot: + logger.error(f" stdout| {line}", module=self._mod_label) + + if not stderr_snapshot and not stdout_snapshot: + logger.error( + "No output captured from native process — " + "binary may have crashed before producing any output", + module=self._mod_label, + pid=self._process.pid, + ) + self.stop() - def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: + def _start_reader( + self, + stream: IO[bytes] | None, + level: str, + tail_buf: list[str], + ) -> threading.Thread: """Spawn a daemon thread that pipes a subprocess stream through the logger.""" - t = threading.Thread(target=self._read_log_stream, args=(stream, level), daemon=True) + t = threading.Thread( + target=self._read_log_stream, + args=(stream, level, tail_buf), + daemon=True, + name=f"native-reader-{level}-{self._mod_label}", + ) t.start() return t - def _read_log_stream(self, stream: IO[bytes] | None, level: str) -> None: + def _read_log_stream( + self, + stream: IO[bytes] | None, + level: str, + tail_buf: list[str], + ) -> None: if stream is None: return log_fn = getattr(logger, level) - is_stderr = level == "warning" for raw in stream: line = raw.decode("utf-8", errors="replace").rstrip() if not line: continue - if is_stderr: - self._last_stderr_lines.append(line) + + # Keep a rolling tail buffer for crash diagnostics. + with self._tail_lock: + tail_buf.append(line) + if self.config.log_format == LogFormat.JSON: try: data = json.loads(line) event = data.pop("event", line) - log_fn(event, **data) + log_fn(event, module=self._mod_label, **data) continue except (json.JSONDecodeError, TypeError): - logger.warning("malformed JSON from native module", raw=line) - log_fn(line, pid=self._process.pid if self._process else None) + logger.warning( + "malformed JSON from native module", + module=self._mod_label, + raw=line, + ) + log_fn(line, module=self._mod_label, pid=self._process.pid if self._process else None) stream.close() def _resolve_paths(self) -> None: @@ -269,18 +352,39 @@ def _resolve_paths(self) -> None: if not Path(self.config.executable).is_absolute() and self.config.cwd is not None: self.config.executable = str(Path(self.config.cwd) / self.config.executable) + def _build_cache_name(self) -> str: + """Return a stable, unique cache name for this module's build state.""" + source_file = Path(inspect.getfile(type(self))).resolve() + return f"native_{source_file}" + def _maybe_build(self) -> None: - """Run ``build_command`` if the executable does not exist.""" + """Run ``build_command`` if the executable does not exist or sources changed.""" exe = Path(self.config.executable) - if exe.exists(): + + # Check if rebuild needed due to source changes + needs_rebuild = False + if self.config.rebuild_on_change and exe.exists(): + if did_change( + self._build_cache_name(), self.config.rebuild_on_change, cwd=self.config.cwd + ): + logger.info("Source files changed, triggering rebuild", executable=str(exe)) + needs_rebuild = True + + if exe.exists() and not needs_rebuild: return + if self.config.build_command is None: raise FileNotFoundError( - f"Executable not found: {exe}. " + f"[{self._mod_label}] Executable not found: {exe}. " "Set build_command in config to auto-build, or build it manually." ) + + # Don't unlink the exe before rebuilding — the build command is + # responsible for replacing it. For nix builds the exe lives inside + # a read-only store; `nix build -o` atomically swaps the output + # symlink without touching store contents. logger.info( - "Executable not found, running build", + "Rebuilding" if needs_rebuild else "Executable not found, building", executable=str(exe), build_command=self.config.build_command, ) @@ -293,25 +397,36 @@ def _maybe_build(self) -> None: stderr=subprocess.PIPE, ) stdout, stderr = proc.communicate() - for line in stdout.decode("utf-8", errors="replace").splitlines(): + + stdout_lines = stdout.decode("utf-8", errors="replace").splitlines() + stderr_lines = stderr.decode("utf-8", errors="replace").splitlines() + + for line in stdout_lines: if line.strip(): - logger.info(line) - for line in stderr.decode("utf-8", errors="replace").splitlines(): + logger.info(line, module=self._mod_label) + for line in stderr_lines: if line.strip(): - logger.warning(line) + logger.warning(line, module=self._mod_label) + if proc.returncode != 0: - stderr_tail = stderr.decode("utf-8", errors="replace").strip()[-1000:] + # Include the last stderr lines in the exception for RPC callers. + tail = [l for l in stderr_lines if l.strip()][-20:] + tail_str = "\n".join(tail) if tail else "(no stderr output)" raise RuntimeError( - f"Build command failed (exit {proc.returncode}): {self.config.build_command}\n" - f"stderr: {stderr_tail}" + f"[{self._mod_label}] Build command failed " + f"(exit {proc.returncode}): {self.config.build_command}\n" + f"--- last stderr ---\n{tail_str}" ) if not exe.exists(): raise FileNotFoundError( - f"Build command succeeded but executable still not found: {exe}\n" - f"Build output may have been written to a different path. " - f"Check that build_command produces the executable at the expected location." + f"[{self._mod_label}] Build command succeeded but executable still not found: {exe}" ) + # Seed the cache after a successful build so the next check has a baseline + # (needed for the initial build when the pre-build change check was skipped) + if self.config.rebuild_on_change: + did_change(self._build_cache_name(), self.config.rebuild_on_change, cwd=self.config.cwd) + def _collect_topics(self) -> dict[str, str]: """Extract LCM topic strings from blueprint-assigned stream transports.""" topics: dict[str, str] = {} diff --git a/dimos/core/test_native_rebuild.py b/dimos/core/test_native_rebuild.py new file mode 100644 index 0000000000..6f8a68b9aa --- /dev/null +++ b/dimos/core/test_native_rebuild.py @@ -0,0 +1,140 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for NativeModule rebuild-on-change integration.""" + +from __future__ import annotations + +from pathlib import Path +import stat + +import pytest + +from dimos.core.native_module import NativeModule, NativeModuleConfig +from dimos.utils.change_detect import PathEntry + + +@pytest.fixture(autouse=True) +def _use_tmp_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Redirect the change-detection cache to a temp dir for every test.""" + monkeypatch.setattr( + "dimos.utils.change_detect._get_cache_dir", + lambda: tmp_path / "cache", + ) + + +@pytest.fixture() +def build_env(tmp_path: Path) -> dict[str, Path]: + """Set up a temp directory with a source file, executable path, and marker path.""" + src = tmp_path / "src" + src.mkdir() + (src / "main.c").write_text("int main() { return 0; }") + + exe = tmp_path / "mybin" + marker = tmp_path / "build_ran.marker" + + # Build script: create the executable and a marker file + build_script = tmp_path / "build.sh" + build_script.write_text(f"#!/bin/sh\ntouch {exe}\nchmod +x {exe}\ntouch {marker}\n") + build_script.chmod(build_script.stat().st_mode | stat.S_IEXEC) + + return {"src": src, "exe": exe, "marker": marker, "build_script": build_script} + + +class _RebuildConfig(NativeModuleConfig): + executable: str = "" + rebuild_on_change: list[PathEntry] | None = None + + +class _RebuildModule(NativeModule[_RebuildConfig]): + default_config = _RebuildConfig + + +def _make_module(build_env: dict[str, Path]) -> _RebuildModule: + """Create a _RebuildModule pointing at the temp build env.""" + return _RebuildModule( + executable=str(build_env["exe"]), + build_command=f"sh {build_env['build_script']}", + rebuild_on_change=[str(build_env["src"])], + cwd=str(build_env["src"]), + ) + + +def test_rebuild_on_change_triggers_build(build_env: dict[str, Path]) -> None: + """When source files change, the build_command should re-run.""" + mod = _make_module(build_env) + try: + exe = build_env["exe"] + marker = build_env["marker"] + + # First build: exe doesn't exist → build runs + mod._maybe_build() + assert exe.exists() + assert marker.exists() + marker.unlink() + + # No change → build should NOT run + mod._maybe_build() + assert not marker.exists() + + # Modify source → build SHOULD run + (build_env["src"] / "main.c").write_text("int main() { return 1; }") + mod._maybe_build() + assert marker.exists(), "Build should have re-run after source change" + finally: + mod.stop() + + +def test_no_change_skips_rebuild(build_env: dict[str, Path]) -> None: + """When sources haven't changed, build_command must not run again.""" + mod = _make_module(build_env) + try: + marker = build_env["marker"] + + # Initial build + mod._maybe_build() + assert marker.exists() + marker.unlink() + + # Second call — nothing changed + mod._maybe_build() + assert not marker.exists(), "Build should have been skipped (no source changes)" + finally: + mod.stop() + + +def test_rebuild_on_change_none_skips_check(build_env: dict[str, Path]) -> None: + """When rebuild_on_change is None, no change detection happens at all.""" + exe = build_env["exe"] + marker = build_env["marker"] + + mod = _RebuildModule( + executable=str(exe), + build_command=f"sh {build_env['build_script']}", + rebuild_on_change=None, + cwd=str(build_env["src"]), + ) + try: + # Initial build + mod._maybe_build() + assert exe.exists() + assert marker.exists() + marker.unlink() + + # Modify source — but rebuild_on_change is None, so no rebuild + (build_env["src"] / "main.c").write_text("int main() { return 1; }") + mod._maybe_build() + assert not marker.exists(), "Should not rebuild when rebuild_on_change is None" + finally: + mod.stop() diff --git a/dimos/e2e_tests/test_smartnav_replay.py b/dimos/e2e_tests/test_smartnav_replay.py new file mode 100644 index 0000000000..e103b9e3cb --- /dev/null +++ b/dimos/e2e_tests/test_smartnav_replay.py @@ -0,0 +1,227 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the unitree_go2_smartnav blueprint using replay data. + +Builds the smartnav pipeline (GO2Connection → OdomAdapter → PGO → CostMapper → +ReplanningAStarPlanner) in replay mode and verifies that data flows end-to-end: + - PGO receives scans and odom, publishes corrected_odometry + global_map + - CostMapper receives global_map, publishes global_costmap +""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.mapping.costmapper import CostMapper +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.smartnav.modules.odom_adapter.odom_adapter import OdomAdapter +from dimos.navigation.smartnav.modules.pgo.pgo import PGO +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic +from dimos.robot.unitree.go2.connection import GO2Connection + + +@pytest.fixture(autouse=True) +def _ci_env(monkeypatch): + monkeypatch.setenv("CI", "1") + + +@pytest.fixture() +def smartnav_coordinator(): + """Build the smartnav blueprint in replay mode (no planner — just PGO + CostMapper).""" + global_config.update( + viewer="none", + replay=True, + replay_dir="go2_sf_office", + n_workers=1, + ) + + # Minimal pipeline: GO2Connection → OdomAdapter → PGO → CostMapper + # Skip ReplanningAStarPlanner and WavefrontFrontierExplorer to avoid + # needing a goal and cmd_vel sink. + bp = ( + autoconnect( + unitree_go2_basic, + PGO.blueprint(), + OdomAdapter.blueprint(), + CostMapper.blueprint(), + ) + .global_config( + n_workers=1, + robot_model="unitree_go2", + ) + .remappings( + [ + (GO2Connection, "lidar", "registered_scan"), + (GO2Connection, "odom", "raw_odom"), + ] + ) + ) + + coord = bp.build() + yield coord + coord.stop() + + +class _StreamCollector: + """Subscribe to a transport and collect messages in a list.""" + + def __init__(self) -> None: + self.messages: list = [] + self._lock = threading.Lock() + self._event = threading.Event() + + def callback(self, msg): # type: ignore[no-untyped-def] + with self._lock: + self.messages.append(msg) + self._event.set() + + def wait(self, count: int = 1, timeout: float = 30.0) -> bool: + deadline = time.monotonic() + timeout + while True: + with self._lock: + if len(self.messages) >= count: + return True + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + self._event.wait(timeout=min(remaining, 0.5)) + self._event.clear() + + +@pytest.mark.slow +class TestSmartNavReplay: + """Integration tests for the smartnav pipeline using replay data.""" + + def test_pgo_produces_corrected_odometry(self, smartnav_coordinator): + """PGO should receive odom+scans via OdomAdapter and publish corrected_odometry.""" + coord = smartnav_coordinator + + # Find the PGO module instance + pgo_mod = None + for mod in coord.all_modules: + if isinstance(mod, PGO): + pgo_mod = mod + break + assert pgo_mod is not None, "PGO module not found in coordinator" + + # Subscribe to corrected_odometry output + collector = _StreamCollector() + pgo_mod.corrected_odometry._transport.subscribe(collector.callback) + + # Start the system — replay data flows automatically + coord.start() + + # Wait for PGO to produce at least 3 corrected odometry messages + assert collector.wait(count=3, timeout=30), ( + f"PGO did not produce enough corrected_odometry messages " + f"(got {len(collector.messages)})" + ) + + # Verify the messages are Odometry with reasonable values + msg = collector.messages[0] + assert isinstance(msg, Odometry), f"Expected Odometry, got {type(msg)}" + assert msg.frame_id == "map" + + def test_pgo_produces_global_map(self, smartnav_coordinator): + """PGO should accumulate keyframes and publish a global map.""" + coord = smartnav_coordinator + + pgo_mod = None + for mod in coord.all_modules: + if isinstance(mod, PGO): + pgo_mod = mod + break + assert pgo_mod is not None + + collector = _StreamCollector() + pgo_mod.global_map._transport.subscribe(collector.callback) + + coord.start() + + # Global map publishes less frequently — wait longer + assert collector.wait(count=1, timeout=60), ( + f"PGO did not produce a global_map (got {len(collector.messages)})" + ) + + msg = collector.messages[0] + assert isinstance(msg, PointCloud2), f"Expected PointCloud2, got {type(msg)}" + pts, _ = msg.as_numpy() + assert len(pts) > 0, "Global map should contain points" + + def test_costmapper_produces_costmap(self, smartnav_coordinator): + """CostMapper should receive global_map from PGO and produce a costmap.""" + coord = smartnav_coordinator + + from dimos.mapping.costmapper import CostMapper + + cm_mod = None + for mod in coord.all_modules: + if isinstance(mod, CostMapper): + cm_mod = mod + break + assert cm_mod is not None, "CostMapper module not found in coordinator" + + collector = _StreamCollector() + cm_mod.global_costmap._transport.subscribe(collector.callback) + + coord.start() + + assert collector.wait(count=1, timeout=60), ( + f"CostMapper did not produce a global_costmap (got {len(collector.messages)})" + ) + + msg = collector.messages[0] + assert isinstance(msg, OccupancyGrid), f"Expected OccupancyGrid, got {type(msg)}" + + def test_odom_adapter_converts_bidirectionally(self, smartnav_coordinator): + """OdomAdapter should convert PoseStamped→Odometry and Odometry→PoseStamped.""" + coord = smartnav_coordinator + + from dimos.navigation.smartnav.modules.odom_adapter.odom_adapter import OdomAdapter + + adapter = None + for mod in coord.all_modules: + if isinstance(mod, OdomAdapter): + adapter = mod + break + assert adapter is not None, "OdomAdapter not found in coordinator" + + # Collect outputs from both directions + odom_out = _StreamCollector() + ps_out = _StreamCollector() + adapter.odometry._transport.subscribe(odom_out.callback) + adapter.odom._transport.subscribe(ps_out.callback) + + coord.start() + + # OdomAdapter.odometry (PoseStamped→Odometry) should fire from replay odom + assert odom_out.wait(count=3, timeout=30), ( + f"OdomAdapter did not produce Odometry output (got {len(odom_out.messages)})" + ) + assert isinstance(odom_out.messages[0], Odometry) + + # OdomAdapter.odom (Odometry→PoseStamped) fires when PGO publishes corrected_odometry + assert ps_out.wait(count=1, timeout=30), ( + f"OdomAdapter did not produce PoseStamped output (got {len(ps_out.messages)})" + ) + assert isinstance(ps_out.messages[0], PoseStamped) diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index b8165658d9..a02f947af1 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -22,6 +22,7 @@ from dimos.agents.annotation import skill from dimos.core.blueprints import autoconnect from dimos.core.core import rpc +from dimos.core.global_config import global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware @@ -32,7 +33,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module def default_transform() -> Transform: @@ -120,5 +121,5 @@ def stop(self) -> None: demo_camera = autoconnect( CameraModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module(viewer_backend=global_config.viewer), ) diff --git a/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp b/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp index 60b8d9cdb2..e1151aec70 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp +++ b/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp @@ -65,6 +65,47 @@ static std::string g_frame_id = "map"; static std::string g_child_frame_id = "body"; static float g_frequency = 10.0f; +// Initial pose offset (applied to all SLAM outputs) +// Position offset +static double g_init_x = 0.0; +static double g_init_y = 0.0; +static double g_init_z = 0.0; +// Orientation offset as quaternion (identity = no rotation) +static double g_init_qx = 0.0; +static double g_init_qy = 0.0; +static double g_init_qz = 0.0; +static double g_init_qw = 1.0; + +// Helper: quaternion multiply (Hamilton product) q_out = q1 * q2 +static void quat_mul(double ax, double ay, double az, double aw, + double bx, double by, double bz, double bw, + double& ox, double& oy, double& oz, double& ow) { + ow = aw*bw - ax*bx - ay*by - az*bz; + ox = aw*bx + ax*bw + ay*bz - az*by; + oy = aw*by - ax*bz + ay*bw + az*bx; + oz = aw*bz + ax*by - ay*bx + az*bw; +} + +// Helper: rotate a vector by a quaternion v_out = q * v * q_inv +static void quat_rotate(double qx, double qy, double qz, double qw, + double vx, double vy, double vz, + double& ox, double& oy, double& oz) { + // t = 2 * cross(q_xyz, v) + double tx = 2.0 * (qy*vz - qz*vy); + double ty = 2.0 * (qz*vx - qx*vz); + double tz = 2.0 * (qx*vy - qy*vx); + // v_out = v + qw*t + cross(q_xyz, t) + ox = vx + qw*tx + (qy*tz - qz*ty); + oy = vy + qw*ty + (qz*tx - qx*tz); + oz = vz + qw*tz + (qx*ty - qy*tx); +} + +// Check if initial pose is non-identity +static bool has_init_pose() { + return g_init_x != 0.0 || g_init_y != 0.0 || g_init_z != 0.0 || + g_init_qx != 0.0 || g_init_qy != 0.0 || g_init_qz != 0.0 || g_init_qw != 1.0; +} + // Frame accumulator (Livox SDK raw → CustomMsg) static std::mutex g_pc_mutex; static std::vector g_accumulated_points; @@ -126,11 +167,32 @@ static void publish_lidar(PointCloudXYZI::Ptr cloud, double timestamp, pc.data_length = pc.row_step; pc.data.resize(pc.data_length); + // Apply only the ROTATION part of init_pose to point clouds (not translation). + // FAST-LIO's get_world_cloud() places points in the SLAM map frame, which + // starts at origin. If the lidar is mounted upside-down, the whole map is + // inverted — rotation fixes that. But the translation component (e.g. z=1.2 + // for mount height) should NOT be added to points; it only offsets the + // odometry origin so downstream modules know the sensor height. Adding it + // to points would shift the ground plane away from z≈0. + // + // Note: init_pose globals are set once in main() before the processing loop + // and never modified, so this check is safe to hoist outside the loop. + const bool apply_rotation = has_init_pose(); for (int i = 0; i < num_points; ++i) { float* dst = reinterpret_cast(pc.data.data() + i * 16); - dst[0] = cloud->points[i].x; - dst[1] = cloud->points[i].y; - dst[2] = cloud->points[i].z; + if (apply_rotation) { + double rx, ry, rz; + quat_rotate(g_init_qx, g_init_qy, g_init_qz, g_init_qw, + cloud->points[i].x, cloud->points[i].y, cloud->points[i].z, + rx, ry, rz); + dst[0] = static_cast(rx); + dst[1] = static_cast(ry); + dst[2] = static_cast(rz); + } else { + dst[0] = cloud->points[i].x; + dst[1] = cloud->points[i].y; + dst[2] = cloud->points[i].z; + } dst[3] = cloud->points[i].intensity; } @@ -148,14 +210,38 @@ static void publish_odometry(const custom_messages::Odometry& odom, double times msg.header = make_header(g_frame_id, timestamp); msg.child_frame_id = g_child_frame_id; - // Pose - msg.pose.pose.position.x = odom.pose.pose.position.x; - msg.pose.pose.position.y = odom.pose.pose.position.y; - msg.pose.pose.position.z = odom.pose.pose.position.z; - msg.pose.pose.orientation.x = odom.pose.pose.orientation.x; - msg.pose.pose.orientation.y = odom.pose.pose.orientation.y; - msg.pose.pose.orientation.z = odom.pose.pose.orientation.z; - msg.pose.pose.orientation.w = odom.pose.pose.orientation.w; + // Pose (apply initial pose offset: p_out = R_init * p_slam + t_init) + if (has_init_pose()) { + double rx, ry, rz; + quat_rotate(g_init_qx, g_init_qy, g_init_qz, g_init_qw, + odom.pose.pose.position.x, + odom.pose.pose.position.y, + odom.pose.pose.position.z, + rx, ry, rz); + msg.pose.pose.position.x = rx + g_init_x; + msg.pose.pose.position.y = ry + g_init_y; + msg.pose.pose.position.z = rz + g_init_z; + + double ox, oy, oz, ow; + quat_mul(g_init_qx, g_init_qy, g_init_qz, g_init_qw, + odom.pose.pose.orientation.x, + odom.pose.pose.orientation.y, + odom.pose.pose.orientation.z, + odom.pose.pose.orientation.w, + ox, oy, oz, ow); + msg.pose.pose.orientation.x = ox; + msg.pose.pose.orientation.y = oy; + msg.pose.pose.orientation.z = oz; + msg.pose.pose.orientation.w = ow; + } else { + msg.pose.pose.position.x = odom.pose.pose.position.x; + msg.pose.pose.position.y = odom.pose.pose.position.y; + msg.pose.pose.position.z = odom.pose.pose.position.z; + msg.pose.pose.orientation.x = odom.pose.pose.orientation.x; + msg.pose.pose.orientation.y = odom.pose.pose.orientation.y; + msg.pose.pose.orientation.z = odom.pose.pose.orientation.z; + msg.pose.pose.orientation.w = odom.pose.pose.orientation.w; + } // Covariance (fixed-size double[36]) for (int i = 0; i < 36; ++i) { @@ -340,7 +426,29 @@ int main(int argc, char** argv) { ports.host_imu_data = mod.arg_int("host_imu_data_port", port_defaults.host_imu_data); ports.host_log_data = mod.arg_int("host_log_data_port", port_defaults.host_log_data); + // Initial pose offset [x, y, z, qx, qy, qz, qw] + { + std::string init_str = mod.arg("init_pose", ""); + if (!init_str.empty()) { + double vals[7] = {0, 0, 0, 0, 0, 0, 1}; + int n = 0; + size_t pos = 0; + while (pos < init_str.size() && n < 7) { + size_t comma = init_str.find(',', pos); + if (comma == std::string::npos) comma = init_str.size(); + vals[n++] = std::stod(init_str.substr(pos, comma - pos)); + pos = comma + 1; + } + g_init_x = vals[0]; g_init_y = vals[1]; g_init_z = vals[2]; + g_init_qx = vals[3]; g_init_qy = vals[4]; g_init_qz = vals[5]; g_init_qw = vals[6]; + } + } + printf("[fastlio2] Starting FAST-LIO2 + Livox Mid-360 native module\n"); + if (has_init_pose()) { + printf("[fastlio2] init_pose: xyz=(%.3f, %.3f, %.3f) quat=(%.4f, %.4f, %.4f, %.4f)\n", + g_init_x, g_init_y, g_init_z, g_init_qx, g_init_qy, g_init_qz, g_init_qw); + } printf("[fastlio2] lidar topic: %s\n", g_lidar_topic.empty() ? "(disabled)" : g_lidar_topic.c_str()); printf("[fastlio2] odometry topic: %s\n", diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index f3de842b46..b39dd7bcec 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -15,36 +15,45 @@ from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 from dimos.mapping.voxels import VoxelGridMapper -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module voxel_size = 0.05 mid360_fastlio = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), VoxelGridMapper.blueprint(publish_interval=1.0, voxel_size=voxel_size, carve_columns=False), - RerunBridgeModule.blueprint( - visual_override={ - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - "world/lidar": None, - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + "world/lidar": None, + }, + }, ), ).global_config(n_workers=3, robot_model="mid360_fastlio2_voxels") mid360_fastlio_voxels_native = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=3.0), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": None, - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": None, + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") diff --git a/dimos/hardware/sensors/lidar/fastlio2/module.py b/dimos/hardware/sensors/lidar/fastlio2/module.py index cdce59bd81..72775e3c05 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/module.py +++ b/dimos/hardware/sensors/lidar/fastlio2/module.py @@ -30,9 +30,12 @@ from __future__ import annotations +import ipaddress from pathlib import Path +import socket from typing import TYPE_CHECKING, Annotated +from pydantic import field_validator from pydantic.experimental.pipeline import validate_as from dimos.core.native_module import NativeModule, NativeModuleConfig @@ -52,8 +55,55 @@ from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.spec import mapping, perception +from dimos.utils.logging_config import setup_logger _CONFIG_DIR = Path(__file__).parent / "config" +_logger = setup_logger() + + +def _get_local_ips() -> list[str]: + """Return all IPv4 addresses assigned to local interfaces.""" + ips: list[str] = [] + try: + for info in socket.getaddrinfo(socket.gethostname(), None, socket.AF_INET): + addr = str(info[4][0]) + if addr not in ips: + ips.append(addr) + except socket.gaierror: + pass + # Also grab addresses via DGRAM trick for interfaces without DNS + try: + import subprocess + + out = subprocess.check_output( + ["ip", "-4", "-o", "addr", "show"], + timeout=5, + stderr=subprocess.DEVNULL, + ).decode() + for line in out.splitlines(): + # e.g. "2: eth0 inet 192.168.123.5/24 ..." + parts = line.split() + for i, p in enumerate(parts): + if p == "inet" and i + 1 < len(parts): + addr = parts[i + 1].split("/")[0] + if addr not in ips: + ips.append(addr) + except Exception: + pass + return ips + + +def _find_candidate_ips(lidar_ip: str, local_ips: list[str]) -> list[str]: + """Suggest local IPs on the same subnet as the lidar.""" + candidates: list[str] = [] + try: + lidar_net = ipaddress.IPv4Network(f"{lidar_ip}/24", strict=False) + for ip in local_ips: + if ipaddress.IPv4Address(ip) in lidar_net: + candidates.append(ip) + except (ValueError, TypeError): + pass + return candidates class FastLio2Config(NativeModuleConfig): @@ -68,6 +118,20 @@ class FastLio2Config(NativeModuleConfig): lidar_ip: str = "192.168.1.155" frequency: float = 10.0 + # Initial pose offset [x, y, z, qx, qy, qz, qw] applied to all SLAM outputs. + # Set z to sensor mount height above ground for correct terrain analysis. + # Quaternion (qx, qy, qz, qw) for angled mounts; identity = [0,0,0, 0,0,0,1]. + init_pose: list[float] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] + + @field_validator("init_pose") + @classmethod + def _check_init_pose_length(cls, v: list[float]) -> list[float]: + if len(v) != 7: + raise ValueError( + f"init_pose must have exactly 7 elements [x,y,z,qx,qy,qz,qw], got {len(v)}" + ) + return v + # Frame IDs for output messages frame_id: str = "map" child_frame_id: str = "body" @@ -108,12 +172,21 @@ class FastLio2Config(NativeModuleConfig): host_imu_data_port: int = SDK_HOST_IMU_DATA_PORT host_log_data_port: int = SDK_HOST_LOG_DATA_PORT - # Resolved in __post_init__, passed as --config_path to the binary + # Passed as --config_path to the binary (resolved from ``config`` in post-init) config_path: str | None = None - # config is not a CLI arg (config_path is) + # config is not a CLI arg (config_path is the resolved version) cli_exclude: frozenset[str] = frozenset({"config"}) + def model_post_init(self, __context: object) -> None: + """Resolve config_path from the config YAML field.""" + super().model_post_init(__context) + # The validate_as pipeline may not fire for defaults, so resolve here. + cfg = self.config + if not cfg.is_absolute(): + cfg = _CONFIG_DIR / cfg + self.config_path = str(cfg.resolve()) + class FastLio2( NativeModule[FastLio2Config], perception.Lidar, perception.Odometry, mapping.GlobalPointcloud @@ -131,6 +204,75 @@ class FastLio2( odometry: Out[Odometry] global_map: Out[PointCloud2] + def __init__(self, **kwargs: object) -> None: + super().__init__(**kwargs) + self._validate_network() + + def _validate_network(self) -> None: + """Pre-flight check: verify host_ip is reachable and suggest alternatives.""" + host_ip = self.config.host_ip + lidar_ip = self.config.lidar_ip + local_ips = _get_local_ips() + + _logger.info( + "FastLio2 network check", + host_ip=host_ip, + lidar_ip=lidar_ip, + local_ips=local_ips, + ) + + # Check if host_ip is actually assigned to this machine. + if host_ip not in local_ips: + same_subnet = _find_candidate_ips(lidar_ip, local_ips) + + if same_subnet: + picked = same_subnet[0] + _logger.warning( + f"FastLio2: host_ip={host_ip!r} not found locally. " + f"Auto-correcting to {picked!r} (same subnet as lidar {lidar_ip}).", + configured_ip=host_ip, + corrected_ip=picked, + lidar_ip=lidar_ip, + local_ips=local_ips, + ) + self.config.host_ip = picked + host_ip = picked + else: + subnet_prefix = ".".join(lidar_ip.split(".")[:3]) + msg = ( + f"FastLio2: host_ip={host_ip!r} is not assigned to any local interface.\n" + f" Lidar IP: {lidar_ip}\n" + f" Local IPs found: {', '.join(local_ips) or '(none)'}\n" + f" No local IP found on the same subnet as lidar ({lidar_ip}).\n" + f" The lidar network interface may be down or unconfigured.\n" + f" → Check: ip addr | grep {subnet_prefix}\n" + f" → Or assign an IP: " + f"sudo ip addr add {subnet_prefix}.5/24 dev \n" + ) + _logger.error(msg) + raise RuntimeError(msg) + + # Check if we can bind a UDP socket on host_ip (port 0 = ephemeral). + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.bind((host_ip, 0)) + except OSError as e: + _logger.error( + f"FastLio2: Cannot bind UDP socket on host_ip={host_ip!r}: {e}\n" + f" Another process may be using the Livox SDK ports.\n" + f" → Check: ss -ulnp | grep {host_ip}" + ) + raise RuntimeError( + f"FastLio2: Cannot bind UDP on {host_ip}: {e}. " + f"Check if another Livox/FastLio2 process is running." + ) from e + + _logger.info( + "FastLio2 network check passed", + host_ip=host_ip, + lidar_ip=lidar_ip, + ) + # Verify protocol port compliance (mypy will flag missing ports) if TYPE_CHECKING: diff --git a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py index c8835b3e89..958af084e2 100644 --- a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py +++ b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py @@ -14,9 +14,9 @@ from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.lidar.livox.module import Mid360 -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module mid360 = autoconnect( Mid360.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).global_config(n_workers=2, robot_model="mid360") diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index 8f726fe173..880bdb81b9 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -46,7 +46,6 @@ from dimos.msgs.sensor_msgs.JointState import JointState from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule from dimos.robot.catalog.ufactory import xarm7 as _catalog_xarm7 -from dimos.robot.foxglove_bridge import FoxgloveBridge # TODO: migrate to rerun from dimos.utils.data import get_data @@ -405,7 +404,7 @@ def _make_piper_config( use_aabb=True, max_obstacle_width=0.06, ), - FoxgloveBridge.blueprint(), # TODO: migrate to rerun + vis_module("foxglove"), ) .transports( { diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 782283029b..c210688f0e 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -14,15 +14,13 @@ # limitations under the License. from pathlib import Path -from dimos.agents.mcp.mcp_client import McpClient -from dimos.agents.mcp.mcp_server import McpServer from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera from dimos.manipulation.grasping.graspgen_module import graspgen from dimos.manipulation.grasping.grasping import GraspingModule from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_module = RealSenseCamera.blueprint(enable_pointcloud=False) @@ -44,7 +42,7 @@ ("/tmp", "/tmp", "rw") ], # Grasp visualization debug standalone: python -m dimos.manipulation.grasping.visualize_grasps ), - FoxgloveBridge.blueprint(), + vis_module("foxglove"), McpServer.blueprint(), McpClient.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py index 0f1b1cd37a..69900d28ea 100644 --- a/dimos/models/embedding/base.py +++ b/dimos/models/embedding/base.py @@ -166,5 +166,4 @@ def query( top_values, top_indices = similarities.topk(k=min(top_k, len(candidates))) return [(idx.item(), val.item()) for idx, val in zip(top_indices, top_values, strict=False)] - ... diff --git a/dimos/models/embedding/treid.py b/dimos/models/embedding/treid.py index 1e89a55116..b4f06453e7 100644 --- a/dimos/models/embedding/treid.py +++ b/dimos/models/embedding/treid.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import overload +import warnings warnings.filterwarnings("ignore", message="Cython evaluation.*unavailable", category=UserWarning) diff --git a/dimos/models/qwen/video_query.py b/dimos/models/qwen/video_query.py index 05b93028d7..55f3b5334e 100644 --- a/dimos/models/qwen/video_query.py +++ b/dimos/models/qwen/video_query.py @@ -161,7 +161,8 @@ def query_single_frame( def get_bbox_from_qwen( - video_stream: Observable, object_name: str | None = None # type: ignore[type-arg] + video_stream: Observable, + object_name: str | None = None, # type: ignore[type-arg] ) -> tuple[BBox, float] | None: """Get bounding box coordinates from Qwen for a specific object or any object. diff --git a/dimos/models/segmentation/edge_tam.py b/dimos/models/segmentation/edge_tam.py index 91cdec661d..21e6184acb 100644 --- a/dimos/models/segmentation/edge_tam.py +++ b/dimos/models/segmentation/edge_tam.py @@ -79,15 +79,14 @@ def __init__( OmegaConf.update(cfg, key, value) if cfg.model._target_ != "sam2.sam2_video_predictor.SAM2VideoPredictor": - logger.warning( - f"Config target is {cfg.model._target_}, forcing SAM2VideoPredictor" - ) + logger.warning(f"Config target is {cfg.model._target_}, forcing SAM2VideoPredictor") cfg.model._target_ = "sam2.sam2_video_predictor.SAM2VideoPredictor" self._predictor = instantiate(cfg.model, _recursive_=True) # Suppress the per-frame "propagate in video" tqdm bar from sam2 import sam2.sam2_video_predictor as _svp + _svp.tqdm = lambda iterable, *a, **kw: iterable ckpt_path = str(get_data("models_edgetam") / "edgetam.pt") diff --git a/dimos/models/vl/create.py b/dimos/models/vl/create.py index 45dc5f935e..db4566b466 100644 --- a/dimos/models/vl/create.py +++ b/dimos/models/vl/create.py @@ -5,12 +5,15 @@ __all__ = ["VlModelName", "create"] + def create(name: VlModelName) -> VlModel[Any]: # This uses inline imports to only import what's needed. match name: case "qwen": from dimos.models.vl.qwen import QwenVlModel + return QwenVlModel() case "moondream": from dimos.models.vl.moondream import MoondreamVlModel + return MoondreamVlModel() diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py index aad9fe514c..654512e7c4 100644 --- a/dimos/models/vl/moondream_hosted.py +++ b/dimos/models/vl/moondream_hosted.py @@ -58,7 +58,9 @@ def caption(self, image: Image | np.ndarray, length: str = "normal") -> str: # result = self._client.caption(pil_image, length=length) return result.get("caption", str(result)) # type: ignore[no-any-return] - def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetections2D[Detection2DBBox]: # type: ignore[no-untyped-def] + def query_detections( + self, image: Image, query: str, **kwargs + ) -> ImageDetections2D[Detection2DBBox]: # type: ignore[no-untyped-def] """Detect objects using Moondream's hosted detect method. Args: @@ -148,4 +150,3 @@ def query_points( def stop(self) -> None: pass - diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index 0486bbdb30..22d0587a5e 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -30,7 +30,9 @@ def _client(self) -> OpenAI: return OpenAI(api_key=api_key) - def query(self, image: Image | np.ndarray, query: str, response_format: dict | None = None, **kwargs) -> str: # type: ignore[override, type-arg, no-untyped-def] + def query( + self, image: Image | np.ndarray, query: str, response_format: dict | None = None, **kwargs + ) -> str: # type: ignore[override, type-arg, no-untyped-def] if isinstance(image, np.ndarray): import warnings @@ -71,7 +73,11 @@ def query(self, image: Image | np.ndarray, query: str, response_format: dict | N return response.choices[0].message.content # type: ignore[return-value,no-any-return] def query_batch( - self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any + self, + images: list[Image], + query: str, + response_format: dict[str, Any] | None = None, + **kwargs: Any, ) -> list[str]: # type: ignore[override] """Query VLM with multiple images using a single API call.""" if not images: @@ -80,7 +86,9 @@ def query_batch( content: list[dict[str, Any]] = [ { "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"}, + "image_url": { + "url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}" + }, } for img in images ] @@ -100,4 +108,3 @@ def stop(self) -> None: """Release the OpenAI client.""" if "_client" in self.__dict__: del self.__dict__["_client"] - diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index 202ce6759e..9928f298df 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -68,17 +68,23 @@ def query(self, image: Image | np.ndarray, query: str) -> str: # type: ignore[o return response.choices[0].message.content # type: ignore[return-value] def query_batch( - self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any + self, + images: list[Image], + query: str, + response_format: dict[str, Any] | None = None, + **kwargs: Any, ) -> list[str]: # type: ignore[override] """Query VLM with multiple images using a single API call.""" if not images: return [] content: list[dict[str, Any]] = [ - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"}, - } + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}" + }, + } for img in images ] content.append({"type": "text", "text": query}) diff --git a/dimos/models/vl/test_vlm.py b/dimos/models/vl/test_vlm.py index f0fd3b8d5a..734553a7b3 100644 --- a/dimos/models/vl/test_vlm.py +++ b/dimos/models/vl/test_vlm.py @@ -35,7 +35,7 @@ @pytest.mark.slow @pytest.mark.skipif_in_ci def test_vlm_bbox_detections(model_class: "type[VlModel]", model_name: str) -> None: - if model_class is MoondreamHostedVlModel and 'MOONDREAM_API_KEY' not in os.environ: + if model_class is MoondreamHostedVlModel and "MOONDREAM_API_KEY" not in os.environ: pytest.skip("Need MOONDREAM_API_KEY to run") image = Image.from_file(get_data("cafe.jpg")).to_rgb() @@ -110,7 +110,7 @@ def test_vlm_bbox_detections(model_class: "type[VlModel]", model_name: str) -> N def test_vlm_point_detections(model_class: "type[VlModel]", model_name: str) -> None: """Test VLM point detection capabilities.""" - if model_class is MoondreamHostedVlModel and 'MOONDREAM_API_KEY' not in os.environ: + if model_class is MoondreamHostedVlModel and "MOONDREAM_API_KEY" not in os.environ: pytest.skip("Need MOONDREAM_API_KEY to run") image = Image.from_file(get_data("cafe.jpg")).to_rgb() diff --git a/dimos/navigation/demo_ros_navigation.py b/dimos/navigation/demo_ros_navigation.py index 0efa04cd44..6effb4c672 100644 --- a/dimos/navigation/demo_ros_navigation.py +++ b/dimos/navigation/demo_ros_navigation.py @@ -18,7 +18,7 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.navigation import rosnav +from dimos.navigation import rosnav_legacy as rosnav from dimos.protocol.service.lcmservice import autoconf from dimos.utils.logging_config import setup_logger @@ -30,7 +30,7 @@ def main() -> None: dimos = ModuleCoordinator() dimos.start() - ros_nav = rosnav.deploy(dimos) + ros_nav = rosnav.deploy(dimos) # type: ignore[attr-defined] logger.info("\nTesting navigation in 2 seconds...") time.sleep(2) diff --git a/dimos/navigation/rosnav/Dockerfile b/dimos/navigation/rosnav/Dockerfile new file mode 100644 index 0000000000..bfa1df65d3 --- /dev/null +++ b/dimos/navigation/rosnav/Dockerfile @@ -0,0 +1,355 @@ +# syntax=docker/dockerfile:1 +# DimOS Navigation Docker Image +# +# Multi-stage build for ROS 2 navigation with SLAM support. +# Includes both arise_slam and FASTLIO2 - select at runtime via LOCALIZATION_METHOD. +# +# The ros-navigation-autonomy-stack repo is cloned at build time via SSH. +# Build with: docker build --ssh default ... +# + +# Build argument for ROS distribution (default: humble) +ARG ROS_DISTRO=humble +ARG TARGETARCH +# Pinned git ref for ros-navigation-autonomy-stack (branch, tag, or commit SHA) +ARG NAV_STACK_REF=fastlio2 + +# Platform-specific base images +# - amd64: Use osrf/ros desktop-full (includes Gazebo, full GUI) +# - arm64: Use ros-base (desktop-full not available for ARM) +FROM osrf/ros:${ROS_DISTRO}-desktop-full AS base-amd64 +FROM ros:${ROS_DISTRO}-ros-base AS base-arm64 + +# STAGE 1: Build Stage - compile all C++ dependencies +FROM base-${TARGETARCH} AS builder + +ARG ROS_DISTRO +ARG NAV_STACK_REF +ENV DEBIAN_FRONTEND=noninteractive +ENV ROS_DISTRO=${ROS_DISTRO} +ENV WORKSPACE=/ros2_ws + +# Install build dependencies only +RUN apt-get update && apt-get install -y --no-install-recommends \ + # Build tools + git \ + cmake \ + build-essential \ + python3-colcon-common-extensions \ + # SSH client for private repo clone + openssh-client \ + # Libraries needed for building + libpcl-dev \ + libgoogle-glog-dev \ + libgflags-dev \ + libatlas-base-dev \ + libeigen3-dev \ + libsuitesparse-dev \ + # ROS packages needed for build + ros-${ROS_DISTRO}-pcl-ros \ + ros-${ROS_DISTRO}-cv-bridge \ + && rm -rf /var/lib/apt/lists/* + +# On arm64, ros-base doesn't include rviz2 (unlike desktop-full on amd64) +# Install it separately for building rviz plugins +# Note: ARG must be re-declared after FROM; placed here to maximize layer caching above +ARG TARGETARCH +RUN if [ "${TARGETARCH}" = "arm64" ]; then \ + apt-get update && apt-get install -y --no-install-recommends \ + ros-${ROS_DISTRO}-rviz2 \ + && rm -rf /var/lib/apt/lists/*; \ + fi + +# On arm64, build open3d from source (no Linux aarch64 wheels on PyPI) +# Cached as a separate layer; the wheel is copied to the runtime stage +# mkdir runs unconditionally so COPY --from=builder works on all architectures +RUN mkdir -p /opt/open3d-wheel && \ + PYTHON_MINOR=$(python3 -c "import sys; print(sys.version_info.minor)") && \ + if [ "${TARGETARCH}" = "arm64" ] && [ "$PYTHON_MINOR" -ge 12 ]; then \ + echo "Building open3d from source for arm64 + Python 3.${PYTHON_MINOR} (no PyPI wheel)..." && \ + apt-get update && apt-get install -y --no-install-recommends \ + python3-dev \ + python3-pip \ + python3-setuptools \ + python3-wheel \ + libblas-dev \ + liblapack-dev \ + libgl1-mesa-dev \ + libglib2.0-dev \ + libxinerama-dev \ + libxcursor-dev \ + libxrandr-dev \ + libxi-dev \ + gfortran \ + && rm -rf /var/lib/apt/lists/* && \ + cd /tmp && \ + git clone --depth 1 --branch v0.19.0 https://github.com/isl-org/Open3D.git && \ + cd Open3D && \ + util/install_deps_ubuntu.sh assume-yes && \ + mkdir build && cd build && \ + cmake .. \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_CUDA_MODULE=OFF \ + -DBUILD_GUI=OFF \ + -DBUILD_TENSORFLOW_OPS=OFF \ + -DBUILD_PYTORCH_OPS=OFF \ + -DBUILD_UNIT_TESTS=OFF \ + -DBUILD_BENCHMARKS=OFF \ + -DBUILD_EXAMPLES=OFF \ + -DBUILD_WEBRTC=OFF && \ + make -j$(($(nproc) > 4 ? 4 : $(nproc))) && \ + make pip-package -j$(($(nproc) > 4 ? 4 : $(nproc))) && \ + mkdir -p /opt/open3d-wheel && \ + cp lib/python_package/pip_package/open3d*.whl /opt/open3d-wheel/ && \ + cd / && rm -rf /tmp/Open3D; \ + fi + +# On arm64, build or-tools from source (pre-built binaries are x86_64 only) +# This is cached as a separate layer since it takes significant time to build +ENV OR_TOOLS_VERSION=9.8 +RUN if [ "${TARGETARCH}" = "arm64" ]; then \ + echo "Building or-tools v${OR_TOOLS_VERSION} from source for arm64..." && \ + apt-get update && apt-get install -y --no-install-recommends \ + lsb-release \ + wget \ + && rm -rf /var/lib/apt/lists/* && \ + cd /tmp && \ + wget -q https://github.com/google/or-tools/archive/refs/tags/v${OR_TOOLS_VERSION}.tar.gz && \ + tar xzf v${OR_TOOLS_VERSION}.tar.gz && \ + cd or-tools-${OR_TOOLS_VERSION} && \ + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_DEPS=ON \ + -DBUILD_SAMPLES=OFF \ + -DBUILD_EXAMPLES=OFF \ + -DBUILD_FLATZINC=OFF \ + -DUSE_SCIP=OFF \ + -DUSE_COINOR=OFF && \ + cmake --build build --config Release -j$(($(nproc) > 4 ? 4 : $(nproc))) && \ + cmake --install build --prefix /opt/or-tools && \ + rm -rf /tmp/or-tools-${OR_TOOLS_VERSION} /tmp/v${OR_TOOLS_VERSION}.tar.gz; \ + fi + +# Create workspace +RUN mkdir -p ${WORKSPACE}/src + +# Clone autonomy stack source +RUN git clone -b ${NAV_STACK_REF} --depth 1 \ + https://github.com/jeff-hykin/ros-navigation-autonomy-stack.git \ + ${WORKSPACE}/src/ros-navigation-autonomy-stack + +# On arm64, replace pre-built x86_64 or-tools with arm64 built version +RUN if [ "${TARGETARCH}" = "arm64" ] && [ -d "/opt/or-tools" ]; then \ + echo "Replacing x86_64 or-tools with arm64 build..." && \ + OR_TOOLS_DIR=${WORKSPACE}/src/ros-navigation-autonomy-stack/src/exploration_planner/tare_planner/or-tools && \ + rm -rf ${OR_TOOLS_DIR}/lib/*.so* ${OR_TOOLS_DIR}/lib/*.a && \ + cp -r /opt/or-tools/lib/* ${OR_TOOLS_DIR}/lib/ && \ + rm -rf ${OR_TOOLS_DIR}/include && \ + cp -r /opt/or-tools/include ${OR_TOOLS_DIR}/ && \ + ldconfig; \ + fi + +# Compatibility fix: In Humble, cv_bridge uses .h extension, but Jazzy uses .hpp +# Create a symlink so code written for Jazzy works on Humble +RUN if [ "${ROS_DISTRO}" = "humble" ]; then \ + CV_BRIDGE_DIR=$(find /opt/ros/humble/include -name "cv_bridge.h" -printf "%h\n" 2>/dev/null | head -1) && \ + if [ -n "$CV_BRIDGE_DIR" ]; then \ + ln -sf "$CV_BRIDGE_DIR/cv_bridge.h" "$CV_BRIDGE_DIR/cv_bridge.hpp"; \ + echo "Created cv_bridge.hpp symlink in $CV_BRIDGE_DIR"; \ + else \ + echo "Warning: cv_bridge.h not found, skipping symlink creation"; \ + fi; \ + fi + +# Build Livox-SDK2 +RUN cd ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/utilities/livox_ros_driver2/Livox-SDK2 && \ + mkdir -p build && cd build && \ + cmake .. && make -j$(nproc) && make install && ldconfig && \ + rm -rf ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/utilities/livox_ros_driver2/Livox-SDK2/build + +# Build Sophus +RUN cd ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/dependency/Sophus && \ + mkdir -p build && cd build && \ + cmake .. -DBUILD_TESTS=OFF && make -j$(nproc) && make install && \ + rm -rf ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/dependency/Sophus/build + +# Build Ceres Solver +RUN cd ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/dependency/ceres-solver && \ + mkdir -p build && cd build && \ + cmake .. && make -j$(nproc) && make install && \ + rm -rf ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/dependency/ceres-solver/build + +# Build GTSAM +RUN cd ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/dependency/gtsam && \ + mkdir -p build && cd build && \ + cmake .. -DGTSAM_USE_SYSTEM_EIGEN=ON -DGTSAM_BUILD_WITH_MARCH_NATIVE=OFF && \ + make -j$(nproc) && make install && ldconfig && \ + rm -rf ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/dependency/gtsam/build + +# Build ROS workspace with both SLAM systems (no --symlink-install for multi-stage build compatibility) +RUN /bin/bash -c "source /opt/ros/${ROS_DISTRO}/setup.bash && \ + cd ${WORKSPACE} && \ + echo 'Building with both arise_slam and FASTLIO2' && \ + colcon build --cmake-args -DCMAKE_BUILD_TYPE=Release" + +# STAGE 2: Runtime Stage - minimal image for running +ARG ROS_DISTRO +ARG TARGETARCH +FROM base-${TARGETARCH} AS runtime + +ARG ROS_DISTRO +ENV DEBIAN_FRONTEND=noninteractive +ENV ROS_DISTRO=${ROS_DISTRO} +ENV WORKSPACE=/ros2_ws +ENV DIMOS_PATH=/workspace/dimos +# LOCALIZATION_METHOD: arise_slam (default) or fastlio +ENV LOCALIZATION_METHOD=arise_slam + +# DDS Configuration - Use FastDDS (default ROS 2 middleware) +ENV RMW_IMPLEMENTATION=rmw_fastrtps_cpp +ENV FASTRTPS_DEFAULT_PROFILES_FILE=/ros2_ws/config/fastdds.xml + +# Install runtime dependencies only (no build tools) +RUN apt-get update && apt-get install -y --no-install-recommends \ + # ROS packages + ros-${ROS_DISTRO}-pcl-ros \ + ros-${ROS_DISTRO}-cv-bridge \ + ros-${ROS_DISTRO}-foxglove-bridge \ + ros-${ROS_DISTRO}-rviz2 \ + ros-${ROS_DISTRO}-rqt* \ + ros-${ROS_DISTRO}-joy \ + # DDS middleware (FastDDS is default, just ensure it's installed) + ros-${ROS_DISTRO}-rmw-fastrtps-cpp \ + # Runtime libraries + libpcl-dev \ + libgoogle-glog-dev \ + libgflags-dev \ + libatlas-base-dev \ + libeigen3-dev \ + libsuitesparse-dev \ + # X11 for GUI (minimal) + libx11-6 \ + libxext6 \ + libxrender1 \ + libgl1 \ + libglib2.0-0 \ + # Networking tools + iputils-ping \ + net-tools \ + iproute2 \ + # Serial/USB for hardware + usbutils \ + # Python (minimal) + python3-pip \ + python3-venv \ + # Joystick support + joystick \ + # Time sync for multi-computer setups + chrony \ + && rm -rf /var/lib/apt/lists/* + +# Copy installed libraries from builder +COPY --from=builder /usr/local/lib /usr/local/lib +COPY --from=builder /usr/local/include /usr/local/include + +RUN ldconfig + +# Copy built ROS workspace from builder +COPY --from=builder ${WORKSPACE}/install ${WORKSPACE}/install + +# Copy only config/rviz files from src (not the large dependency folders) +# These are needed if running without volume mount +COPY --from=builder ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/base_autonomy/vehicle_simulator/rviz ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/base_autonomy/vehicle_simulator/rviz +COPY --from=builder ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/route_planner/far_planner/rviz ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/route_planner/far_planner/rviz +COPY --from=builder ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/exploration_planner/tare_planner/rviz ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/exploration_planner/tare_planner/rviz +# Copy SLAM config files based on SLAM_TYPE +COPY --from=builder ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/utilities/livox_ros_driver2/config ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/utilities/livox_ros_driver2/config + +# Copy config files for both SLAM systems +RUN --mount=from=builder,source=${WORKSPACE}/src/ros-navigation-autonomy-stack/src,target=/tmp/src \ + echo "Copying arise_slam configs" && \ + mkdir -p ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/arise_slam_mid360 && \ + cp -r /tmp/src/slam/arise_slam_mid360/config ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/arise_slam_mid360/ 2>/dev/null || true && \ + echo "Copying FASTLIO2 configs" && \ + mkdir -p ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/FASTLIO2_ROS2 && \ + for pkg in fastlio2 localizer pgo hba; do \ + if [ -d "/tmp/src/slam/FASTLIO2_ROS2/$pkg/config" ]; then \ + mkdir -p ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/FASTLIO2_ROS2/$pkg && \ + cp -r /tmp/src/slam/FASTLIO2_ROS2/$pkg/config ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/FASTLIO2_ROS2/$pkg/; \ + fi; \ + if [ -d "/tmp/src/slam/FASTLIO2_ROS2/$pkg/rviz" ]; then \ + cp -r /tmp/src/slam/FASTLIO2_ROS2/$pkg/rviz ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/FASTLIO2_ROS2/$pkg/; \ + fi; \ + done + +# Copy simulation shell scripts (real robot mode uses volume mount) +COPY --from=builder ${WORKSPACE}/src/ros-navigation-autonomy-stack/system_simulation*.sh ${WORKSPACE}/src/ros-navigation-autonomy-stack/ + +# Create directories +RUN mkdir -p ${DIMOS_PATH} \ + ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/base_autonomy/vehicle_simulator/mesh/unity \ + ${WORKSPACE}/bagfiles \ + ${WORKSPACE}/logs \ + ${WORKSPACE}/config + +# Copy FastDDS configuration (single source: dimos/navigation/rosnav/fastdds.xml) +# At runtime the volume mount may overlay this, but the baked-in copy ensures +# the image works standalone. +COPY dimos/navigation/rosnav/fastdds.xml ${WORKSPACE}/config/fastdds.xml + +# Install portaudio for unitree-webrtc-connect (pyaudio dependency) +RUN apt-get update && apt-get install -y --no-install-recommends \ + portaudio19-dev \ + && rm -rf /var/lib/apt/lists/* + +# Create Python venv and install dependencies +RUN python3 -m venv /opt/dimos-venv && \ + /opt/dimos-venv/bin/pip install --no-cache-dir \ + pyyaml + +# On arm64, install open3d wheel built from source in the builder stage +COPY --from=builder /opt/open3d-wheel /opt/open3d-wheel +ARG TARGETARCH +RUN if [ "${TARGETARCH}" = "arm64" ] && ls /opt/open3d-wheel/open3d*.whl 1>/dev/null 2>&1; then \ + echo "Installing open3d from pre-built arm64 wheel..." && \ + /opt/dimos-venv/bin/pip install --no-cache-dir /opt/open3d-wheel/open3d*.whl && \ + rm -rf /opt/open3d-wheel; \ + fi + +# Copy dimos source and install as editable package +# The volume mount at runtime will overlay /workspace/dimos, but the editable +# install creates a link that will use the volume-mounted files +COPY pyproject.toml /workspace/dimos/ +COPY dimos /workspace/dimos/dimos +RUN /opt/dimos-venv/bin/pip install --no-cache-dir -e "/workspace/dimos[unitree]" + +# Set up shell environment +RUN echo "source /opt/ros/${ROS_DISTRO}/setup.bash" >> ~/.bashrc && \ + echo "source ${WORKSPACE}/install/setup.bash" >> ~/.bashrc && \ + echo "source /opt/dimos-venv/bin/activate" >> ~/.bashrc && \ + echo "export RMW_IMPLEMENTATION=rmw_fastrtps_cpp" >> ~/.bashrc && \ + echo "export FASTRTPS_DEFAULT_PROFILES_FILE=/ros2_ws/config/fastdds.xml" >> ~/.bashrc + +# Copy helper scripts (paths relative to repo root build context) +COPY docker/navigation/foxglove_utility/twist_relay.py /usr/local/bin/twist_relay.py +COPY docker/navigation/foxglove_utility/goal_autonomy_relay.py /usr/local/bin/goal_autonomy_relay.py +COPY dimos/navigation/rosnav/entrypoint.sh /usr/local/bin/entrypoint.sh +RUN chmod +x /usr/local/bin/twist_relay.py /usr/local/bin/goal_autonomy_relay.py /usr/local/bin/entrypoint.sh + +# Set up udev rules for motor controller +RUN mkdir -p /etc/udev/rules.d && \ + echo 'SUBSYSTEM=="tty", ATTRS{idVendor}=="0483", ATTRS{idProduct}=="5740", MODE="0666", GROUP="dialout"' \ + > /etc/udev/rules.d/99-motor-controller.rules + +# Working directory +WORKDIR ${DIMOS_PATH} + +# Default entrypoint (overridden at docker run time by DockerModule) +ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] + +# Default command +CMD ["bash"] + +# DIMOS-MODULE-CONVERSION-427593ae-c6e8-4cf1-9b2d-ee81a420a5dc +# (sentinel prevents docker_build.py from appending redundant DimOS footer; +# dimos is already installed above via pip install -e) diff --git a/dimos/navigation/rosnav/entrypoint.sh b/dimos/navigation/rosnav/entrypoint.sh new file mode 100755 index 0000000000..875aaeed8d --- /dev/null +++ b/dimos/navigation/rosnav/entrypoint.sh @@ -0,0 +1,522 @@ +#!/bin/bash + +MODE="${MODE:-unity_sim}" +USE_ROUTE_PLANNER="${USE_ROUTE_PLANNER:-true}" +USE_RVIZ="${USE_RVIZ:-false}" +ENABLE_FOXGLOVE="${ENABLE_FOXGLOVE:-false}" +FOXGLOVE_PORT="${FOXGLOVE_PORT:-8765}" +LOCALIZATION_METHOD="${LOCALIZATION_METHOD:-arise_slam}" +BAGFILE_PATH="${BAGFILE_PATH:-}" + +UNITY_BRIDGE_CONNECT_TIMEOUT_SEC="${UNITY_BRIDGE_CONNECT_TIMEOUT_SEC:-25}" +UNITY_BRIDGE_RETRY_INTERVAL_SEC="${UNITY_BRIDGE_RETRY_INTERVAL_SEC:-2}" + +# Tune kernel TCP buffers for high-bandwidth data transmission (lidar, etc.) +sysctl -w net.core.rmem_max=67108864 net.core.rmem_default=67108864 2>/dev/null || true +sysctl -w net.core.wmem_max=67108864 net.core.wmem_default=67108864 2>/dev/null || true + +STACK_ROOT="/ros2_ws/src/ros-navigation-autonomy-stack" +UNITY_EXECUTABLE="${STACK_ROOT}/src/base_autonomy/vehicle_simulator/mesh/unity/environment/Model.x86_64" +UNITY_MESH_DIR="${STACK_ROOT}/src/base_autonomy/vehicle_simulator/mesh/unity" + +# +# Source +# +echo "[entrypoint] Sourcing ROS env..." +source /opt/ros/${ROS_DISTRO:-humble}/setup.bash +source /ros2_ws/install/setup.bash +source /opt/dimos-venv/bin/activate + +# +# cli helpers (when connecting to docker) +# + +# rosspy +cat > /usr/bin/rosspy <<'EOS' +#!/bin/bash + source /opt/ros/${ROS_DISTRO:-humble}/setup.bash + source /ros2_ws/install/setup.bash + source /opt/dimos-venv/bin/activate + exec python3 -m dimos.utils.cli.rosspy.run_rosspy "$@" +EOS +chmod +x /usr/bin/rosspy + +# x11_doctor +cat > /usr/bin/x11_doctor <<'EOS' +#!/usr/bin/env bash + ok=true + echo "=== X11 Doctor ===" + + # 1. DISPLAY + echo "" + echo "--- DISPLAY ---" + if [ -z "${DISPLAY:-}" ]; then + echo " FAIL DISPLAY is not set" + ok=false + else + echo " OK DISPLAY=${DISPLAY}" + fi + + # 2. X11 unix socket directory + echo "" + echo "--- /tmp/.X11-unix socket directory ---" + if [ ! -d /tmp/.X11-unix ]; then + echo " FAIL /tmp/.X11-unix does not exist (volume not mounted?)" + ok=false + else + sockets + sockets=$(ls /tmp/.X11-unix 2>/dev/null) + if [ -z "$sockets" ]; then + echo " WARN /tmp/.X11-unix exists but is empty (no display sockets)" + ok=false + else + echo " OK /tmp/.X11-unix contents: $sockets" + fi + fi + + # 3. Socket for the specific DISPLAY + echo "" + echo "--- Display socket for DISPLAY=${DISPLAY:-} ---" + if [ -n "${DISPLAY:-}" ]; then + display_num=$(echo "${DISPLAY}" | sed 's/.*:\([0-9]*\).*/\1/') + sock="/tmp/.X11-unix/X${display_num}" + if [ -S "$sock" ]; then + echo " OK $sock exists and is a socket" + ls -la "$sock" + else + echo " FAIL $sock not found or not a socket" + ok=false + fi + else + echo " SKIP (DISPLAY not set)" + fi + + # 4. XAUTHORITY file + echo "" + echo "--- XAUTHORITY ---" + xauth_file="${XAUTHORITY:-$HOME/.Xauthority}" + if [ -z "${XAUTHORITY:-}" ]; then + echo " WARN XAUTHORITY env var not set; defaulting to $xauth_file" + else + echo " OK XAUTHORITY=${XAUTHORITY}" + fi + if [ -f "$xauth_file" ]; then + echo " OK $xauth_file exists ($(wc -c < "$xauth_file") bytes)" + ls -la "$xauth_file" + else + echo " FAIL $xauth_file not found (Xauthority not mounted?)" + ok=false + fi + + # 5. xauth cookie list + echo "" + echo "--- xauth cookie entries ---" + if command -v xauth >/dev/null 2>&1; then + cookie_out=$(XAUTHORITY="$xauth_file" xauth list 2>&1) + if [ -z "$cookie_out" ]; then + echo " WARN xauth list returned no entries (cookie file empty or wrong display)" + ok=false + else + echo " OK cookies found:" + echo "$cookie_out" | sed 's/^/ /' + fi + else + echo " WARN xauth not installed; cannot check cookies" + fi + + # 6. Live connection test + echo "" + echo "--- Live connection test ---" + if command -v xdpyinfo >/dev/null 2>&1; then + if DISPLAY="${DISPLAY:-:0}" XAUTHORITY="$xauth_file" xdpyinfo >/dev/null 2>&1; then + echo " OK xdpyinfo connected to ${DISPLAY:-:0} successfully" + else + echo " FAIL xdpyinfo could not connect to ${DISPLAY:-:0}" + DISPLAY="${DISPLAY:-:0}" XAUTHORITY="$xauth_file" xdpyinfo 2>&1 | head -5 | sed 's/^/ /' + ok=false + fi + elif command -v xclock >/dev/null 2>&1; then + if DISPLAY="${DISPLAY:-:0}" XAUTHORITY="$xauth_file" xclock -display "${DISPLAY:-:0}" & + sleep 1 && kill %1 2>/dev/null; then + echo " OK xclock launched on ${DISPLAY:-:0}" + else + echo " FAIL xclock could not connect" + ok=false + fi + else + echo " SKIP neither xdpyinfo nor xclock installed; skipping live test" + echo " Install with: apt-get install -y x11-utils" + fi + + # 7. Summary + echo "" + echo "=== Summary ===" + if $ok; then + echo " All checks passed — X11 should work." + else + echo " One or more checks failed." + echo "" + echo " Common fixes:" + echo " • Mount the socket: -v /tmp/.X11-unix:/tmp/.X11-unix" + echo " • Mount the cookie: -v \${XAUTHORITY:-\$HOME/.Xauthority}:/tmp/.Xauthority:ro" + echo " • Set env vars: -e DISPLAY -e XAUTHORITY=/tmp/.Xauthority" + echo " • Allow local X: xhost +local: (run on host, less safe)" + fi + echo "" +EOS +chmod +x /usr/bin/x11_doctor + +# +# +# +# sanity checks and setup +# +# +# + +# +# dimos +# +if ! [ -d "/workspace/dimos" ]; then + echo "the dimos codebase must be mounted to /workspace/dimos for the codebase to work" + exit 1 +fi +export PYTHONPATH="/workspace/dimos:${PYTHONPATH:-}" +# start pip install in the background +pip_install_log_path="/tmp/dimos_pip_install.log" +pip install -e /workspace/dimos &>"$pip_install_log_path" & +PIP_INSTALL_PID=$! + +# +# dds config +# +export RMW_IMPLEMENTATION=rmw_fastrtps_cpp +if [ -z "$FASTRTPS_DEFAULT_PROFILES_FILE" ]; then + if [ -f "/ros2_ws/config/custom_fastdds.xml" ]; then + export FASTRTPS_DEFAULT_PROFILES_FILE=/ros2_ws/config/custom_fastdds.xml + elif [ -f "/ros2_ws/config/fastdds.xml" ]; then + export FASTRTPS_DEFAULT_PROFILES_FILE=/ros2_ws/config/fastdds.xml + fi +fi +# ensure file exists +if ! [ -f "$FASTRTPS_DEFAULT_PROFILES_FILE" ]; then + echo "FASTRTPS_DEFAULT_PROFILES_FILE was set (or defaulted to) '$FASTRTPS_DEFAULT_PROFILES_FILE' but that file doesn't exist" + exit 4 +fi + +# +# launch helpers +# +# complicated because of retry system (needed as an alternative to "sleep 5" and praying its enough) +start_ros_nav_stack() { + setsid bash -c " + source /opt/ros/${ROS_DISTRO:-humble}/setup.bash + source /ros2_ws/install/setup.bash + cd ${STACK_ROOT} + echo '[entrypoint] running: ros2 launch vehicle_simulator ${LAUNCH_FILE} ${LAUNCH_ARGS}' + ros2 launch vehicle_simulator ${LAUNCH_FILE} ${LAUNCH_ARGS} + " & + ROS_NAV_PID=$! + echo "[entrypoint] ROS nav stack PID: $ROS_NAV_PID" +} + +stop_ros_nav_stack() { + if [ -n "$ROS_NAV_PID" ] && kill -0 "$ROS_NAV_PID" 2>/dev/null; then + kill -TERM "-$ROS_NAV_PID" 2>/dev/null || kill -TERM "$ROS_NAV_PID" 2>/dev/null || true + for _ in 1 2 3 4 5; do + kill -0 "$ROS_NAV_PID" 2>/dev/null || break + sleep 1 + done + kill -KILL "-$ROS_NAV_PID" 2>/dev/null || kill -KILL "$ROS_NAV_PID" 2>/dev/null || true + fi +} + +start_unity() { + if [ ! -f "$UNITY_EXECUTABLE" ]; then + echo "[entrypoint] ERROR: Unity executable not found: $UNITY_EXECUTABLE" + exit 1 + fi + + # These files are expected by CMU/TARE sim assets. Missing files usually + # indicate a bad mount and can break downstream map-dependent behavior. + for required in map.ply traversable_area.ply; do + if [ ! -f "$UNITY_MESH_DIR/$required" ]; then + echo "[entrypoint] WARNING: missing $UNITY_MESH_DIR/$required" + fi + done + + echo "[entrypoint] Starting Unity: $UNITY_EXECUTABLE" + "$UNITY_EXECUTABLE" & + UNITY_PID=$! + echo "[entrypoint] Unity PID: $UNITY_PID" +} + +has_established_bridge_tcp() { + if ! command -v ss >/dev/null 2>&1; then + return 0 + fi + ss -Htn state established '( sport = :10000 or dport = :10000 )' 2>/dev/null | grep -q . +} + +unity_topics_ready() { + local topics + topics="$(ros2 topic list 2>/dev/null || true)" + + echo "$topics" | grep -Eq '^/registered_scan$' || return 1 + echo "$topics" | grep -Eq '^/camera/image/compressed$' || return 1 + return 0 +} + +bridge_ready() { + # Check only that Unity has established the TCP connection to the bridge. + # unity_topics_ready (ros2 topic list) is intentionally skipped: DDS + # discovery is too slow/unreliable to use as a readiness gate inside the + # container — ros2 topic list consistently fails to see Unity bridge topics + # within any reasonable window even though the publishers ARE registered. + has_established_bridge_tcp || return 1 + return 0 +} + +launch_with_retry() { + local attempt=1 + + while true; do + echo "[entrypoint] Launch attempt ${attempt}: ros2 launch vehicle_simulator ${LAUNCH_FILE} ${LAUNCH_ARGS}" + start_ros_nav_stack + + local deadline=$((SECONDS + UNITY_BRIDGE_CONNECT_TIMEOUT_SEC)) + while [ "$SECONDS" -lt "$deadline" ]; do + if bridge_ready; then + echo "[entrypoint] Unity bridge ready: /registered_scan and /camera/image/compressed present." + return 0 + fi + + if ! kill -0 "$ROS_NAV_PID" 2>/dev/null; then + echo "[entrypoint] ROS nav stack exited during bridge startup." + break + fi + sleep 1 + done + + cat </dev/null || true + ip link set "${LIDAR_INTERFACE}" up 2>/dev/null || true + fi + + # Generate MID360_config.json so the Livox driver knows where to listen + if [ -n "${LIDAR_COMPUTER_IP}" ] && [ -n "${LIDAR_IP}" ]; then + MID360_SRC="${STACK_ROOT}/src/utilities/livox_ros_driver2/config/MID360_config.json" + MID360_INST="/ros2_ws/install/livox_ros_driver2/share/livox_ros_driver2/config/MID360_config.json" + echo "[entrypoint] Generating MID360_config.json (lidar=${LIDAR_IP}, host=${LIDAR_COMPUTER_IP})..." + cat > "${MID360_SRC}" </dev/null || true + fi + + start_ros_nav_stack + + # Start Unitree WebRTC control bridge (subscribes /cmd_vel, enables robot control). + # This is required for the robot connection; also publishes robot state to ROS. + if [[ "${ROBOT_CONFIG_PATH:-}" == *"unitree"* ]]; then + echo "[entrypoint] Starting Unitree WebRTC control (IP: ${UNITREE_IP:-192.168.12.1}, Method: ${UNITREE_CONN:-LocalAP})..." + ros2 launch unitree_webrtc_ros unitree_control.launch.py \ + robot_ip:="${UNITREE_IP:-192.168.12.1}" \ + connection_method:="${UNITREE_CONN:-LocalAP}" & + fi +elif [ "$MODE" = "bagfile" ]; then + if [ "$USE_ROUTE_PLANNER" = "true" ]; then + LAUNCH_FILE="system_bagfile_with_route_planner.launch.py" + else + LAUNCH_FILE="system_bagfile.launch.py" + fi + if [ ! -e "$BAGFILE_PATH" ]; then + echo "[entrypoint] ERROR: BAGFILE_PATH set but not found: $BAGFILE_PATH" + exit 1 + fi + echo "[entrypoint] Playing bag: ros2 bag play --clock $BAGFILE_PATH" + ros2 bag play "$BAGFILE_PATH" --clock & + start_ros_nav_stack +else + echo "MODE must be one of 'simulation', 'external_sim', 'hardware', 'bagfile' but got '$MODE'" + exit 19 +fi + + +# +# +# optional services +# +# +if [ "$USE_RVIZ" = "true" ]; then + if [ "$USE_ROUTE_PLANNER" = "true" ]; then + RVIZ_CFG="/ros2_ws/src/ros-navigation-autonomy-stack/src/route_planner/far_planner/rviz/default.rviz" + else + RVIZ_CFG="/ros2_ws/src/ros-navigation-autonomy-stack/src/base_autonomy/vehicle_simulator/rviz/vehicle_simulator.rviz" + fi + # check if file exists + if ! [ -f "$RVIZ_CFG" ]; then + echo "RVIZ_CFG was set to '$RVIZ_CFG' but that file doesn't exist" + exit 19 + fi + ros2 run rviz2 rviz2 -d "$RVIZ_CFG" & +elif ! [ "$USE_RVIZ" = "false" ]; then + echo "USE_RVIZ must be true or false but got: $USE_RVIZ" + exit 20 +fi + + +# Convert /foxglove_teleop Twist → /cmd_vel TwistStamped, goal relay, and Foxglove Bridge +if [ "$ENABLE_FOXGLOVE" = "true" ]; then + if [ -f "/usr/local/bin/twist_relay.py" ]; then + python3 /usr/local/bin/twist_relay.py & + else + echo "unable to start foxglove relay!" + exit 21 + fi + if [ -f "/usr/local/bin/goal_autonomy_relay.py" ]; then + python3 /usr/local/bin/goal_autonomy_relay.py & + fi + ros2 launch foxglove_bridge foxglove_bridge_launch.xml port:="${FOXGLOVE_PORT}" & +elif ! [ "$ENABLE_FOXGLOVE" = "false" ]; then + echo "ENABLE_FOXGLOVE must be true or false but got: $ENABLE_FOXGLOVE" + exit 22 +fi + +# start module (when being run from ) +if [ "$#" -gt 0 ]; then + + # make sure pip install went well + if ! wait "$PIP_INSTALL_PID"; then + cat "$pip_install_log_path" + echo "[entrypoint] WARNING: pip install -e failed; see $pip_install_log_path" + exit 29 + fi + + exec python -m dimos.core.docker_module run "$@" +fi + +# Otherwise keep container alive with the nav stack process. +wait "$ROS_NAV_PID" diff --git a/dimos/navigation/rosnav/fastdds.xml b/dimos/navigation/rosnav/fastdds.xml new file mode 100644 index 0000000000..ee054ed72b --- /dev/null +++ b/dimos/navigation/rosnav/fastdds.xml @@ -0,0 +1,43 @@ + + + + + + ros2_navigation_participant + + + SIMPLE + + 10 + 0 + + + 3 + 0 + + + + + 10485760 + 10485760 + true + + + + + + + udp_transport + UDPv4 + 10485760 + 10485760 + 65500 + + + + shm_transport + SHM + 10485760 + + + diff --git a/dimos/navigation/rosnav/fixtures/test_agentic_sim_navigate.json b/dimos/navigation/rosnav/fixtures/test_agentic_sim_navigate.json new file mode 100644 index 0000000000..cacb8c2afe --- /dev/null +++ b/dimos/navigation/rosnav/fixtures/test_agentic_sim_navigate.json @@ -0,0 +1,19 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "begin_exploration", + "args": {}, + "id": "call_explore_001", + "type": "tool_call" + } + ] + }, + { + "content": "I've started autonomous exploration. The robot is now moving around to map the environment.", + "tool_calls": [] + } + ] +} diff --git a/dimos/navigation/rosnav/fixtures/test_agentic_sim_stop.json b/dimos/navigation/rosnav/fixtures/test_agentic_sim_stop.json new file mode 100644 index 0000000000..dddc5d9f79 --- /dev/null +++ b/dimos/navigation/rosnav/fixtures/test_agentic_sim_stop.json @@ -0,0 +1,19 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "stop_navigation", + "args": {}, + "id": "call_stop_001", + "type": "tool_call" + } + ] + }, + { + "content": "I've stopped the robot.", + "tool_calls": [] + } + ] +} diff --git a/dimos/navigation/rosnav/fixtures/test_rosnav_agentic_goto.json b/dimos/navigation/rosnav/fixtures/test_rosnav_agentic_goto.json new file mode 100644 index 0000000000..0c87badc8c --- /dev/null +++ b/dimos/navigation/rosnav/fixtures/test_rosnav_agentic_goto.json @@ -0,0 +1,22 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "goto_global", + "args": { + "x": 2.0, + "y": 0.0 + }, + "id": "call_nav_001", + "type": "tool_call" + } + ] + }, + { + "content": "I've navigated the robot to map coordinates (2.0, 0.0). The robot has arrived at the target location.", + "tool_calls": [] + } + ] +} diff --git a/dimos/navigation/rosnav/rosnav_module.py b/dimos/navigation/rosnav/rosnav_module.py new file mode 100644 index 0000000000..b5f3d9af6c --- /dev/null +++ b/dimos/navigation/rosnav/rosnav_module.py @@ -0,0 +1,1062 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +NavBot class for navigation-related functionality. +Encapsulates ROS transport and topic remapping for Unitree robots. +""" + +from dataclasses import field +import logging +from pathlib import Path +import platform +import threading +import time +from typing import Any + +import cv2 +import numpy as np + +# ROS message imports: available inside the ROS2 container, but may be missing on the host +try: # pragma: no cover - import-time environment dependent + from geometry_msgs.msg import ( # type: ignore[attr-defined] + PointStamped as ROSPointStamped, + PoseStamped as ROSPoseStamped, + TwistStamped as ROSTwistStamped, + ) + from nav_msgs.msg import Odometry as ROSOdometry, Path as ROSPath # type: ignore[attr-defined] + from sensor_msgs.msg import ( # type: ignore[attr-defined] + CompressedImage as ROSCompressedImage, + Joy as ROSJoy, + PointCloud2 as ROSPointCloud2, + ) + from std_msgs.msg import ( # type: ignore[attr-defined] + Bool as ROSBool, + Int8 as ROSInt8, + ) + from tf2_msgs.msg import TFMessage as ROSTFMessage # type: ignore[attr-defined] +except ModuleNotFoundError: + # Running outside a ROS2 environment (e.g. host CLI without ROS Python packages). + # Define minimal placeholder types so blueprints can import without failing. + class _Stub: # pragma: no cover - host-only stub + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + ROSPointStamped = _Stub # type: ignore[assignment] + ROSPoseStamped = _Stub # type: ignore[assignment] + ROSTwistStamped = _Stub # type: ignore[assignment] + ROSOdometry = _Stub # type: ignore[assignment] + ROSPath = _Stub # type: ignore[assignment] + ROSCompressedImage = _Stub # type: ignore[assignment] + ROSJoy = _Stub # type: ignore[assignment] + ROSPointCloud2 = _Stub # type: ignore[assignment] + ROSBool = _Stub # type: ignore[assignment] + ROSInt8 = _Stub # type: ignore[assignment] + ROSTFMessage = _Stub # type: ignore[assignment] + +from dimos_lcm.std_msgs import Bool + +from dimos.agents.annotation import skill +from dimos.core.core import rpc +from dimos.core.docker_module import DockerModuleConfig +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.nav_msgs.Path import Path as NavPath +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.tf2_msgs.TFMessage import TFMessage +from dimos.navigation.base import NavigationInterface, NavigationState +from dimos.utils.data import get_data +from dimos.utils.generic import is_jetson +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import euler_to_quaternion + +logger = setup_logger(level=logging.INFO) + + +class ROSNavConfig(DockerModuleConfig): + # Module settings + local_pointcloud_freq: float = 2.0 + global_map_freq: float = 1.0 + sensor_to_base_link_transform: Transform = field( + default_factory=lambda: Transform(frame_id="sensor", child_frame_id="base_link") + ) + + # Docker settings + docker_restart_policy: str = "no" # Don't auto-restart; host process manages lifecycle + docker_startup_timeout: float = 180 + docker_image: str = "dimos_rosnav:humble" + docker_shm_size: str = "8g" + docker_entrypoint: str = "/usr/local/bin/entrypoint.sh" + docker_file: Path = Path(__file__).parent / "Dockerfile" + docker_build_context: Path = Path(__file__).parent.parent.parent.parent + docker_build_extra_args: list[str] = field(default_factory=lambda: ["--network", "host"]) + docker_build_args: dict[str, str] = field( + default_factory=lambda: { + "TARGETARCH": "arm64" if platform.machine() == "aarch64" else "amd64" + } + ) + docker_gpus: str | None = None if is_jetson() else "all" + docker_extra_args: list[str] = field( + default_factory=lambda: [ + "--cap-add=NET_ADMIN", + *(["--runtime=nvidia"] if is_jetson() else []), + ] + ) + docker_env: dict[str, str] = field( + default_factory=lambda: { + "ROS_DISTRO": "humble", + "ROS_DOMAIN_ID": "42", + "RMW_IMPLEMENTATION": "rmw_fastrtps_cpp", + "FASTRTPS_DEFAULT_PROFILES_FILE": "/ros2_ws/config/fastdds.xml", + "QT_X11_NO_MITSHM": "1", + "NVIDIA_VISIBLE_DEVICES": "all", + "NVIDIA_DRIVER_CAPABILITIES": "all", + # Give DDS topic discovery enough time after Unity registers publishers. + # Default in the entrypoint is 25s which is too short on some machines. + "UNITY_BRIDGE_CONNECT_TIMEOUT_SEC": "60", + } + ) + docker_volumes: list[tuple[str, str, str]] = field(default_factory=lambda: []) + docker_devices: list[str] = field( + default_factory=lambda: [ + "/dev/input:/dev/input", + *(["/dev/dri:/dev/dri"] if Path("/dev/dri").exists() else []), + ] + ) + + # Vehicle geometry + # Height of the robot's base_link above the ground plane (metres). + # The CMU nav stack uses this to position the simulated sensor origin; + # it is forwarded to the ROS launch as the ``vehicleHeight`` parameter. + vehicle_height: float = 0.75 + + # Teleop override + # Seconds of silence after the last teleop cmd_vel before switching back + # to the ROS nav stack. At the end of the cooldown the module publishes + # a goal at the robot's current position so the nav stack re-engages at + # standstill instead of resuming the old goal. + teleop_cooldown_sec: float = 1.0 + + # Runtime mode settings + # mode controls which ROS launch file the entrypoint selects: + # "simulation" — system_simulation[_with_route_planner].launch.py + Unity if present + # "unity_sim" — same as simulation but hard-exits if Unity binary is missing + # "external_sim" — same launch as simulation, but no internal Unity (sensor data from LCM) + # "hardware" — system_real_robot[_with_route_planner].launch.py + # "bagfile" — system_bagfile[_with_route_planner].launch.py + use_sim_time + # Setting bagfile_path automatically forces mode to "bagfile". + mode: str = "hardware" + use_route_planner: bool = False + localization_method: str = "arise_slam" + robot_config_path: str = "unitree/unitree_g1" + robot_ip: str = "" + bagfile_path: str | Path = "" # host-side path to bag; remapped into container at runtime + + # use_rviz: whether to launch RViz2 inside the container. + # None (default) → True for simulation/unity_sim modes, False otherwise + # (mirrors the unconditional RViz launch in run_both.sh for simulation) + use_rviz: bool = False + foxglove_port: int = 8765 + + # Hardware sensor / network settings (used when mode="hardware") + # lidar_interface: host ethernet interface connected to Mid-360 lidar (e.g. "eth0") + # lidar_computer_ip: IP to assign/use on that interface for lidar communication + # lidar_gateway: gateway IP for the lidar subnet + # lidar_ip: IP address of the Mid-360 lidar device itself + # unitree_ip: Unitree robot IP for WebRTC connection + # unitree_conn: WebRTC connection method — "LocalAP", "LocalSTA", or "Remote" + lidar_interface: str = "" + lidar_computer_ip: str = "" + lidar_gateway: str = "" + lidar_ip: str = "" + unitree_ip: str = "192.168.12.1" + unitree_conn: str = "LocalAP" + + # When True, download and mount sim assets (map.ply, traversable_area.ply) even + # in hardware mode. Used when running hardware-mode nav stack with an external + # simulator that still needs the pre-built map data. + # TODO: remove once the nav stack can build maps purely from incoming lidar. + mount_sim_assets: bool = False + + def model_post_init(self, __context: object) -> None: + import os + + effective_mode = "bagfile" if self.bagfile_path else self.mode + self.docker_env["MODE"] = effective_mode + + # Hardware sensor env vars — read by entrypoint.sh when MODE=hardware. + is_hardware = effective_mode == "hardware" + if is_hardware: + # Privileged mode is required for ip link/ip addr and sysctl inside the container. + self.docker_privileged = True + self.docker_env["LIDAR_INTERFACE"] = self.lidar_interface + self.docker_env["LIDAR_COMPUTER_IP"] = self.lidar_computer_ip + self.docker_env["LIDAR_GATEWAY"] = self.lidar_gateway + self.docker_env["LIDAR_IP"] = self.lidar_ip + self.docker_env["UNITREE_IP"] = self.unitree_ip + self.docker_env["UNITREE_CONN"] = self.unitree_conn + + if self.bagfile_path: + bag_path = Path(self.bagfile_path).expanduser() + if bag_path.exists(): + bag_path = bag_path.resolve() + bag_dir = bag_path.parent + bag_name = bag_path.name + container_bag_dir = "/ros2_ws/bagfiles" + + self.docker_volumes.append((str(bag_dir), container_bag_dir, "rw")) + self.docker_env["BAGFILE_PATH"] = f"{container_bag_dir}/{bag_name}" + else: + self.docker_env["BAGFILE_PATH"] = str(self.bagfile_path) + + self.docker_env["USE_RVIZ"] = "true" if self.use_rviz else "false" + self.docker_env["FOXGLOVE_PORT"] = str(self.foxglove_port) + self.docker_env["USE_ROUTE_PLANNER"] = "true" if self.use_route_planner else "false" + self.docker_env["LOCALIZATION_METHOD"] = self.localization_method + self.docker_env["ROBOT_CONFIG_PATH"] = self.robot_config_path + self.docker_env["ROBOT_IP"] = self.robot_ip + self.docker_env["VEHICLE_HEIGHT"] = str(self.vehicle_height) + + # Pass host DISPLAY through for X11 forwarding (RViz, Unity) + if display := os.environ.get("DISPLAY", ":0"): + self.docker_env["DISPLAY"] = display + + self.docker_env["QT_X11_NO_MITSHM"] = "1" + + repo_root = Path(__file__).parent.parent.parent.parent + self.docker_volumes += [ + # X11 socket for display forwarding (RViz, Unity) + ("/tmp/.X11-unix", "/tmp/.X11-unix", "rw"), + # Mount live dimos source so the module is always up-to-date + (str(repo_root), "/workspace/dimos", "rw"), + # Mount DDS config (fastdds.xml) from host — single file mount + # avoids shadowing the entire /ros2_ws/config directory + (str(Path(__file__).parent / "fastdds.xml"), "/ros2_ws/config/fastdds.xml", "ro"), + # Note: most of the mounts below are only needed for development + # Mount entrypoint script so changes don't require a rebuild + ( + str(Path(__file__).parent / "entrypoint.sh"), + "/usr/local/bin/entrypoint.sh", + "ro", + ), + ] + + # Only download and mount sim assets for simulation modes (avoids slow LFS pull in hardware mode). + # mount_sim_assets overrides this for hardware mode with external sim. + # TODO: remove mount_sim_assets once nav stack can build maps from lidar alone. + if effective_mode in ("simulation", "unity_sim", "external_sim") or self.mount_sim_assets: + sim_data_dir = str(get_data("office_building_1")) + self.docker_volumes += [ + # Mount Unity sim (office_building_1) — downloaded via get_data / LFS + # Provides map.ply, traversable_area.ply and environment/Model.x86_64 + ( + sim_data_dir, + "/ros2_ws/src/ros-navigation-autonomy-stack/src/base_autonomy/vehicle_simulator/mesh/unity/", + "rw", + ), + # real_world uses the same sim data + ( + sim_data_dir, + "/ros2_ws/src/ros-navigation-autonomy-stack/src/base_autonomy/vehicle_simulator/mesh/real_world/", + "rw", + ), + # Some CMU stack nodes (e.g., visualizationTools.cpp) rewrite install paths + # to /ros2_ws/src/base_autonomy/... directly. Mirror the same sim asset + # directory at that legacy path to avoid "map.ply not found" errors. + ( + sim_data_dir, + "/ros2_ws/src/base_autonomy/vehicle_simulator/mesh/unity/", + "rw", + ), + ( + sim_data_dir, + "/ros2_ws/src/base_autonomy/vehicle_simulator/mesh/real_world/", + "rw", + ), + ] + + # Mount Xauthority cookie for X11 forwarding. + # Honour $XAUTHORITY on the host (falls back to ~/.Xauthority) and + # place it at /tmp/.Xauthority inside the container so it is + # accessible regardless of which user the container runs as. + xauth_host = Path(os.environ.get("XAUTHORITY", str(Path.home() / ".Xauthority"))) + if xauth_host.exists(): + self.docker_volumes.append((str(xauth_host), "/tmp/.Xauthority", "ro")) + self.docker_env["XAUTHORITY"] = "/tmp/.Xauthority" + + +class ROSNav(Module, NavigationInterface): + config: ROSNavConfig + default_config = ROSNavConfig + + goal_request: In[PoseStamped] + clicked_point: In[PointStamped] + stop_explore_cmd: In[Bool] + tele_cmd_vel: In[Twist] + + # External sensor inputs — when connected, data is republished to ROS2 + # topics inside the Docker container so the nav stack can consume them + # (e.g. from an external simulator). + ext_registered_scan: In[PointCloud2] + ext_odometry: In[Odometry] + + lidar: Out[PointCloud2] + terrain_map: Out[PointCloud2] + global_pointcloud: Out[PointCloud2] + rosnav_overall_map: Out[PointCloud2] + odom: Out[PoseStamped] + goal_active: Out[PoseStamped] + goal_reached: Out[Bool] + path: Out[NavPath] + cmd_vel: Out[Twist] + + _current_position_running: bool = False + _spin_thread: threading.Thread | None = None + _goal_reach: bool | None = None + + # Navigation state tracking for NavigationInterface + _navigation_state: NavigationState = NavigationState.IDLE + _state_lock: threading.Lock + _navigation_thread: threading.Thread | None = None + _current_goal: PoseStamped | None = None + _goal_reached: bool = False + + # Teleop override state + _teleop_active: bool = False + _teleop_lock: threading.Lock + _teleop_timer: threading.Timer | None = None + _last_odom: PoseStamped | None = None + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + import rclpy + from rclpy.node import Node + + # Initialize state tracking + self._state_lock = threading.Lock() + self._teleop_lock = threading.Lock() + self._navigation_state = NavigationState.IDLE + self._goal_reached = False + + if not rclpy.ok(): # type: ignore[attr-defined] + rclpy.init() + + self._node = Node("navigation_module") + + # ROS2 Publishers + self.goal_pose_pub = self._node.create_publisher(ROSPoseStamped, "/goal_pose", 10) + self.cancel_goal_pub = self._node.create_publisher(ROSBool, "/cancel_goal", 10) + self.soft_stop_pub = self._node.create_publisher(ROSInt8, "/stop", 10) + self.joy_pub = self._node.create_publisher(ROSJoy, "/joy", 10) + + # ROS2 Subscribers + self.goal_reached_sub = self._node.create_subscription( + ROSBool, "/goal_reached", self._on_ros_goal_reached, 10 + ) + from rclpy.qos import QoSProfile, ReliabilityPolicy # type: ignore[attr-defined] + + self.cmd_vel_sub = self._node.create_subscription( + ROSTwistStamped, + "/cmd_vel", + self._on_ros_cmd_vel, + QoSProfile(depth=10, reliability=ReliabilityPolicy.BEST_EFFORT), + ) + self.goal_waypoint_sub = self._node.create_subscription( + ROSPointStamped, "/way_point", self._on_ros_goal_waypoint, 10 + ) + self.registered_scan_sub = self._node.create_subscription( + ROSPointCloud2, "/registered_scan", self._on_ros_registered_scan, 10 + ) + + self.terrain_map_sub = self._node.create_subscription( + ROSPointCloud2, "/terrain_map", self._on_ros_terrain_map, 10 + ) + self.global_pointcloud_sub = self._node.create_subscription( + ROSPointCloud2, "/terrain_map_ext", self._on_ros_global_map, 10 + ) + + self.rosnav_overall_map_sub = self._node.create_subscription( + ROSPointCloud2, "/overall_map", self._on_ros_rosnav_overall_map, 10 + ) + + self.path_sub = self._node.create_subscription(ROSPath, "/path", self._on_ros_path, 10) + self.tf_sub = self._node.create_subscription(ROSTFMessage, "/tf", self._on_ros_tf, 10) + self.odom_sub = self._node.create_subscription( + ROSOdometry, "/state_estimation", self._on_ros_odom, 10 + ) + + # ROS2 publisher for external sensor data. + # When ext_registered_scan input is connected, incoming DimOS PointCloud2 + # messages are converted and republished on this ROS2 topic so the nav + # stack inside the container can consume them. + self._ext_scan_pub = self._node.create_publisher(ROSPointCloud2, "/registered_scan", 10) + self._ext_odom_pub = self._node.create_publisher(ROSOdometry, "/state_estimation", 10) + + logger.info("NavigationModule initialized with ROS2 node") + + @rpc + def start(self) -> None: + try: + self._running = True + + # Create and start the spin thread for ROS2 node spinning + self._spin_thread = threading.Thread( + target=self._spin_node, daemon=True, name="ROS2SpinThread" + ) + self._spin_thread.start() + + self.goal_request.subscribe(self._on_goal_pose) + self.clicked_point.subscribe(lambda pt: self._on_goal_pose(pt.to_pose_stamped())) + self.stop_explore_cmd.subscribe(self._on_stop_cmd) + self.tele_cmd_vel.subscribe(self._on_tele_cmd_vel) + + # External sensor inputs — republish to ROS2 topics for the nav stack + self.ext_registered_scan.subscribe(self._on_ext_scan) + self.ext_odometry.subscribe(self._on_ext_odom) + + logger.info("NavigationModule started with ROS2 spinning") + except Exception as e: + logger.error(f"ROSNav start() failed: {e}", exc_info=True) + + def _spin_node(self) -> None: + import rclpy + + while self._running and rclpy.ok(): # type: ignore[attr-defined] + try: + rclpy.spin_once(self._node, timeout_sec=0.1) + except Exception as e: + if self._running: + logger.error(f"ROS2 spin error: {e}") + + def _on_ros_goal_reached(self, msg: ROSBool) -> None: + self._goal_reach = msg.data + self.goal_reached.publish(Bool(data=msg.data)) + if msg.data: + with self._state_lock: + self._goal_reached = True + self._navigation_state = NavigationState.IDLE + + def _on_ros_goal_waypoint(self, msg: ROSPointStamped) -> None: + dimos_pose = PoseStamped( + ts=time.time(), + frame_id=msg.header.frame_id, + position=Vector3(msg.point.x, msg.point.y, msg.point.z), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + self.goal_active.publish(dimos_pose) + + def _on_ros_cmd_vel(self, msg: ROSTwistStamped) -> None: + if self._teleop_active: + return # Suppress nav stack cmd_vel during teleop override + self.cmd_vel.publish(_twist_from_ros(msg.twist)) + + def _on_ros_registered_scan(self, msg: ROSPointCloud2) -> None: + self.lidar.publish(_pc2_from_ros(msg)) + + def _on_ros_terrain_map(self, msg: "ROSPointCloud2") -> None: + self.terrain_map.publish(_pc2_from_ros(msg)) + + def _on_ros_global_map(self, msg: ROSPointCloud2) -> None: + self.global_pointcloud.publish(_pc2_from_ros(msg)) + + def _on_ros_rosnav_overall_map(self, msg: ROSPointCloud2) -> None: + # FIXME: disabling for now for perf onboard G1 (and cause we don't have an overall map rn) + # self.rosnav_overall_map.publish(_pc2_from_ros(msg)) + pass + + def _on_ros_path(self, msg: ROSPath) -> None: + dimos_path = _path_from_ros(msg) + # The CMU nav stack publishes the path in the "vehicle" frame which + # corresponds to "sensor" in the DimOS TF tree (map → sensor). + dimos_path.frame_id = "sensor" + self.path.publish(dimos_path) + + def _on_ros_odom(self, msg: "ROSOdometry") -> None: + ts = msg.header.stamp.sec + msg.header.stamp.nanosec / 1e9 + p = msg.pose.pose.position + o = msg.pose.pose.orientation + pose = PoseStamped( + ts=ts, + frame_id=msg.header.frame_id, + position=Vector3(p.x, p.y, p.z), + orientation=Quaternion(o.x, o.y, o.z, o.w), + ) + self._last_odom = pose + self.odom.publish(pose) + + def _on_ros_tf(self, msg: ROSTFMessage) -> None: + # In external_sim mode, the UnityBridgeModule owns the ground-truth + # transforms (map→sensor, map→world). Don't republish SLAM TF here + # or the two sources will fight and cause jitter. + if self.config.mode == "external_sim": + return + + ros_tf = _tfmessage_from_ros(msg) + + # In hardware/bagfile mode the SLAM initialises the sensor at the + # map-frame origin, placing the ground plane at z = −vehicleHeight. + # Shift the world frame down so that ground aligns with z = 0 in + # Rerun. In simulation the map frame already has ground at z = 0. + is_sim = self.config.mode in ("simulation", "unity_sim") + z_offset = 0.0 if is_sim else -self.config.vehicle_height + + map_to_world_tf = Transform( + translation=Vector3(0.0, 0.0, z_offset), + rotation=euler_to_quaternion(Vector3(0.0, 0.0, 0.0)), + frame_id="map", + child_frame_id="world", + ts=time.time(), + ) + + self.tf.publish( + self.config.sensor_to_base_link_transform.now(), + map_to_world_tf, + *ros_tf.transforms, + ) + + def _on_goal_pose(self, msg: PoseStamped) -> None: + self.set_goal(msg) + + def _on_cancel_goal(self, msg: Bool) -> None: + if msg.data: + self.stop() + + def _on_stop_cmd(self, msg: Bool) -> None: + if not msg.data: + return + logger.info("Stop command received, cancelling navigation") + self.stop_navigation() + # Set goal to current position so the nav stack re-engages at standstill + if self._last_odom is not None: + self._set_autonomy_mode() + ros_pose = _pose_stamped_to_ros(self._last_odom) + self.goal_pose_pub.publish(ros_pose) + + def _on_tele_cmd_vel(self, msg: Twist) -> None: + with self._teleop_lock: + if not self._teleop_active: + self._teleop_active = True + self.stop_navigation() + logger.info("Teleop override: keyboard control active") + + # Cancel existing cooldown timer and start a new one + if self._teleop_timer is not None: + self._teleop_timer.cancel() + self._teleop_timer = threading.Timer( + self.config.teleop_cooldown_sec, + self._end_teleop_override, + ) + self._teleop_timer.daemon = True + self._teleop_timer.start() + + # Forward teleop command to output + self.cmd_vel.publish(msg) + + def _end_teleop_override(self) -> None: + with self._teleop_lock: + self._teleop_active = False + self._teleop_timer = None + + # Set goal to current position so the nav stack resumes at standstill + if self._last_odom is not None: + logger.info("Teleop cooldown expired, setting goal to current position") + self._set_autonomy_mode() + ros_pose = _pose_stamped_to_ros(self._last_odom) + self.goal_pose_pub.publish(ros_pose) + else: + logger.warning("Teleop cooldown expired but no odom available") + + # -- External sensor input callbacks -- + # Convert DimOS messages to ROS2 and republish on ROS2 topics. + + def _on_ext_scan(self, pc2: PointCloud2) -> None: + self._ext_scan_pub.publish(_pc2_to_ros(pc2)) + + def _on_ext_odom(self, odom: Odometry) -> None: + self._ext_odom_pub.publish(_odometry_to_ros(odom)) + + def _set_autonomy_mode(self) -> None: + joy_msg = ROSJoy() # type: ignore[no-untyped-call] + joy_msg.axes = [ + 0.0, # axis 0 + 0.0, # axis 1 + -1.0, # axis 2 + 0.0, # axis 3 + 1.0, # axis 4 + 1.0, # axis 5 + 0.0, # axis 6 + 0.0, # axis 7 + ] + joy_msg.buttons = [ + 0, # button 0 + 0, # button 1 + 0, # button 2 + 0, # button 3 + 0, # button 4 + 0, # button 5 + 0, # button 6 + 1, # button 7 - controls autonomy mode + 0, # button 8 + 0, # button 9 + 0, # button 10 + ] + self.joy_pub.publish(joy_msg) + logger.info("Setting autonomy mode via Joy message") + + @skill + def goto(self, x: float, y: float) -> str: + """ + move the robot in relative coordinates + x is forward, y is left + + goto(1, 0) will move the robot forward by 1 meter + """ + pose_to = PoseStamped( + position=Vector3(x, y, 0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + ts=time.time(), + ) + + self.navigate_to(pose_to) + return "arrived" + + @skill + def goto_global(self, x: float, y: float) -> str: + """ + go to coordinates x,y in the map frame + 0,0 is your starting position + """ + target = PoseStamped( + ts=time.time(), + frame_id="map", + position=Vector3(x, y, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + + self.navigate_to(target) + + return f"arrived to {x:.2f}, {y:.2f}" + + @rpc + def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: + """ + Navigate to a target pose by publishing to ROS topics. + + Args: + pose: Target pose to navigate to + timeout: Maximum time to wait for goal (seconds) + + Returns: + True if navigation was successful + """ + logger.info( + f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f} @ {pose.frame_id})" + ) + + self._goal_reach = None + self._set_autonomy_mode() + + # Enable soft stop (0 = enable) + soft_stop_msg = ROSInt8() # type: ignore[no-untyped-call] + soft_stop_msg.data = 0 + self.soft_stop_pub.publish(soft_stop_msg) + + ros_pose = _pose_stamped_to_ros(pose) + self.goal_pose_pub.publish(ros_pose) + + # Wait for goal to be reached + start_time = time.time() + while time.time() - start_time < timeout: + if self._goal_reach is not None: + soft_stop_msg.data = 2 + self.soft_stop_pub.publish(soft_stop_msg) + return self._goal_reach + time.sleep(0.1) + + self.stop_navigation() + logger.warning(f"Navigation timed out after {timeout} seconds") + return False + + @rpc + def stop_navigation(self) -> bool: + """ + Stop current navigation by publishing to ROS topics. + + Returns: + True if stop command was sent successfully + """ + logger.info("Stopping navigation") + + cancel_msg = ROSBool() # type: ignore[no-untyped-call] + cancel_msg.data = True + self.cancel_goal_pub.publish(cancel_msg) + + soft_stop_msg = ROSInt8() # type: ignore[no-untyped-call] + soft_stop_msg.data = 2 + self.soft_stop_pub.publish(soft_stop_msg) + + # Unblock any waiting navigate_to() call + self._goal_reach = False + + with self._state_lock: + self._navigation_state = NavigationState.IDLE + self._current_goal = None + self._goal_reached = False + + return True + + @rpc + def set_goal(self, goal: PoseStamped, timeout: float = 60.0) -> bool: + """Set a new navigation goal (non-blocking).""" + with self._state_lock: + self._current_goal = goal + self._goal_reached = False + self._navigation_state = NavigationState.FOLLOWING_PATH + + # Cancel previous navigation and wait for thread to exit. + # stop_navigation() sets _goal_reach = False which unblocks navigate_to(). + if self._navigation_thread and self._navigation_thread.is_alive(): + logger.warning("Previous navigation still running, cancelling") + self.stop_navigation() + self._navigation_thread.join(timeout=2.0) + if self._navigation_thread.is_alive(): + logger.warning("Previous navigation thread did not exit in time, proceeding anyway") + + self._navigation_thread = threading.Thread( + target=self._navigate_to_goal_async, + args=(goal, timeout), + daemon=True, + name="ROSNavNavigationThread", + ) + self._navigation_thread.start() + + return True + + def _navigate_to_goal_async(self, goal: PoseStamped, timeout: float = 60.0) -> None: + """Internal method to handle navigation in a separate thread.""" + try: + result = self.navigate_to(goal, timeout=timeout) + with self._state_lock: + self._goal_reached = result + self._navigation_state = NavigationState.IDLE + except Exception as e: + logger.error(f"Navigation failed: {e}") + with self._state_lock: + self._goal_reached = False + self._navigation_state = NavigationState.IDLE + + @rpc + def get_state(self) -> NavigationState: + """Get the current state of the navigator.""" + with self._state_lock: + return self._navigation_state + + @rpc + def is_goal_reached(self) -> bool: + """Check if the current goal has been reached.""" + with self._state_lock: + return self._goal_reached + + @rpc + def cancel_goal(self) -> bool: + """Cancel the current navigation goal.""" + + with self._state_lock: + had_goal = self._current_goal is not None + + if had_goal: + self.stop_navigation() + + return had_goal + + @rpc + def stop(self) -> None: + """Stop the navigation module and clean up resources.""" + self.stop_navigation() + try: + self._running = False + + with self._teleop_lock: + if self._teleop_timer is not None: + self._teleop_timer.cancel() + self._teleop_timer = None + self._teleop_active = False + + if self._spin_thread and self._spin_thread.is_alive(): + self._spin_thread.join(timeout=1.0) + + if hasattr(self, "_node") and self._node: + self._node.destroy_node() # type: ignore[no-untyped-call] + + except Exception as e: + logger.error(f"Error during shutdown: {e}") + finally: + super().stop() + + +def _pose_stamped_to_ros(pose: PoseStamped) -> "ROSPoseStamped": + """Convert a DimOS PoseStamped to a ROS2 geometry_msgs/PoseStamped.""" + msg = ROSPoseStamped() + msg.header.frame_id = pose.frame_id + ts_sec = int(pose.ts) + msg.header.stamp.sec = ts_sec + msg.header.stamp.nanosec = int((pose.ts - ts_sec) * 1_000_000_000) + msg.pose.position.x = float(pose.position.x) + msg.pose.position.y = float(pose.position.y) + msg.pose.position.z = float(pose.position.z) + msg.pose.orientation.x = float(pose.orientation.x) + msg.pose.orientation.y = float(pose.orientation.y) + msg.pose.orientation.z = float(pose.orientation.z) + msg.pose.orientation.w = float(pose.orientation.w) + return msg + + +def _image_from_ros_compressed(msg: "ROSCompressedImage") -> Image: + """Convert a ROS2 sensor_msgs/CompressedImage to a DimOS Image.""" + ts = msg.header.stamp.sec + msg.header.stamp.nanosec / 1e9 + frame_id = msg.header.frame_id + arr = np.frombuffer(bytes(msg.data), dtype=np.uint8) + bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if bgr is None: + return Image(frame_id=frame_id, ts=ts) + return Image(data=bgr, format=ImageFormat.BGR, frame_id=frame_id, ts=ts) + + +def _pc2_from_ros(msg: "ROSPointCloud2") -> PointCloud2: + """Convert a ROS2 sensor_msgs/PointCloud2 to a DimOS PointCloud2.""" + ts = msg.header.stamp.sec + msg.header.stamp.nanosec / 1e9 + frame_id = msg.header.frame_id + + if msg.width == 0 or msg.height == 0: + return PointCloud2(frame_id=frame_id, ts=ts) + + # ROS PointField datatype → (numpy dtype suffix, byte size) + _DTYPE_MAP = {1: "i1", 2: "u1", 3: "i2", 4: "u2", 5: "i4", 6: "u4", 7: "f4", 8: "f8"} + _SIZE_MAP = {1: 1, 2: 1, 3: 2, 4: 2, 5: 4, 6: 4, 7: 4, 8: 8} + + x_off = y_off = z_off = None + x_dt = y_dt = z_dt = 7 # default: FLOAT32 + for f in msg.fields: + if f.name == "x": + x_off, x_dt = f.offset, f.datatype + elif f.name == "y": + y_off, y_dt = f.offset, f.datatype + elif f.name == "z": + z_off, z_dt = f.offset, f.datatype + + if any(o is None for o in [x_off, y_off, z_off]): + raise ValueError("ROS PointCloud2 missing x/y/z fields") + + num_points = msg.width * msg.height + raw = bytes(msg.data) + step = msg.point_step + end = ">" if msg.is_bigendian else "<" + + # Fast path: float32 x/y/z at offsets 0/4/8 (little-endian) + if ( + x_off == 0 + and y_off == 4 + and z_off == 8 + and step >= 12 + and x_dt == 7 + and y_dt == 7 + and z_dt == 7 + and not msg.is_bigendian + ): + if step == 12: + points = np.frombuffer(raw, dtype=np.float32).reshape(-1, 3) + else: + dt = np.dtype([("x", "= 24 + and x_dt == 8 + and y_dt == 8 + and z_dt == 8 + and not msg.is_bigendian + ): + if step == 24: + points = np.frombuffer(raw, dtype=np.float64).reshape(-1, 3).astype(np.float32) + else: + dt = np.dtype([("x", " Twist: + """Convert a ROS2 geometry_msgs/Twist (the .twist field of TwistStamped) to DimOS Twist.""" + return Twist( + linear=Vector3(msg.linear.x, msg.linear.y, msg.linear.z), + angular=Vector3(msg.angular.x, msg.angular.y, msg.angular.z), + ) + + +def _path_from_ros(msg: "ROSPath") -> NavPath: + """Convert a ROS2 nav_msgs/Path to a DimOS Path.""" + ts = msg.header.stamp.sec + msg.header.stamp.nanosec / 1e9 + frame_id = msg.header.frame_id + poses = [] + for ps in msg.poses: + pose_ts = ps.header.stamp.sec + ps.header.stamp.nanosec / 1e9 + p = ps.pose.position + o = ps.pose.orientation + poses.append( + PoseStamped( + ts=pose_ts, + frame_id=ps.header.frame_id or frame_id, + position=Vector3(p.x, p.y, p.z), + orientation=Quaternion(o.x, o.y, o.z, o.w), + ) + ) + return NavPath(ts=ts, frame_id=frame_id, poses=poses) + + +def _tfmessage_from_ros(msg: "ROSTFMessage") -> TFMessage: + """Convert a ROS2 tf2_msgs/TFMessage to a DimOS TFMessage.""" + transforms = [] + for ts_msg in msg.transforms: + ts = ts_msg.header.stamp.sec + ts_msg.header.stamp.nanosec / 1e9 + t = ts_msg.transform.translation + r = ts_msg.transform.rotation + transforms.append( + Transform( + translation=Vector3(t.x, t.y, t.z), + rotation=Quaternion(r.x, r.y, r.z, r.w), + frame_id=ts_msg.header.frame_id, + child_frame_id=ts_msg.child_frame_id, + ts=ts, + ) + ) + return TFMessage(*transforms) + + +# -- DimOS → ROS2 conversion helpers (inverse of the from_ros functions above) -- + + +def _pc2_to_ros(pc2: PointCloud2) -> "ROSPointCloud2": + """Convert a DimOS PointCloud2 to a ROS2 sensor_msgs/PointCloud2. + + Includes a zero-filled ``intensity`` field because the CMU nav stack's + terrain analysis nodes require it (they filter on ``intensity``). + """ + from builtin_interfaces.msg import ( + Time as ROSTime, # type: ignore[attr-defined,import-not-found] + ) + from sensor_msgs.msg import PointField # type: ignore[attr-defined] + + points, _ = pc2.as_numpy() # (N, 3) float32 + n = points.shape[0] + # XYZI layout: 4 floats per point (intensity = 0) + xyzi = np.zeros((n, 4), dtype=np.float32) + xyzi[:, :3] = points.astype(np.float32) + + ros_msg = ROSPointCloud2() # type: ignore[no-untyped-call] + ros_msg.header.stamp = ROSTime(sec=int(pc2.ts), nanosec=int((pc2.ts % 1) * 1e9)) # type: ignore[no-untyped-call] + ros_msg.header.frame_id = pc2.frame_id or "sensor" + ros_msg.height = 1 + ros_msg.width = n + ros_msg.fields = [ + PointField(name="x", offset=0, datatype=7, count=1), # type: ignore[no-untyped-call] + PointField(name="y", offset=4, datatype=7, count=1), # type: ignore[no-untyped-call] + PointField(name="z", offset=8, datatype=7, count=1), # type: ignore[no-untyped-call] + PointField(name="intensity", offset=12, datatype=7, count=1), # type: ignore[no-untyped-call] + ] + ros_msg.is_bigendian = False + ros_msg.point_step = 16 + ros_msg.row_step = 16 * n + ros_msg.data = xyzi.tobytes() + ros_msg.is_dense = True + return ros_msg + + +def _odometry_to_ros(odom: Odometry) -> "ROSOdometry": + """Convert a DimOS Odometry to a ROS2 nav_msgs/Odometry.""" + from builtin_interfaces.msg import ( + Time as ROSTime, # type: ignore[attr-defined,import-not-found] + ) + from geometry_msgs.msg import ( # type: ignore[attr-defined] + Point as ROSPoint, + Pose as ROSPose, + Quaternion as ROSQuat, + Twist as ROSTwist, + Vector3 as ROSVector3, + ) + + ros_msg = ROSOdometry() # type: ignore[no-untyped-call] + ros_msg.header.stamp = ROSTime(sec=int(odom.ts), nanosec=int((odom.ts % 1) * 1e9)) # type: ignore[no-untyped-call] + ros_msg.header.frame_id = odom.frame_id or "map" + ros_msg.child_frame_id = odom.child_frame_id or "sensor" + ros_msg.pose.pose = ROSPose( # type: ignore[no-untyped-call] + position=ROSPoint( # type: ignore[no-untyped-call] + x=odom.pose.position.x, y=odom.pose.position.y, z=odom.pose.position.z + ), + orientation=ROSQuat( # type: ignore[no-untyped-call] + x=odom.pose.orientation.x, + y=odom.pose.orientation.y, + z=odom.pose.orientation.z, + w=odom.pose.orientation.w, + ), + ) + ros_msg.twist.twist = ROSTwist( # type: ignore[no-untyped-call] + linear=ROSVector3( # type: ignore[no-untyped-call] + x=odom.twist.linear.x, y=odom.twist.linear.y, z=odom.twist.linear.z + ), + angular=ROSVector3( # type: ignore[no-untyped-call] + x=odom.twist.angular.x, y=odom.twist.angular.y, z=odom.twist.angular.z + ), + ) + return ros_msg + + +__all__ = ["ROSNav"] + +if __name__ == "__main__": + ROSNav.blueprint().build() diff --git a/dimos/navigation/rosnav/test_rosnav_goal_navigation.py b/dimos/navigation/rosnav/test_rosnav_goal_navigation.py new file mode 100644 index 0000000000..1b3aab0e7b --- /dev/null +++ b/dimos/navigation/rosnav/test_rosnav_goal_navigation.py @@ -0,0 +1,264 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Integration test: send a goal point to ROSNav and verify the robot reaches it. + +Starts the navigation stack in simulation mode with Unity, waits for odom to +stabilise (robot has spawned), sends a global goal via the ``set_goal`` RPC, +and asserts that the robot's final position moves toward the target. + +Requires: + - Docker with BuildKit + - NVIDIA GPU with drivers + - X11 display (real or virtual) + +Run: + pytest dimos/navigation/rosnav/test_rosnav_goal_navigation.py -m slow -s +""" + +import math +import threading +import time +from typing import Any + +from dimos_lcm.std_msgs import Bool +import pytest + +from dimos.core.blueprints import autoconnect +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Path import Path as NavPath +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.rosnav.rosnav_module import ROSNav + +# How long to wait for the robot to move toward the goal (seconds). +GOAL_TIMEOUT_SEC = 120 + +# How long to wait for initial odom before sending a goal. +ODOM_WAIT_SEC = 30 + +# Seconds to wait after receiving first odom before sending the goal, +# letting the nav stack initialise fully. +WARMUP_SEC = 10 + +# Minimum displacement (metres) to consider the robot "moved". +MIN_DISPLACEMENT_M = 0.5 + +# Goal in the map frame — 3 metres forward from origin. +GOAL_X = 3.0 +GOAL_Y = 0.0 + + +class GoalTracker(Module): + """Subscribes to odom and goal_reached, records positions for assertions.""" + + color_image: In[Image] + lidar: In[PointCloud2] + global_pointcloud: In[PointCloud2] + odom: In[PoseStamped] + goal_active: In[PoseStamped] + goal_reached: In[Bool] + path: In[NavPath] + cmd_vel: In[Twist] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lock = threading.Lock() + self._odom_history: list[PoseStamped] = [] + self._cmd_vel_count: int = 0 + self._nonzero_cmd_vel_count: int = 0 + self._goal_reached_flag = False + self._first_odom_event = threading.Event() + self._goal_reached_event = threading.Event() + self._moved_event = threading.Event() + self._start_pose: PoseStamped | None = None + self._unsub_fns: list[Any] = [] + + @rpc + def start(self) -> None: + self._unsub_fns.append(self.odom.subscribe(self._on_odom)) + self._unsub_fns.append(self.goal_reached.subscribe(self._on_goal_reached)) + self._unsub_fns.append(self.cmd_vel.subscribe(self._on_cmd_vel)) + + def _on_odom(self, msg: PoseStamped) -> None: + with self._lock: + self._odom_history.append(msg) + if len(self._odom_history) == 1: + self._first_odom_event.set() + # Check if the robot has moved significantly from its start pose + if self._start_pose is not None and not self._moved_event.is_set(): + dx = msg.position.x - self._start_pose.position.x + dy = msg.position.y - self._start_pose.position.y + if math.sqrt(dx * dx + dy * dy) > MIN_DISPLACEMENT_M: + self._moved_event.set() + + def _on_cmd_vel(self, msg: Twist) -> None: + with self._lock: + self._cmd_vel_count += 1 + if abs(msg.linear.x) > 0.01 or abs(msg.linear.y) > 0.01 or abs(msg.angular.z) > 0.01: + self._nonzero_cmd_vel_count += 1 + + def _on_goal_reached(self, msg: Bool) -> None: + if msg.data: + with self._lock: + self._goal_reached_flag = True + self._goal_reached_event.set() + + @rpc + def wait_for_first_odom(self, timeout: float = ODOM_WAIT_SEC) -> bool: + return self._first_odom_event.wait(timeout) + + @rpc + def wait_for_movement(self, timeout: float = GOAL_TIMEOUT_SEC) -> bool: + return self._moved_event.wait(timeout) + + @rpc + def wait_for_goal_reached(self, timeout: float = GOAL_TIMEOUT_SEC) -> bool: + return self._goal_reached_event.wait(timeout) + + @rpc + def mark_start(self) -> None: + """Snapshot the current odom as the 'start' for displacement measurement.""" + with self._lock: + if self._odom_history: + self._start_pose = self._odom_history[-1] + + @rpc + def get_start_pose(self) -> PoseStamped | None: + with self._lock: + return self._start_pose + + @rpc + def get_latest_odom(self) -> PoseStamped | None: + with self._lock: + return self._odom_history[-1] if self._odom_history else None + + @rpc + def is_goal_reached(self) -> bool: + with self._lock: + return self._goal_reached_flag + + @rpc + def get_odom_count(self) -> int: + with self._lock: + return len(self._odom_history) + + @rpc + def get_cmd_vel_stats(self) -> tuple[int, int]: + with self._lock: + return self._cmd_vel_count, self._nonzero_cmd_vel_count + + @rpc + def stop(self) -> None: + for unsub in self._unsub_fns: + unsub() + self._unsub_fns.clear() + + +def _distance_2d(a: PoseStamped, b: PoseStamped) -> float: + """Euclidean distance in the XY plane.""" + return math.sqrt((a.position.x - b.position.x) ** 2 + (a.position.y - b.position.y) ** 2) + + +@pytest.mark.slow +def test_rosnav_goal_reached(): + """Send a navigation goal and verify the robot reaches it.""" + + coordinator = ( + autoconnect( + ROSNav.blueprint(mode="simulation"), + GoalTracker.blueprint(), + ) + .global_config(viewer="none") + .build() + ) + + try: + tracker = coordinator.get_instance(GoalTracker) + rosnav = coordinator.get_instance(ROSNav) + + # 1. Wait for odom — proves the sim is running and the robot has spawned. + assert tracker.wait_for_first_odom(ODOM_WAIT_SEC), ( + f"No odom received within {ODOM_WAIT_SEC}s — Unity sim may not be running." + ) + + # Let the nav stack fully initialise before sending a goal. + print(f" Odom received. Waiting {WARMUP_SEC}s for nav stack warmup...") + time.sleep(WARMUP_SEC) + + # Snapshot the current position as "start". + tracker.mark_start() + start_pose = tracker.get_start_pose() + assert start_pose is not None + print( + f" Robot start: ({start_pose.position.x:.2f}, " + f"{start_pose.position.y:.2f}, {start_pose.position.z:.2f})" + ) + + # 2. Send a goal in the map frame via set_goal (non-blocking). + goal = PoseStamped( + position=Vector3(GOAL_X, GOAL_Y, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + ts=time.time(), + ) + print(f" Sending set_goal({GOAL_X}, {GOAL_Y}) in map frame...") + rosnav.set_goal(goal) + + # 3. Wait for either goal_reached or significant movement. + moved = tracker.wait_for_movement(GOAL_TIMEOUT_SEC) + reached = tracker.is_goal_reached() + + end_pose = tracker.get_latest_odom() + assert end_pose is not None + + displacement = _distance_2d(start_pose, end_pose) + total_cmd, nonzero_cmd = tracker.get_cmd_vel_stats() + print( + f" Robot end: ({end_pose.position.x:.2f}, " + f"{end_pose.position.y:.2f}, {end_pose.position.z:.2f})" + ) + print(f" Displacement: {displacement:.2f}m (goal was {GOAL_X}m)") + print(f" Odom messages: {tracker.get_odom_count()}") + print(f" cmd_vel messages: {total_cmd} total, {nonzero_cmd} non-zero") + print(f" goal_reached: {reached}") + + # 4. Assert the robot moved. + assert moved or reached, ( + f"Robot did not move within {GOAL_TIMEOUT_SEC}s. " + f"Displacement: {displacement:.2f}m, cmd_vel: {total_cmd} total / {nonzero_cmd} non-zero. " + f"The nav stack may not be generating velocity commands." + ) + + assert displacement > MIN_DISPLACEMENT_M, ( + f"Robot only moved {displacement:.2f}m toward goal at ({GOAL_X}, {GOAL_Y}). " + f"Expected at least {MIN_DISPLACEMENT_M}m." + ) + + if reached: + print(" ✅ goal_reached signal received") + else: + print( + f" ✅ Robot moved {displacement:.2f}m toward goal (goal_reached not yet received)" + ) + + finally: + coordinator.stop() diff --git a/dimos/navigation/rosnav/test_rosnav_simulation.py b/dimos/navigation/rosnav/test_rosnav_simulation.py new file mode 100644 index 0000000000..23688ed358 --- /dev/null +++ b/dimos/navigation/rosnav/test_rosnav_simulation.py @@ -0,0 +1,169 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Integration test for the ROSNav Docker module. + +Starts the navigation stack in simulation mode with Unity and verifies that +the ROS→DimOS bridge produces data on expected streams. Requires an X11 +display (real or virtual) for Unity to render. + +Requires: + - Docker with BuildKit + - SSH key in agent for private repo clone (first build only) + - ~17 GB disk for the Docker image + +Run: + pytest dimos/navigation/rosnav/test_rosnav_simulation.py -m slow -s +""" + +import threading +import time + +from dimos_lcm.std_msgs import Bool +import pytest + +from dimos.core.blueprints import autoconnect +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.Path import Path as NavPath +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.rosnav.rosnav_module import ROSNav + +# Streams that should produce data in simulation mode without sending a goal. +# The nav stack publishes these as soon as the Unity sim is running. +EXPECTED_STREAMS = { + "odom", + "lidar", + "color_image", + "cmd_vel", + "path", +} + +# Streams that only produce data after a navigation goal is sent, +# or take a long time to appear. We report but don't assert. +OPTIONAL_STREAMS = { + "global_pointcloud", + "goal_active", + "goal_reached", +} + +# Total timeout for waiting for expected streams. +STREAM_TIMEOUT_SEC = 360 # 6 minutes + + +class StreamCollector(Module): + """Test module that subscribes to all ROSNav output streams and records arrivals.""" + + color_image: In[Image] + lidar: In[PointCloud2] + global_pointcloud: In[PointCloud2] + odom: In[PoseStamped] + goal_active: In[PoseStamped] + goal_reached: In[Bool] + path: In[NavPath] + cmd_vel: In[Twist] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._received: dict[str, float] = {} + self._lock = threading.Lock() + self._unsub_fns: list = [] + + @rpc + def start(self) -> None: + for stream_name in ( + "color_image", + "lidar", + "global_pointcloud", + "odom", + "goal_active", + "goal_reached", + "path", + "cmd_vel", + ): + stream = getattr(self, stream_name) + unsub = stream.subscribe(self._make_callback(stream_name)) + if unsub is not None: + self._unsub_fns.append(unsub) + + def _make_callback(self, name: str): + def _cb(_msg): + with self._lock: + if name not in self._received: + self._received[name] = time.time() + + return _cb + + @rpc + def get_received(self) -> dict[str, float]: + with self._lock: + return dict(self._received) + + @rpc + def stop(self) -> None: + for unsub in self._unsub_fns: + unsub() + self._unsub_fns.clear() + + +@pytest.mark.slow +def test_rosnav_simulation_streams(): + """Start ROSNav in simulation mode and verify expected streams produce data.""" + + coordinator = ( + autoconnect( + ROSNav.blueprint(mode="simulation"), + StreamCollector.blueprint(), + ) + .global_config(viewer="none") + .build() + ) + + try: + collector = coordinator.get_instance(StreamCollector) + start = time.time() + missing = set(EXPECTED_STREAMS) + + while missing and (time.time() - start) < STREAM_TIMEOUT_SEC: + received = collector.get_received() + missing = EXPECTED_STREAMS - received.keys() + if missing: + time.sleep(2) + + received = collector.get_received() + arrived = set(received.keys()) + + for name in sorted(arrived): + elapsed = received[name] - start + print(f" stream '{name}' first message after {elapsed:.1f}s") + + missing_expected = EXPECTED_STREAMS - arrived + assert not missing_expected, ( + f"Timed out after {STREAM_TIMEOUT_SEC}s waiting for streams: {missing_expected}. " + f"Received: {sorted(arrived)}" + ) + + for name in sorted(OPTIONAL_STREAMS & arrived): + elapsed = received[name] - start + print(f" optional stream '{name}' arrived after {elapsed:.1f}s") + for name in sorted(OPTIONAL_STREAMS - arrived): + print(f" optional stream '{name}' did not produce data (expected without hardware)") + + finally: + coordinator.stop() diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav_legacy.py similarity index 98% rename from dimos/navigation/rosnav.py rename to dimos/navigation/rosnav_legacy.py index ef76539d5f..e8299de7a7 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav_legacy.py @@ -224,7 +224,7 @@ def goto(self, x: float, y: float) -> str: """ pose_to = PoseStamped( position=Vector3(x, y, 0), - orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), frame_id="base_link", ts=time.time(), ) @@ -242,7 +242,7 @@ def goto_global(self, x: float, y: float) -> str: ts=time.time(), frame_id="map", position=Vector3(x, y, 0.0), - orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), ) self.navigate_to(target) @@ -297,6 +297,9 @@ def stop_navigation(self) -> bool: self.ros_cancel_goal.publish(Bool(data=True)) self.ros_soft_stop.publish(Int8(data=2)) + # Unblock any waiting navigate_to() call + self._goal_reach = False + with self._state_lock: self._navigation_state = NavigationState.IDLE self._current_goal = None diff --git a/dimos/navigation/smartnav/.gitignore b/dimos/navigation/smartnav/.gitignore new file mode 100644 index 0000000000..3ef7d2e8a0 --- /dev/null +++ b/dimos/navigation/smartnav/.gitignore @@ -0,0 +1,2 @@ +# Nix build outputs (symlinks to /nix/store) +results/ diff --git a/dimos/navigation/smartnav/CMakeLists.txt b/dimos/navigation/smartnav/CMakeLists.txt new file mode 100644 index 0000000000..9380bf7e0f --- /dev/null +++ b/dimos/navigation/smartnav/CMakeLists.txt @@ -0,0 +1,147 @@ +cmake_minimum_required(VERSION 3.14) +project(smartnav_native CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + +if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + set(CMAKE_INSTALL_PREFIX "${CMAKE_SOURCE_DIR}/result" CACHE PATH "" FORCE) +endif() + +# Option: USE_PCL (default ON). When OFF, uses lightweight implementations. +option(USE_PCL "Use PCL for point cloud operations" ON) + +# Fetch dependencies +include(FetchContent) + +# dimos-lcm C++ message headers +FetchContent_Declare(dimos_lcm + GIT_REPOSITORY https://github.com/dimensionalOS/dimos-lcm.git + GIT_TAG main + GIT_SHALLOW TRUE +) +FetchContent_MakeAvailable(dimos_lcm) + +# LCM +find_package(PkgConfig REQUIRED) +pkg_check_modules(LCM REQUIRED lcm) + +# Eigen3 +find_package(Eigen3 REQUIRED) + +# PCL (optional) +if(USE_PCL) + find_package(PCL 1.8 REQUIRED COMPONENTS common filters kdtree) + add_definitions(-DUSE_PCL) +endif() + +# Common include directories +set(SMARTNAV_COMMON_INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR}/common + ${dimos_lcm_SOURCE_DIR}/generated/cpp_lcm_msgs + ${LCM_INCLUDE_DIRS} + ${EIGEN3_INCLUDE_DIR} +) + +set(SMARTNAV_COMMON_LIBS + ${LCM_LIBRARIES} +) + +set(SMARTNAV_COMMON_LIB_DIRS + ${LCM_LIBRARY_DIRS} +) + +if(USE_PCL) + list(APPEND SMARTNAV_COMMON_INCLUDES ${PCL_INCLUDE_DIRS}) + list(APPEND SMARTNAV_COMMON_LIBS ${PCL_LIBRARIES}) +endif() + +# --- Terrain Analysis --- +add_executable(terrain_analysis + modules/terrain_analysis/main.cpp +) +target_include_directories(terrain_analysis PRIVATE ${SMARTNAV_COMMON_INCLUDES}) +target_link_libraries(terrain_analysis PRIVATE ${SMARTNAV_COMMON_LIBS}) +target_link_directories(terrain_analysis PRIVATE ${SMARTNAV_COMMON_LIB_DIRS}) + +# --- Local Planner --- +add_executable(local_planner + modules/local_planner/main.cpp +) +target_include_directories(local_planner PRIVATE ${SMARTNAV_COMMON_INCLUDES}) +target_link_libraries(local_planner PRIVATE ${SMARTNAV_COMMON_LIBS}) +target_link_directories(local_planner PRIVATE ${SMARTNAV_COMMON_LIB_DIRS}) + +# --- Path Follower --- +add_executable(path_follower + modules/path_follower/main.cpp +) +target_include_directories(path_follower PRIVATE ${SMARTNAV_COMMON_INCLUDES}) +target_link_libraries(path_follower PRIVATE ${SMARTNAV_COMMON_LIBS}) +target_link_directories(path_follower PRIVATE ${SMARTNAV_COMMON_LIB_DIRS}) + +# --- FAR Planner --- +add_executable(far_planner + modules/far_planner/main.cpp +) +target_include_directories(far_planner PRIVATE ${SMARTNAV_COMMON_INCLUDES}) +target_link_libraries(far_planner PRIVATE ${SMARTNAV_COMMON_LIBS}) +target_link_directories(far_planner PRIVATE ${SMARTNAV_COMMON_LIB_DIRS}) +# FAR planner uses OpenCV for contour detection +find_package(OpenCV QUIET COMPONENTS core imgproc) +if(OpenCV_FOUND) + target_include_directories(far_planner PRIVATE ${OpenCV_INCLUDE_DIRS}) + target_link_libraries(far_planner PRIVATE ${OpenCV_LIBS}) + target_compile_definitions(far_planner PRIVATE HAS_OPENCV) +endif() + +# --- PGO (Pose Graph Optimization) --- +find_package(GTSAM QUIET) +if(USE_PCL AND GTSAM_FOUND) + add_executable(pgo + modules/pgo/main.cpp + ) + target_include_directories(pgo PRIVATE ${SMARTNAV_COMMON_INCLUDES}) + target_link_libraries(pgo PRIVATE ${SMARTNAV_COMMON_LIBS} gtsam) + target_link_directories(pgo PRIVATE ${SMARTNAV_COMMON_LIB_DIRS}) + # PCL registration component needed for ICP + find_package(PCL 1.8 REQUIRED COMPONENTS registration) + target_include_directories(pgo PRIVATE ${PCL_INCLUDE_DIRS}) + target_link_libraries(pgo PRIVATE ${PCL_LIBRARIES}) +endif() + +# --- AriseSLAM --- +find_package(Ceres QUIET) +if(USE_PCL AND Ceres_FOUND) + add_executable(arise_slam + modules/arise_slam/main.cpp + ) + target_include_directories(arise_slam PRIVATE ${SMARTNAV_COMMON_INCLUDES}) + target_link_libraries(arise_slam PRIVATE ${SMARTNAV_COMMON_LIBS} Ceres::ceres) + target_link_directories(arise_slam PRIVATE ${SMARTNAV_COMMON_LIB_DIRS}) +endif() + +# --- TARE Planner --- +add_executable(tare_planner + modules/tare_planner/main.cpp +) +target_include_directories(tare_planner PRIVATE ${SMARTNAV_COMMON_INCLUDES}) +target_link_libraries(tare_planner PRIVATE ${SMARTNAV_COMMON_LIBS}) +target_link_directories(tare_planner PRIVATE ${SMARTNAV_COMMON_LIB_DIRS}) + +# Install all targets +set(SMARTNAV_INSTALL_TARGETS + terrain_analysis + local_planner + path_follower + far_planner + tare_planner +) +if(USE_PCL AND GTSAM_FOUND) + list(APPEND SMARTNAV_INSTALL_TARGETS pgo) +endif() +if(USE_PCL AND Ceres_FOUND) + list(APPEND SMARTNAV_INSTALL_TARGETS arise_slam) +endif() +install(TARGETS ${SMARTNAV_INSTALL_TARGETS} DESTINATION bin) diff --git a/dimos/navigation/smartnav/blueprints/_rerun_helpers.py b/dimos/navigation/smartnav/blueprints/_rerun_helpers.py new file mode 100644 index 0000000000..dec850ee9a --- /dev/null +++ b/dimos/navigation/smartnav/blueprints/_rerun_helpers.py @@ -0,0 +1,138 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared Rerun visual overrides for SmartNav blueprints.""" + +from __future__ import annotations + +from typing import Any + + +def sensor_scan_override(cloud: Any) -> Any: + """Render sensor_scan attached to the sensor TF frame so it moves with the robot.""" + import rerun as rr + + arch = cloud.to_rerun(colormap="turbo", size=0.02) + return [ + ("world/sensor_scan", rr.Transform3D(parent_frame="tf#/sensor")), + ("world/sensor_scan", arch), + ] + + +def global_map_override(cloud: Any) -> Any: + """Render accumulated global map — small grey/blue points for map context.""" + return cloud.to_rerun(colormap="cool", size=0.03) + + +def terrain_map_override(cloud: Any) -> Any: + """Render terrain_map: big green dots = traversable, red = obstacle. + + The terrain_analysis C++ module sets point intensity to the height + difference above the planar voxel ground. Low intensity → ground, + high intensity → obstacle. + """ + import numpy as np + import rerun as rr + + points, _ = cloud.as_numpy() + if len(points) == 0: + return None + + # Color by z-height: low = green (ground), high = red (obstacle) + z = points[:, 2] + z_min, z_max = z.min(), z.max() + z_norm = (z - z_min) / (z_max - z_min + 1e-8) + + colors = np.zeros((len(points), 3), dtype=np.uint8) + colors[:, 0] = (z_norm * 255).astype(np.uint8) # R + colors[:, 1] = ((1 - z_norm) * 200 + 55).astype(np.uint8) # G + colors[:, 2] = 30 + + return rr.Points3D(positions=points[:, :3], colors=colors, radii=0.08) + + +def terrain_map_ext_override(cloud: Any) -> Any: + """Render extended terrain map — persistent accumulated cloud.""" + return cloud.to_rerun(colormap="viridis", size=0.06) + + +def path_override(path_msg: Any) -> Any: + """Render path in vehicle frame by attaching to the sensor TF.""" + import rerun as rr + + if not path_msg.poses: + return None + + points = [[p.x, p.y, p.z + 0.3] for p in path_msg.poses] + return [ + ("world/nav_path", rr.Transform3D(parent_frame="tf#/sensor")), + ("world/nav_path", rr.LineStrips3D([points], colors=[(0, 255, 128)], radii=0.05)), + ] + + +def goal_path_override(path_msg: Any) -> Any: + """Render the goal line (robot→goal) as a bright dashed line in world frame.""" + import rerun as rr + + if not path_msg.poses or len(path_msg.poses) < 2: + return None + + points = [[p.x, p.y, p.z] for p in path_msg.poses] + return rr.LineStrips3D([points], colors=[(255, 100, 50)], radii=0.03) + + +def waypoint_override(msg: Any) -> Any: + """Render the current waypoint goal as a visible marker.""" + import math + + import rerun as rr + + if not all(math.isfinite(v) for v in (msg.x, msg.y, msg.z)): + return None + + return rr.Points3D( + positions=[[msg.x, msg.y, msg.z + 0.5]], + colors=[(255, 50, 50)], + radii=0.3, + ) + + +def static_robot(rr: Any) -> list[Any]: + """Static robot rectangle attached to the sensor TF frame. + + Renders a wireframe box roughly the size of the mecanum-wheel platform, + so you can see the robot's position and heading in the 3D view. + """ + return [ + rr.Boxes3D( + half_sizes=[0.25, 0.20, 0.15], # ~50x40x30 cm box (mecanum platform) + centers=[[0, 0, 0]], + colors=[(0, 255, 127)], + fill_mode="MajorWireframe", + ), + rr.Transform3D(parent_frame="tf#/sensor"), + ] + + +def static_floor(rr: Any) -> list[Any]: + """Static ground plane at z=0 as a solid textured quad.""" + + s = 50.0 # half-size + return [ + rr.Mesh3D( + vertex_positions=[[-s, -s, 0], [s, -s, 0], [s, s, 0], [-s, s, 0]], + triangle_indices=[[0, 1, 2], [0, 2, 3]], + vertex_colors=[[40, 40, 40, 120]] * 4, + ) + ] diff --git a/dimos/navigation/smartnav/blueprints/real_robot.py b/dimos/navigation/smartnav/blueprints/real_robot.py new file mode 100644 index 0000000000..a3a649d01a --- /dev/null +++ b/dimos/navigation/smartnav/blueprints/real_robot.py @@ -0,0 +1,94 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Real robot blueprint: runs on physical hardware with FastLio2 SLAM. + +Uses the existing dimos FastLio2 NativeModule for SLAM with a Livox Mid-360 +lidar, replacing the Unity simulator with real sensor data. + +FastLio2 outputs ``lidar`` (not ``registered_scan``), so we remap it. +No camera ports — lidar-only setup. +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.tui_control.tui_control import TUIControlModule +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + """Rerun layout for lidar-only (no camera panel).""" + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Spatial3DView(origin="world", name="3D"), + ) + + +def _terrain_map_override(cloud: Any) -> Any: + """Render terrain_map colored by obstacle cost (intensity field).""" + return cloud.to_rerun(colormap="turbo", size=0.04) + + +rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "visual_override": { + "world/terrain_map": _terrain_map_override, + }, +} + + +def make_real_robot_blueprint( + host_ip: str = "192.168.1.5", + lidar_ip: str = "192.168.1.155", +): + """Create a real robot blueprint with configurable network settings.""" + return autoconnect( + FastLio2.blueprint(host_ip=host_ip, lidar_ip=lidar_ip), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint(), + LocalPlanner.blueprint(), + PathFollower.blueprint(), + TUIControlModule.blueprint(), + vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config), + ).remappings( + [ + # FastLio2 outputs "lidar"; SmartNav modules expect "registered_scan" + (FastLio2, "lidar", "registered_scan"), + ] + ) + + +real_robot_blueprint = make_real_robot_blueprint() + + +def main() -> None: + real_robot_blueprint.build().loop() + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/smartnav/blueprints/simulation.py b/dimos/navigation/smartnav/blueprints/simulation.py new file mode 100644 index 0000000000..2d0d72781b --- /dev/null +++ b/dimos/navigation/smartnav/blueprints/simulation.py @@ -0,0 +1,137 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation blueprint: base autonomy with Unity vehicle simulator.""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + global_map_override, + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.global_map.global_map import GlobalMap +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/global_map": global_map_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +simulation_blueprint = autoconnect( + UnityBridgeModule.blueprint( + unity_binary="", + unity_scene="home_building_1", + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + ] + ), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + GlobalMap.blueprint(), + vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config), +) + + +def main() -> None: + simulation_blueprint.build({"n_workers": 8}).loop() + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/smartnav/blueprints/simulation_explore.py b/dimos/navigation/smartnav/blueprints/simulation_explore.py new file mode 100644 index 0000000000..db88328e22 --- /dev/null +++ b/dimos/navigation/smartnav/blueprints/simulation_explore.py @@ -0,0 +1,158 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation + TARE exploration planner blueprint. + +Usage: + python -m smartnav.blueprints.simulation_explore # default scene + python -m smartnav.blueprints.simulation_explore home_building_1 # specific scene +""" + +from __future__ import annotations + +import sys +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + global_map_override, + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.global_map.global_map import GlobalMap +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.tare_planner.tare_planner import TarePlanner +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + "world/terrain_map": terrain_map_override, + "world/sensor_scan": sensor_scan_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/global_map": global_map_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + + +def make_explore_blueprint(scene: str = "home_building_1"): + """Create an exploration blueprint with the given Unity scene.""" + return autoconnect( + UnityBridgeModule.blueprint( + unity_binary="", + unity_scene=scene, + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + ] + ), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + TarePlanner.blueprint(), + ClickToGoal.blueprint(), + GlobalMap.blueprint(), + vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config), + ).remappings( + [ + # In explore mode, only TarePlanner should drive way_point to LocalPlanner. + # Disconnect ClickToGoal's way_point so it doesn't conflict. + (ClickToGoal, "way_point", "_click_way_point_unused"), + ] + ) + + +simulation_explore_blueprint = make_explore_blueprint() + + +def main() -> None: + scene = sys.argv[1] if len(sys.argv) > 1 else "home_building_1" + make_explore_blueprint(scene).build({"n_workers": 9}).loop() + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/smartnav/blueprints/simulation_pgo.py b/dimos/navigation/smartnav/blueprints/simulation_pgo.py new file mode 100644 index 0000000000..9c8fb6bdd7 --- /dev/null +++ b/dimos/navigation/smartnav/blueprints/simulation_pgo.py @@ -0,0 +1,146 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation + PGO blueprint: base autonomy with pose graph optimization. + +Replaces GlobalMap with PGO module. PGO provides loop-closure-corrected +odometry and accumulated global map from optimized keyframes. + +Data flow: + UnityBridge → registered_scan + odometry + → PGO → corrected_odometry + global_map + → SensorScanGeneration → TerrainAnalysis → LocalPlanner → PathFollower +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + global_map_override, + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.pgo.pgo import PGO +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/global_map": global_map_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +simulation_pgo_blueprint = autoconnect( + UnityBridgeModule.blueprint( + unity_binary="", + unity_scene="home_building_1", + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + ] + ), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + PGO.blueprint(), + vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config), +) + + +def main() -> None: + simulation_pgo_blueprint.build({"n_workers": 8}).loop() + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/smartnav/blueprints/simulation_route.py b/dimos/navigation/smartnav/blueprints/simulation_route.py new file mode 100644 index 0000000000..ecd932a20a --- /dev/null +++ b/dimos/navigation/smartnav/blueprints/simulation_route.py @@ -0,0 +1,150 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation + FAR route planner blueprint. + +Data flow: + ClickToGoal.way_point → (remapped to "goal") → FarPlanner.goal + FarPlanner.way_point → LocalPlanner.way_point +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + global_map_override, + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.far_planner.far_planner import FarPlanner +from dimos.navigation.smartnav.modules.global_map.global_map import GlobalMap +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/global_map": global_map_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +simulation_route_blueprint = autoconnect( + UnityBridgeModule.blueprint( + unity_binary="", + unity_scene="home_building_1", + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + ] + ), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + FarPlanner.blueprint(), + ClickToGoal.blueprint(), + GlobalMap.blueprint(), + vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config), +).remappings( + [ + # In route mode, only FarPlanner should drive way_point to LocalPlanner. + # Disconnect ClickToGoal's way_point so it doesn't conflict/override. + (ClickToGoal, "way_point", "_click_way_point_unused"), + ] +) + + +def main() -> None: + simulation_route_blueprint.build({"n_workers": 9}).loop() + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/smartnav/blueprints/simulation_slam.py b/dimos/navigation/smartnav/blueprints/simulation_slam.py new file mode 100644 index 0000000000..1082b853a1 --- /dev/null +++ b/dimos/navigation/smartnav/blueprints/simulation_slam.py @@ -0,0 +1,160 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulation + AriseSLAM blueprint: base autonomy with LiDAR SLAM. + +Replaces UnityBridge's simulated odometry with actual SLAM-computed odometry. +AriseSLAM processes raw point clouds through feature extraction and +scan-to-map matching to produce world-frame registered scans and odometry. + +Data flow: + UnityBridge → raw lidar cloud (body frame) + → AriseSLAM → registered_scan (world frame) + odometry + → SensorScanGeneration → TerrainAnalysis → LocalPlanner → PathFollower +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + global_map_override, + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.arise_slam.arise_slam import AriseSLAM +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.global_map.global_map import GlobalMap +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/global_map": global_map_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +simulation_slam_blueprint = autoconnect( + UnityBridgeModule.blueprint( + unity_binary="", + unity_scene="home_building_1", + ), + AriseSLAM.blueprint( + extra_args=[ + "--scanVoxelSize", + "0.1", + "--maxRange", + "50.0", + "--publishMap", + "true", + "--mapPublishRate", + "0.2", + ] + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + ] + ), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + GlobalMap.blueprint(), + vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config), +) + + +def main() -> None: + simulation_slam_blueprint.build({"n_workers": 9}).loop() + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/smartnav/common/dimos_native_module.hpp b/dimos/navigation/smartnav/common/dimos_native_module.hpp new file mode 100644 index 0000000000..e7fed34bdf --- /dev/null +++ b/dimos/navigation/smartnav/common/dimos_native_module.hpp @@ -0,0 +1,89 @@ +// SmartNav Native Module helpers. +// Re-exports dimos NativeModule patterns for CLI arg parsing and LCM helpers. +// Based on dimos/hardware/sensors/lidar/common/dimos_native_module.hpp + +#pragma once + +#include +#include +#include +#include + +#include "std_msgs/Header.hpp" +#include "std_msgs/Time.hpp" + +namespace dimos { + +class NativeModule { +public: + NativeModule(int argc, char** argv) { + for (int i = 1; i < argc; ++i) { + std::string arg(argv[i]); + if (arg.size() > 2 && arg[0] == '-' && arg[1] == '-' && i + 1 < argc) { + args_[arg.substr(2)] = argv[++i]; + } + } + } + + /// Get the full LCM channel string for a declared port. + const std::string& topic(const std::string& port) const { + auto it = args_.find(port); + if (it == args_.end()) { + throw std::runtime_error("NativeModule: no topic for port '" + port + "'"); + } + return it->second; + } + + /// Get a string arg value, or a default if not present. + std::string arg(const std::string& key, const std::string& default_val = "") const { + auto it = args_.find(key); + return it != args_.end() ? it->second : default_val; + } + + /// Get a float arg value, or a default if not present. + float arg_float(const std::string& key, float default_val = 0.0f) const { + auto it = args_.find(key); + return it != args_.end() ? std::stof(it->second) : default_val; + } + + /// Get an int arg value, or a default if not present. + int arg_int(const std::string& key, int default_val = 0) const { + auto it = args_.find(key); + return it != args_.end() ? std::stoi(it->second) : default_val; + } + + /// Get a bool arg value, or a default if not present. + bool arg_bool(const std::string& key, bool default_val = false) const { + auto it = args_.find(key); + if (it == args_.end()) return default_val; + return it->second == "true" || it->second == "1"; + } + + /// Check if a port/arg was provided. + bool has(const std::string& key) const { + return args_.count(key) > 0; + } + +private: + std::map args_; +}; + +/// Convert seconds (double) to a ROS-style Time message. +inline std_msgs::Time time_from_seconds(double t) { + std_msgs::Time ts; + ts.sec = static_cast(t); + ts.nsec = static_cast((t - ts.sec) * 1e9); + return ts; +} + +/// Build a stamped Header with auto-incrementing sequence number. +inline std_msgs::Header make_header(const std::string& frame_id, double ts) { + static std::atomic seq{0}; + std_msgs::Header h; + h.seq = seq.fetch_add(1, std::memory_order_relaxed); + h.stamp = time_from_seconds(ts); + h.frame_id = frame_id; + return h; +} + +} // namespace dimos diff --git a/dimos/navigation/smartnav/common/point_cloud_utils.hpp b/dimos/navigation/smartnav/common/point_cloud_utils.hpp new file mode 100644 index 0000000000..0970e1f8de --- /dev/null +++ b/dimos/navigation/smartnav/common/point_cloud_utils.hpp @@ -0,0 +1,170 @@ +// Point cloud utility functions for SmartNav native modules. +// Provides PointCloud2 building/parsing helpers that work with dimos-lcm types. +// When USE_PCL is defined, also provides PCL interop utilities. + +#pragma once + +#include +#include +#include + +#include "sensor_msgs/PointCloud2.hpp" +#include "sensor_msgs/PointField.hpp" +#include "std_msgs/Header.hpp" + +#include "dimos_native_module.hpp" + +#ifdef USE_PCL +#include +#include +#include +#endif + +namespace smartnav { + +// Simple XYZI point structure (no PCL dependency) +struct PointXYZI { + float x, y, z, intensity; +}; + +// Build PointCloud2 from vector of XYZI points +inline sensor_msgs::PointCloud2 build_pointcloud2( + const std::vector& points, + const std::string& frame_id, + double timestamp +) { + sensor_msgs::PointCloud2 pc; + pc.header = dimos::make_header(frame_id, timestamp); + pc.height = 1; + pc.width = static_cast(points.size()); + pc.is_bigendian = 0; + pc.is_dense = 1; + + // Fields: x, y, z, intensity (all float32) + pc.fields_length = 4; + pc.fields.resize(4); + auto make_field = [](const std::string& name, int32_t offset) { + sensor_msgs::PointField f; + f.name = name; + f.offset = offset; + f.datatype = sensor_msgs::PointField::FLOAT32; + f.count = 1; + return f; + }; + pc.fields[0] = make_field("x", 0); + pc.fields[1] = make_field("y", 4); + pc.fields[2] = make_field("z", 8); + pc.fields[3] = make_field("intensity", 12); + + pc.point_step = 16; + pc.row_step = pc.point_step * pc.width; + pc.data_length = pc.row_step; + pc.data.resize(pc.data_length); + + for (size_t i = 0; i < points.size(); ++i) { + float* dst = reinterpret_cast(pc.data.data() + i * 16); + dst[0] = points[i].x; + dst[1] = points[i].y; + dst[2] = points[i].z; + dst[3] = points[i].intensity; + } + + return pc; +} + +// Parse PointCloud2 into vector of XYZI points +inline std::vector parse_pointcloud2(const sensor_msgs::PointCloud2& pc) { + std::vector points; + if (pc.width == 0 || pc.height == 0) return points; + + int num_points = pc.width * pc.height; + points.reserve(num_points); + + // Find field offsets + int x_off = -1, y_off = -1, z_off = -1, i_off = -1; + for (const auto& f : pc.fields) { + if (f.name == "x") x_off = f.offset; + else if (f.name == "y") y_off = f.offset; + else if (f.name == "z") z_off = f.offset; + else if (f.name == "intensity") i_off = f.offset; + } + + if (x_off < 0 || y_off < 0 || z_off < 0) return points; + + for (int n = 0; n < num_points; ++n) { + if (static_cast((n + 1) * pc.point_step) > pc.data.size()) break; + const uint8_t* base = pc.data.data() + n * pc.point_step; + PointXYZI p; + std::memcpy(&p.x, base + x_off, sizeof(float)); + std::memcpy(&p.y, base + y_off, sizeof(float)); + std::memcpy(&p.z, base + z_off, sizeof(float)); + if (i_off >= 0) std::memcpy(&p.intensity, base + i_off, sizeof(float)); + else p.intensity = 0.0f; + points.push_back(p); + } + + return points; +} + +// Get timestamp from PointCloud2 header +inline double get_timestamp(const sensor_msgs::PointCloud2& pc) { + return pc.header.stamp.sec + pc.header.stamp.nsec / 1e9; +} + +#ifdef USE_PCL +// Convert dimos-lcm PointCloud2 to PCL point cloud +inline void to_pcl(const sensor_msgs::PointCloud2& pc, + pcl::PointCloud& cloud) { + auto points = parse_pointcloud2(pc); + cloud.clear(); + cloud.reserve(points.size()); + for (const auto& p : points) { + pcl::PointXYZI pt; + pt.x = p.x; + pt.y = p.y; + pt.z = p.z; + pt.intensity = p.intensity; + cloud.push_back(pt); + } + cloud.width = cloud.size(); + cloud.height = 1; + cloud.is_dense = true; +} + +// Convert PCL point cloud to dimos-lcm PointCloud2 +inline sensor_msgs::PointCloud2 from_pcl( + const pcl::PointCloud& cloud, + const std::string& frame_id, + double timestamp +) { + std::vector points; + points.reserve(cloud.size()); + for (const auto& pt : cloud) { + points.push_back({pt.x, pt.y, pt.z, pt.intensity}); + } + return build_pointcloud2(points, frame_id, timestamp); +} +#endif + +// Quaternion to RPY conversion +inline void quat_to_rpy(double qx, double qy, double qz, double qw, + double& roll, double& pitch, double& yaw) { + // Roll (x-axis rotation) + double sinr_cosp = 2.0 * (qw * qx + qy * qz); + double cosr_cosp = 1.0 - 2.0 * (qx * qx + qy * qy); + roll = std::atan2(sinr_cosp, cosr_cosp); + + // Pitch (y-axis rotation) + double sinp = 2.0 * (qw * qy - qz * qx); + if (std::abs(sinp) >= 1.0) + pitch = std::copysign(M_PI / 2, sinp); + else + pitch = std::asin(sinp); + + // Yaw (z-axis rotation) + double siny_cosp = 2.0 * (qw * qz + qx * qy); + double cosy_cosp = 1.0 - 2.0 * (qy * qy + qz * qz); + yaw = std::atan2(siny_cosp, cosy_cosp); +} + +} // namespace smartnav diff --git a/dimos/navigation/smartnav/flake.lock b/dimos/navigation/smartnav/flake.lock new file mode 100644 index 0000000000..4b35b3647e --- /dev/null +++ b/dimos/navigation/smartnav/flake.lock @@ -0,0 +1,79 @@ +{ + "nodes": { + "dimos-lcm": { + "flake": false, + "locked": { + "lastModified": 1769774949, + "narHash": "sha256-icRK7jerqNlwK1WZBrnIP04I2WozzFqTD7qsmnPxQuo=", + "owner": "dimensionalOS", + "repo": "dimos-lcm", + "rev": "0aa72b7b1bd3a65f50f5c03485ee9b728df56afe", + "type": "github" + }, + "original": { + "owner": "dimensionalOS", + "ref": "main", + "repo": "dimos-lcm", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1773389992, + "narHash": "sha256-wvfdLLWJ2I9oEpDd9PfMA8osfIZicoQ5MT1jIwNs9Tk=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c06b4ae3d6599a672a6210b7021d699c351eebda", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "dimos-lcm": "dimos-lcm", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/dimos/navigation/smartnav/flake.nix b/dimos/navigation/smartnav/flake.nix new file mode 100644 index 0000000000..40eead92f1 --- /dev/null +++ b/dimos/navigation/smartnav/flake.nix @@ -0,0 +1,115 @@ +{ + description = "SmartNav native modules - autonomous navigation C++ components"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + dimos-lcm = { + url = "github:dimensionalOS/dimos-lcm/main"; + flake = false; + }; + }; + + outputs = { self, nixpkgs, flake-utils, dimos-lcm, ... }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { inherit system; }; + + commonBuildInputs = [ + pkgs.lcm + pkgs.glib + pkgs.eigen + pkgs.boost + ]; + + commonNativeBuildInputs = [ + pkgs.cmake + pkgs.pkg-config + ]; + + commonCmakeFlags = [ + "-DCMAKE_POLICY_VERSION_MINIMUM=3.5" + "-DFETCHCONTENT_SOURCE_DIR_DIMOS_LCM=${dimos-lcm}" + ]; + + # Full build with PCL + smartnav_native = pkgs.stdenv.mkDerivation { + pname = "smartnav-native"; + version = "0.1.0"; + src = ./.; + + nativeBuildInputs = commonNativeBuildInputs; + buildInputs = commonBuildInputs ++ [ + pkgs.pcl + pkgs.opencv + pkgs.ceres-solver + ]; + + cmakeFlags = commonCmakeFlags ++ [ + "-DUSE_PCL=ON" + ]; + }; + + # Lightweight build without PCL + smartnav_native_lite = pkgs.stdenv.mkDerivation { + pname = "smartnav-native-lite"; + version = "0.1.0"; + src = ./.; + + nativeBuildInputs = commonNativeBuildInputs; + buildInputs = commonBuildInputs ++ [ + pkgs.opencv + ]; + + cmakeFlags = commonCmakeFlags ++ [ + "-DUSE_PCL=OFF" + ]; + }; + + # Individual module builds + mkModule = name: extra_inputs: extra_flags: + pkgs.stdenv.mkDerivation { + pname = "smartnav-${name}"; + version = "0.1.0"; + src = ./.; + + nativeBuildInputs = commonNativeBuildInputs; + buildInputs = commonBuildInputs ++ extra_inputs; + + cmakeFlags = commonCmakeFlags ++ extra_flags; + + # Only build the specific target + buildPhase = '' + cmake --build . --target ${name} + ''; + + installPhase = '' + mkdir -p $out/bin + cp ${name} $out/bin/ + ''; + }; + + terrain_analysis = mkModule "terrain_analysis" [ pkgs.pcl ] [ "-DUSE_PCL=ON" ]; + local_planner = mkModule "local_planner" [ pkgs.pcl ] [ "-DUSE_PCL=ON" ]; + path_follower = mkModule "path_follower" [ pkgs.pcl ] [ "-DUSE_PCL=ON" ]; + far_planner = mkModule "far_planner" [ pkgs.pcl pkgs.opencv ] [ "-DUSE_PCL=ON" ]; + tare_planner = mkModule "tare_planner" [ pkgs.pcl ] [ "-DUSE_PCL=ON" ]; + pgo = mkModule "pgo" [ pkgs.pcl pkgs.gtsam ] [ "-DUSE_PCL=ON" ]; + arise_slam = mkModule "arise_slam" [ pkgs.pcl pkgs.ceres-solver ] [ "-DUSE_PCL=ON" ]; + in { + packages = { + default = smartnav_native; + inherit smartnav_native smartnav_native_lite; + inherit terrain_analysis local_planner path_follower far_planner tare_planner pgo arise_slam; + }; + + devShells.default = pkgs.mkShell { + buildInputs = commonBuildInputs ++ commonNativeBuildInputs ++ [ + pkgs.pcl + pkgs.opencv + pkgs.gtsam + pkgs.ceres-solver + ]; + }; + }); +} diff --git a/dimos/navigation/smartnav/modules/arise_sim_adapter.py b/dimos/navigation/smartnav/modules/arise_sim_adapter.py new file mode 100644 index 0000000000..406a359010 --- /dev/null +++ b/dimos/navigation/smartnav/modules/arise_sim_adapter.py @@ -0,0 +1,180 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AriseSimAdapter: adapts Unity sim data for AriseSLAM input. + +AriseSLAM expects body-frame lidar (raw_points) and IMU data. +Unity provides world-frame registered_scan and ground-truth odometry. +This adapter: + 1. Transforms registered_scan from world-frame → body-frame using odom + 2. Synthesizes IMU (orientation + angular velocity + gravity) from odom + +This lets AriseSLAM run in simulation without real hardware. +""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.Imu import Imu +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class AriseSimAdapterConfig(ModuleConfig): + gravity: float = 9.80511 + imu_rate: float = 200.0 # Hz — AriseSLAM expects high-rate IMU + + +class AriseSimAdapter(Module[AriseSimAdapterConfig]): + """Adapts sim data (world-frame scans + odom) → AriseSLAM inputs (body-frame + IMU). + + Ports: + registered_scan (In[PointCloud2]): World-frame scan from simulator. + odometry (In[Odometry]): Ground-truth odom from simulator. + raw_points (Out[PointCloud2]): Body-frame scan for AriseSLAM. + imu (Out[Imu]): Synthetic IMU for AriseSLAM. + """ + + default_config = AriseSimAdapterConfig + + registered_scan: In[PointCloud2] + odometry: In[Odometry] + raw_points: Out[PointCloud2] + imu: Out[Imu] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._lock = threading.Lock() + self._running = False + self._thread: threading.Thread | None = None + self._latest_odom: Odometry | None = None + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + state.pop("_lock", None) + state.pop("_thread", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._lock = threading.Lock() + self._thread = None + + def start(self) -> None: + self.odometry._transport.subscribe(self._on_odom) + self.registered_scan._transport.subscribe(self._on_scan) + self._running = True + self._thread = threading.Thread(target=self._imu_loop, daemon=True) + self._thread.start() + print("[AriseSimAdapter] Started — converting sim data for AriseSLAM") + + def stop(self) -> None: + self._running = False + if self._thread: + self._thread.join(timeout=2.0) + super().stop() + + def _on_odom(self, msg: Odometry) -> None: + with self._lock: + self._latest_odom = msg + + def _on_scan(self, cloud: PointCloud2) -> None: + """Transform world-frame scan → body-frame using latest odom.""" + with self._lock: + odom = self._latest_odom + if odom is None: + return + + try: + tf_map_to_sensor = Transform( + translation=Vector3(odom.x, odom.y, odom.z), + rotation=odom.orientation, + frame_id="map", + child_frame_id="sensor", + ) + tf_sensor_to_map = tf_map_to_sensor.inverse() + body_cloud = cloud.transform(tf_sensor_to_map) + body_cloud.frame_id = "sensor" + self.raw_points._transport.publish(body_cloud) + except Exception: + import traceback + + print(f"[AriseSimAdapter] scan transform failed: {traceback.format_exc()}") + + def _imu_loop(self) -> None: + """Publish synthetic IMU at high rate from latest odom.""" + dt = 1.0 / self.config.imu_rate + g = self.config.gravity + + while self._running: + t0 = time.monotonic() + + with self._lock: + odom = self._latest_odom + + if odom is not None: + q = odom.pose.orientation + ang_vel = Vector3(0.0, 0.0, 0.0) + if odom.twist is not None: + ang_vel = Vector3( + odom.twist.angular.x, + odom.twist.angular.y, + odom.twist.angular.z, + ) + + # Rotate gravity [0, 0, g] into body frame + gx, gy, gz = _rotate_vec_by_quat_inv(0.0, 0.0, g, q.x, q.y, q.z, q.w) + + self.imu._transport.publish( + Imu( + angular_velocity=ang_vel, + linear_acceleration=Vector3(gx, gy, gz), + orientation=Quaternion(q.x, q.y, q.z, q.w), + ts=time.time(), + frame_id="sensor", + ) + ) + + elapsed = time.monotonic() - t0 + if dt - elapsed > 0: + time.sleep(dt - elapsed) + + +def _rotate_vec_by_quat_inv( + vx: float, + vy: float, + vz: float, + qx: float, + qy: float, + qz: float, + qw: float, +) -> tuple[float, float, float]: + """Rotate vector by the inverse of a unit quaternion.""" + nqx, nqy, nqz = -qx, -qy, -qz + tx = 2.0 * (nqy * vz - nqz * vy) + ty = 2.0 * (nqz * vx - nqx * vz) + tz = 2.0 * (nqx * vy - nqy * vx) + return ( + vx + qw * tx + (nqy * tz - nqz * ty), + vy + qw * ty + (nqz * tx - nqx * tz), + vz + qw * tz + (nqx * ty - nqy * tx), + ) diff --git a/dimos/navigation/smartnav/modules/arise_slam/arise_slam.py b/dimos/navigation/smartnav/modules/arise_slam/arise_slam.py new file mode 100644 index 0000000000..b392538d87 --- /dev/null +++ b/dimos/navigation/smartnav/modules/arise_slam/arise_slam.py @@ -0,0 +1,94 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AriseSLAM NativeModule: C++ LiDAR SLAM with feature-based scan matching. + +Ported from arise_slam_mid360. Performs curvature-based feature extraction +(edge + planar), scan-to-map matching via Ceres optimization, and optional +IMU preintegration for motion prediction. Publishes world-frame registered +point clouds and odometry. +""" + +from __future__ import annotations + +from dimos.core.native_module import NativeModule, NativeModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.Imu import Imu +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class AriseSLAMConfig(NativeModuleConfig): + """Config for the AriseSLAM native module.""" + + cwd: str | None = "." + executable: str = "result/bin/arise_slam" + build_command: str | None = ( + "nix build github:dimensionalOS/dimos-module-arise-slam/v0.1.0 --no-write-lock-file" + ) + + # Feature extraction + edge_threshold: float = 1.0 + surf_threshold: float = 0.1 + scan_voxel_size: float = 0.1 + + # Local map + line_res: float = 0.2 + plane_res: float = 0.4 + max_range: float = 100.0 + + # Scan matching + max_icp_iterations: int = 4 + max_lm_iterations: int = 15 + + # IMU + use_imu: bool = True + gravity: float = 9.80511 + + # Output + min_publish_interval: float = 0.05 + publish_map: bool = False + map_publish_rate: float = 0.2 + + # Initial pose + init_x: float = 0.0 + init_y: float = 0.0 + init_z: float = 0.0 + init_roll: float = 0.0 + init_pitch: float = 0.0 + init_yaw: float = 0.0 + + +class AriseSLAM(NativeModule): + """LiDAR SLAM module with feature-based scan-to-map matching. + + Processes raw LiDAR point clouds through curvature-based feature + extraction, matches against a rolling local map using Ceres + optimization, and publishes world-frame registered scans + odometry. + + Ports: + raw_points (In[PointCloud2]): Raw lidar point cloud (body frame). + imu (In[Imu]): IMU data for motion prediction. + registered_scan (Out[PointCloud2]): World-frame registered cloud. + odometry (Out[Odometry]): SLAM-estimated odometry. + local_map (Out[PointCloud2]): Local map visualization (optional). + """ + + default_config: type[AriseSLAMConfig] = AriseSLAMConfig # type: ignore[assignment] + + raw_points: In[PointCloud2] + imu: In[Imu] + registered_scan: Out[PointCloud2] + odometry: Out[Odometry] + local_map: Out[PointCloud2] diff --git a/dimos/navigation/smartnav/modules/arise_slam/test_arise_slam.py b/dimos/navigation/smartnav/modules/arise_slam/test_arise_slam.py new file mode 100644 index 0000000000..70daf645d2 --- /dev/null +++ b/dimos/navigation/smartnav/modules/arise_slam/test_arise_slam.py @@ -0,0 +1,101 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for AriseSLAM NativeModule wrapper.""" + +from pathlib import Path + +import pytest + +from dimos.navigation.smartnav.modules.arise_slam.arise_slam import AriseSLAM, AriseSLAMConfig + + +class TestAriseSLAMConfig: + """Test AriseSLAM configuration.""" + + def test_default_config(self): + config = AriseSLAMConfig() + assert config.edge_threshold == 1.0 + assert config.surf_threshold == 0.1 + assert config.max_icp_iterations == 4 + assert config.use_imu is True + + def test_cli_args_generation(self): + config = AriseSLAMConfig( + edge_threshold=2.0, + max_icp_iterations=8, + ) + args = config.to_cli_args() + assert "--edge_threshold" in args + assert "2.0" in args + assert "--max_icp_iterations" in args + assert "8" in args + + +class TestAriseSLAMModule: + """Test AriseSLAM module declaration.""" + + def test_ports_declared(self): + from typing import get_origin, get_type_hints + + from dimos.core.stream import In, Out + + hints = get_type_hints(AriseSLAM) + in_ports = {k for k, v in hints.items() if get_origin(v) is In} + out_ports = {k for k, v in hints.items() if get_origin(v) is Out} + + assert "raw_points" in in_ports + assert "imu" in in_ports + assert "registered_scan" in out_ports + assert "odometry" in out_ports + assert "local_map" in out_ports + + +@pytest.mark.skipif( + not Path(__file__).resolve().parent.joinpath("result", "bin").exists(), + reason="Native binary not built (run nix build first)", +) +class TestPathResolution: + """Verify native module paths resolve to real filesystem locations.""" + + def _make(self): + m = AriseSLAM() + m._resolve_paths() + return m + + def test_cwd_resolves_to_existing_directory(self): + m = self._make() + try: + assert Path(m.config.cwd).exists(), f"cwd does not exist: {m.config.cwd}" + assert Path(m.config.cwd).is_dir() + finally: + m.stop() + + def test_executable_exists(self): + m = self._make() + try: + exe = Path(m.config.executable) + assert exe.exists(), f"Binary not found: {exe}. Run nix build first." + finally: + m.stop() + + def test_cwd_resolves_to_smartnav_root(self): + """cwd should resolve to the smartnav root (where CMakeLists.txt lives).""" + m = self._make() + try: + cwd = Path(m.config.cwd).resolve() + assert (cwd / "CMakeLists.txt").exists(), f"cwd {cwd} is not the smartnav root" + assert (cwd / "flake.nix").exists() + finally: + m.stop() diff --git a/dimos/navigation/smartnav/modules/click_to_goal/click_to_goal.py b/dimos/navigation/smartnav/modules/click_to_goal/click_to_goal.py new file mode 100644 index 0000000000..ec38d15ba5 --- /dev/null +++ b/dimos/navigation/smartnav/modules/click_to_goal/click_to_goal.py @@ -0,0 +1,114 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ClickToGoal: forwards clicked_point to LocalPlanner's way_point. + +Receives clicked_point from RerunWebSocketServer (or any module that +publishes PointStamped clicks) and re-publishes as way_point / goal +for the navigation stack. Also publishes goal_path (straight line from +robot to goal) for Rerun visualization. +""" + +from __future__ import annotations + +import math +import threading +import time +from typing import Any + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.nav_msgs.Path import Path + + +class ClickToGoal(Module[ModuleConfig]): + """Relay clicked_point → way_point + goal for click-to-navigate. + + Ports: + clicked_point (In[PointStamped]): Click from viewer. + odometry (In[Odometry]): Vehicle pose for goal line rendering. + way_point (Out[PointStamped]): Navigation waypoint for LocalPlanner. + goal (Out[PointStamped]): Navigation goal for FarPlanner. + goal_path (Out[Path]): Straight line from robot to goal for Rerun. + """ + + default_config = ModuleConfig + + clicked_point: In[PointStamped] + odometry: In[Odometry] + way_point: Out[PointStamped] + goal: Out[PointStamped] + goal_path: Out[Path] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._lock = threading.Lock() + self._robot_x = 0.0 + self._robot_y = 0.0 + self._robot_z = 0.0 + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + state.pop("_lock", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._lock = threading.Lock() + + def start(self) -> None: + self.odometry._transport.subscribe(self._on_odom) + self.clicked_point._transport.subscribe(self._on_click) + + def _on_odom(self, msg: Odometry) -> None: + with self._lock: + self._robot_x = msg.pose.position.x + self._robot_y = msg.pose.position.y + self._robot_z = msg.pose.position.z + + def _on_click(self, msg: PointStamped) -> None: + # Reject invalid clicks (sky/background gives inf or huge coords) + if not all(math.isfinite(v) for v in (msg.x, msg.y, msg.z)): + print(f"[click_to_goal] Ignored invalid click: ({msg.x:.1f}, {msg.y:.1f}, {msg.z:.1f})") + return + if abs(msg.x) > 500 or abs(msg.y) > 500 or abs(msg.z) > 50: + print( + f"[click_to_goal] Ignored out-of-range click: ({msg.x:.1f}, {msg.y:.1f}, {msg.z:.1f})" + ) + return + + with self._lock: + rx, ry, rz = self._robot_x, self._robot_y, self._robot_z + + print(f"[click_to_goal] Goal: ({msg.x:.1f}, {msg.y:.1f}, {msg.z:.1f})") + self.way_point._transport.publish(msg) + self.goal._transport.publish(msg) + + # Publish a straight-line path from robot to goal for visualization + now = time.time() + poses = [ + PoseStamped( + ts=now, frame_id="map", position=[rx, ry, rz + 0.3], orientation=[0, 0, 0, 1] + ), + PoseStamped( + ts=now, + frame_id="map", + position=[msg.x, msg.y, msg.z + 0.3], + orientation=[0, 0, 0, 1], + ), + ] + self.goal_path._transport.publish(Path(ts=now, frame_id="map", poses=poses)) diff --git a/dimos/navigation/smartnav/modules/cmd_vel_mux.py b/dimos/navigation/smartnav/modules/cmd_vel_mux.py new file mode 100644 index 0000000000..9693aa472d --- /dev/null +++ b/dimos/navigation/smartnav/modules/cmd_vel_mux.py @@ -0,0 +1,101 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CmdVelMux: merges nav and teleop velocity commands. + +Teleop (tele_cmd_vel) takes priority over autonomous navigation +(nav_cmd_vel). When teleop is active, nav commands are suppressed. +After a cooldown period with no teleop input, nav commands resume. +""" + +from __future__ import annotations + +import threading +from typing import Any + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Twist import Twist + + +class CmdVelMuxConfig(ModuleConfig): + teleop_cooldown_sec: float = 1.0 + + +class CmdVelMux(Module[CmdVelMuxConfig]): + """Multiplexes nav_cmd_vel and tele_cmd_vel into a single cmd_vel output. + + Ports: + nav_cmd_vel (In[Twist]): Velocity from the autonomous planner. + tele_cmd_vel (In[Twist]): Velocity from keyboard/joystick teleop. + cmd_vel (Out[Twist]): Merged output — teleop wins when active. + """ + + default_config = CmdVelMuxConfig + + nav_cmd_vel: In[Twist] + tele_cmd_vel: In[Twist] + cmd_vel: Out[Twist] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._teleop_active = False + self._lock = threading.Lock() + self._timer: threading.Timer | None = None + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + state.pop("_lock", None) + state.pop("_timer", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._lock = threading.Lock() + self._timer = None + + def start(self) -> None: + self.nav_cmd_vel._transport.subscribe(self._on_nav) + self.tele_cmd_vel._transport.subscribe(self._on_teleop) + + def stop(self) -> None: + with self._lock: + if self._timer is not None: + self._timer.cancel() + self._timer = None + super().stop() + + def _on_nav(self, msg: Twist) -> None: + with self._lock: + if self._teleop_active: + return + self.cmd_vel._transport.publish(msg) + + def _on_teleop(self, msg: Twist) -> None: + with self._lock: + self._teleop_active = True + if self._timer is not None: + self._timer.cancel() + self._timer = threading.Timer( + self.config.teleop_cooldown_sec, + self._end_teleop, + ) + self._timer.daemon = True + self._timer.start() + self.cmd_vel._transport.publish(msg) + + def _end_teleop(self) -> None: + with self._lock: + self._teleop_active = False + self._timer = None diff --git a/dimos/navigation/smartnav/modules/far_planner/far_planner.py b/dimos/navigation/smartnav/modules/far_planner/far_planner.py new file mode 100644 index 0000000000..feddf01f8a --- /dev/null +++ b/dimos/navigation/smartnav/modules/far_planner/far_planner.py @@ -0,0 +1,66 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FarPlanner NativeModule: C++ visibility-graph route planner. + +Ported from far_planner + boundary_handler + graph_decoder. Builds a +visibility graph from registered scans, finds routes to goals, and +outputs intermediate waypoints for the local planner. +""" + +from __future__ import annotations + +from dimos.core.native_module import NativeModule, NativeModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class FarPlannerConfig(NativeModuleConfig): + """Config for the FAR planner native module.""" + + cwd: str | None = "." + executable: str = "result/bin/far_planner" + build_command: str | None = ( + "nix build github:dimensionalOS/dimos-module-far-planner/v0.1.0 --no-write-lock-file" + ) + + # Planner parameters + visibility_range: float = 15.0 + update_rate: float = 2.0 + robot_dim: float = 0.5 + sensor_range: float = 20.0 + + +class FarPlanner(NativeModule): + """FAR planner: visibility-graph global route planner. + + Builds and maintains a visibility graph from registered point clouds, + then finds shortest paths through the graph to navigation goals. + Outputs intermediate waypoints for the local planner. + + Ports: + registered_scan (In[PointCloud2]): World-frame point cloud for graph updates. + odometry (In[Odometry]): Vehicle state. + goal (In[PointStamped]): User-specified navigation goal. + way_point (Out[PointStamped]): Intermediate waypoint for local planner. + """ + + default_config: type[FarPlannerConfig] = FarPlannerConfig # type: ignore[assignment] + + registered_scan: In[PointCloud2] + odometry: In[Odometry] + goal: In[PointStamped] + way_point: Out[PointStamped] diff --git a/dimos/navigation/smartnav/modules/far_planner/test_far_planner.py b/dimos/navigation/smartnav/modules/far_planner/test_far_planner.py new file mode 100644 index 0000000000..b78502910a --- /dev/null +++ b/dimos/navigation/smartnav/modules/far_planner/test_far_planner.py @@ -0,0 +1,100 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for FarPlanner NativeModule wrapper.""" + +from pathlib import Path + +import pytest + +from dimos.navigation.smartnav.modules.far_planner.far_planner import FarPlanner, FarPlannerConfig + + +class TestFarPlannerConfig: + """Test FarPlanner configuration.""" + + def test_default_config(self): + config = FarPlannerConfig() + assert config.visibility_range == 15.0 + assert config.update_rate == 2.0 + assert config.robot_dim == 0.5 + assert config.sensor_range == 20.0 + + def test_cli_args_generation(self): + config = FarPlannerConfig( + visibility_range=20.0, + robot_dim=0.8, + ) + args = config.to_cli_args() + assert "--visibility_range" in args + assert "20.0" in args + assert "--robot_dim" in args + assert "0.8" in args + + +class TestFarPlannerModule: + """Test FarPlanner module declaration.""" + + def test_ports_declared(self): + from typing import get_origin, get_type_hints + + from dimos.core.stream import In, Out + + hints = get_type_hints(FarPlanner) + in_ports = {k for k, v in hints.items() if get_origin(v) is In} + out_ports = {k for k, v in hints.items() if get_origin(v) is Out} + + assert "registered_scan" in in_ports + assert "odometry" in in_ports + assert "goal" in in_ports + assert "way_point" in out_ports + + +@pytest.mark.skipif( + not Path(__file__).resolve().parent.joinpath("result", "bin").exists(), + reason="Native binary not built (run nix build first)", +) +class TestPathResolution: + """Verify native module paths resolve to real filesystem locations.""" + + def _make(self): + m = FarPlanner() + m._resolve_paths() + return m + + def test_cwd_resolves_to_existing_directory(self): + m = self._make() + try: + assert Path(m.config.cwd).exists(), f"cwd does not exist: {m.config.cwd}" + assert Path(m.config.cwd).is_dir() + finally: + m.stop() + + def test_executable_exists(self): + m = self._make() + try: + exe = Path(m.config.executable) + assert exe.exists(), f"Binary not found: {exe}. Run nix build first." + finally: + m.stop() + + def test_cwd_resolves_to_smartnav_root(self): + """cwd should resolve to the smartnav root (where CMakeLists.txt lives).""" + m = self._make() + try: + cwd = Path(m.config.cwd).resolve() + assert (cwd / "CMakeLists.txt").exists(), f"cwd {cwd} is not the smartnav root" + assert (cwd / "flake.nix").exists() + finally: + m.stop() diff --git a/dimos/navigation/smartnav/modules/global_map/global_map.py b/dimos/navigation/smartnav/modules/global_map/global_map.py new file mode 100644 index 0000000000..609004a35d --- /dev/null +++ b/dimos/navigation/smartnav/modules/global_map/global_map.py @@ -0,0 +1,168 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GlobalMap: accumulated voxelized point cloud from registered_scan. + +Subscribes to registered_scan and odometry, accumulates points into a +voxel grid, and publishes the full accumulated cloud periodically for +Rerun visualization. This gives a persistent "map" view instead of +only seeing instant/local data. + +Decay and range limits prevent unbounded memory growth. +""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import numpy as np + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class GlobalMapConfig(ModuleConfig): + """Config for global map accumulator.""" + + voxel_size: float = 0.15 # meters per voxel (fine enough for map detail) + decay_time: float = 300.0 # seconds before points expire (5 min) + publish_rate: float = 1.0 # Hz — keep low to avoid memory explosion + max_range: float = 80.0 # max distance from robot to keep + max_points: int = 500_000 # hard cap on published points + height_min: float = -2.0 # clip floor noise + height_max: float = 4.0 # clip ceiling + + +class GlobalMap(Module[GlobalMapConfig]): + """Accumulated global point cloud from registered_scan. + + Voxelizes incoming scans and maintains a persistent map with + time-based decay and range culling. Publishes the full accumulated + cloud for Rerun visualization. + + Ports: + registered_scan (In[PointCloud2]): World-frame lidar scan. + odometry (In[Odometry]): Vehicle pose for range culling. + global_map (Out[PointCloud2]): Accumulated voxelized cloud. + """ + + default_config = GlobalMapConfig + + registered_scan: In[PointCloud2] + odometry: In[Odometry] + global_map: Out[PointCloud2] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._lock = threading.Lock() + self._running = False + self._thread: threading.Thread | None = None + # Voxel storage: key=(ix,iy,iz) -> (x, y, z, timestamp) + self._voxels: dict[tuple[int, int, int], tuple[float, float, float, float]] = {} + self._robot_x = 0.0 + self._robot_y = 0.0 + self._robot_z = 0.0 + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + for k in ("_lock", "_thread", "_voxels"): + state.pop(k, None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._lock = threading.Lock() + self._thread = None + self._voxels = {} + + def start(self) -> None: + self.registered_scan._transport.subscribe(self._on_scan) + self.odometry._transport.subscribe(self._on_odom) + self._running = True + self._thread = threading.Thread(target=self._publish_loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + self._running = False + if self._thread: + self._thread.join(timeout=3.0) + super().stop() + + def _on_odom(self, msg: Odometry) -> None: + with self._lock: + self._robot_x = msg.pose.position.x + self._robot_y = msg.pose.position.y + self._robot_z = msg.pose.position.z + + def _on_scan(self, cloud: PointCloud2) -> None: + points, _ = cloud.as_numpy() + if len(points) == 0: + return + + vs = self.config.voxel_size + h_min = self.config.height_min + h_max = self.config.height_max + now = time.time() + + with self._lock: + for i in range(len(points)): + x, y, z = float(points[i, 0]), float(points[i, 1]), float(points[i, 2]) + # Height filter + if z < h_min or z > h_max: + continue + ix = int(np.floor(x / vs)) + iy = int(np.floor(y / vs)) + iz = int(np.floor(z / vs)) + self._voxels[(ix, iy, iz)] = (x, y, z, now) + + def _publish_loop(self) -> None: + dt = 1.0 / self.config.publish_rate + while self._running: + t0 = time.monotonic() + now = time.time() + decay = self.config.decay_time + max_r2 = self.config.max_range**2 + max_pts = self.config.max_points + + with self._lock: + rx, ry = self._robot_x, self._robot_y + # Expire old voxels and range-cull + expired = [] + pts = [] + for k, (x, y, z, ts) in self._voxels.items(): + if now - ts > decay: + expired.append(k) + elif (x - rx) ** 2 + (y - ry) ** 2 > max_r2: + expired.append(k) + else: + pts.append([x, y, z]) + for k in expired: + del self._voxels[k] + + if pts: + # Cap total points to prevent memory explosion + if len(pts) > max_pts: + pts = pts[:max_pts] + arr = np.array(pts, dtype=np.float32) + self.global_map._transport.publish( + PointCloud2.from_numpy(arr, frame_id="map", timestamp=now) + ) + + elapsed = time.monotonic() - t0 + if elapsed < dt: + time.sleep(dt - elapsed) diff --git a/dimos/navigation/smartnav/modules/local_planner/local_planner.py b/dimos/navigation/smartnav/modules/local_planner/local_planner.py new file mode 100644 index 0000000000..d168e15ae7 --- /dev/null +++ b/dimos/navigation/smartnav/modules/local_planner/local_planner.py @@ -0,0 +1,94 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LocalPlanner NativeModule: C++ local path planner with obstacle avoidance. + +Ported from localPlanner.cpp. Uses pre-computed path sets and DWA-like +evaluation to select collision-free paths toward goals. +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.native_module import NativeModule, NativeModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.nav_msgs.Path import Path +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.data import get_data + + +def _default_paths_dir() -> str: + """Resolve path data from LFS.""" + return str(get_data("smartnav_paths")) + + +class LocalPlannerConfig(NativeModuleConfig): + """Config for the local planner native module.""" + + cwd: str | None = "." + executable: str = "result/bin/local_planner" + build_command: str | None = ( + "nix build github:dimensionalOS/dimos-module-local-planner/v0.1.1 --no-write-lock-file" + ) + + # Path data directory (auto-resolved from LFS) + paths_dir: str = "" + + def model_post_init(self, __context: Any) -> None: + super().model_post_init(__context) + if not self.paths_dir: + self.paths_dir = _default_paths_dir() + + # Vehicle config + vehicle_config: str = "omniDir" # "omniDir" for mecanum, "standard" for ackermann + + # Speed limits + max_speed: float = 2.0 + autonomy_speed: float = 1.0 + + # Obstacle detection + obstacle_height_threshold: float = 0.15 + + # Goal parameters + goal_clearance: float = 0.5 + goal_x: float = 0.0 + goal_y: float = 0.0 + + +class LocalPlanner(NativeModule): + """Local path planner with obstacle avoidance. + + Evaluates pre-computed path sets against current obstacle map to select + the best collision-free path toward the goal. Supports smart joystick, + waypoint, and manual control modes. + + Ports: + registered_scan (In[PointCloud2]): Obstacle point cloud. + odometry (In[Odometry]): Vehicle state estimation. + joy_cmd (In[Twist]): Joystick/teleop velocity commands. + way_point (In[PointStamped]): Navigation goal waypoint. + path (Out[Path]): Selected local path for path follower. + """ + + default_config: type[LocalPlannerConfig] = LocalPlannerConfig # type: ignore[assignment] + + registered_scan: In[PointCloud2] + odometry: In[Odometry] + joy_cmd: In[Twist] + way_point: In[PointStamped] + path: Out[Path] diff --git a/dimos/navigation/smartnav/modules/local_planner/test_local_planner.py b/dimos/navigation/smartnav/modules/local_planner/test_local_planner.py new file mode 100644 index 0000000000..90dc71c077 --- /dev/null +++ b/dimos/navigation/smartnav/modules/local_planner/test_local_planner.py @@ -0,0 +1,114 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LocalPlanner NativeModule wrapper.""" + +from pathlib import Path + +import pytest + +from dimos.navigation.smartnav.modules.local_planner.local_planner import ( + LocalPlanner, + LocalPlannerConfig, +) + + +class TestLocalPlannerConfig: + """Test LocalPlanner configuration.""" + + def test_default_config(self): + config = LocalPlannerConfig() + assert config.max_speed == 2.0 + assert config.autonomy_speed == 1.0 + assert config.vehicle_config == "omniDir" + assert config.obstacle_height_threshold == 0.15 + + def test_cli_args_generation(self): + config = LocalPlannerConfig( + max_speed=1.5, + paths_dir="/custom/paths", + ) + args = config.to_cli_args() + assert "--max_speed" in args + assert "1.5" in args + assert "--paths_dir" in args + assert "/custom/paths" in args + + +class TestLocalPlannerModule: + """Test LocalPlanner module declaration.""" + + def test_ports_declared(self): + from typing import get_origin, get_type_hints + + from dimos.core.stream import In, Out + + hints = get_type_hints(LocalPlanner) + in_ports = {k for k, v in hints.items() if get_origin(v) is In} + out_ports = {k for k, v in hints.items() if get_origin(v) is Out} + + assert "registered_scan" in in_ports + assert "odometry" in in_ports + assert "joy_cmd" in in_ports + assert "way_point" in in_ports + assert "path" in out_ports + + +@pytest.mark.skipif( + not Path(__file__).resolve().parent.joinpath("result", "bin").exists(), + reason="Native binary not built (run nix build first)", +) +class TestPathResolution: + """Verify native module paths resolve to real filesystem locations.""" + + def _make(self): + m = LocalPlanner() + m._resolve_paths() + return m + + def test_cwd_resolves_to_existing_directory(self): + m = self._make() + try: + assert Path(m.config.cwd).exists(), f"cwd does not exist: {m.config.cwd}" + assert Path(m.config.cwd).is_dir() + finally: + m.stop() + + def test_executable_exists(self): + m = self._make() + try: + exe = Path(m.config.executable) + assert exe.exists(), f"Binary not found: {exe}. Run nix build first." + finally: + m.stop() + + def test_cwd_resolves_to_smartnav_root(self): + """cwd should resolve to the smartnav root (where CMakeLists.txt lives).""" + m = self._make() + try: + cwd = Path(m.config.cwd).resolve() + assert (cwd / "CMakeLists.txt").exists(), f"cwd {cwd} is not the smartnav root" + assert (cwd / "flake.nix").exists() + finally: + m.stop() + + def test_data_files_exist(self): + """Local planner needs path data files (pulled from LFS).""" + from dimos.utils.data import get_data + + paths_dir = get_data("smartnav_paths") + assert paths_dir.exists(), f"paths_dir not found: {paths_dir}" + assert (paths_dir / "startPaths.ply").exists() + assert (paths_dir / "pathList.ply").exists() + assert (paths_dir / "paths.ply").exists() diff --git a/dimos/navigation/smartnav/modules/odom_adapter/odom_adapter.py b/dimos/navigation/smartnav/modules/odom_adapter/odom_adapter.py new file mode 100644 index 0000000000..1185d20cfe --- /dev/null +++ b/dimos/navigation/smartnav/modules/odom_adapter/odom_adapter.py @@ -0,0 +1,74 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OdomAdapter: bidirectional PoseStamped <-> Odometry converter. + +Bridges GO2Connection (PoseStamped odom) with PGO (Odometry). +Also converts PGO's corrected Odometry back to PoseStamped for +downstream consumers (ReplanningAStarPlanner, WavefrontFrontierExplorer). +""" + +from __future__ import annotations + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.Odometry import Odometry + + +class OdomAdapter(Module[ModuleConfig]): + """Bidirectional PoseStamped <-> Odometry adapter.""" + + default_config = ModuleConfig + + raw_odom: In[PoseStamped] + odometry: Out[Odometry] + corrected_odometry: In[Odometry] + odom: Out[PoseStamped] + + def start(self) -> None: + self.raw_odom._transport.subscribe(self._on_raw_odom) + self.corrected_odometry._transport.subscribe(self._on_corrected_odom) + print("[OdomAdapter] Started") + + def _on_raw_odom(self, msg: PoseStamped) -> None: + odom = Odometry( + ts=msg.ts, + frame_id=msg.frame_id, + pose=Pose( + position=[msg.x, msg.y, msg.z], + orientation=[ + msg.orientation.x, + msg.orientation.y, + msg.orientation.z, + msg.orientation.w, + ], + ), + ) + self.odometry._transport.publish(odom) + + def _on_corrected_odom(self, msg: Odometry) -> None: + ps = PoseStamped( + ts=msg.ts, + frame_id=msg.frame_id, + position=[msg.x, msg.y, msg.z], + orientation=[ + msg.orientation.x, + msg.orientation.y, + msg.orientation.z, + msg.orientation.w, + ], + ) + self.odom._transport.publish(ps) diff --git a/dimos/navigation/smartnav/modules/path_follower/path_follower.py b/dimos/navigation/smartnav/modules/path_follower/path_follower.py new file mode 100644 index 0000000000..0d996f8995 --- /dev/null +++ b/dimos/navigation/smartnav/modules/path_follower/path_follower.py @@ -0,0 +1,67 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PathFollower NativeModule: C++ pure pursuit path tracking controller. + +Ported from pathFollower.cpp. Follows a given path using pure pursuit +with PID yaw control, outputting velocity commands. +""" + +from __future__ import annotations + +from dimos.core.native_module import NativeModule, NativeModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.nav_msgs.Path import Path + + +class PathFollowerConfig(NativeModuleConfig): + """Config for the path follower native module.""" + + cwd: str | None = "." + executable: str = "result/bin/path_follower" + build_command: str | None = ( + "nix build github:dimensionalOS/dimos-module-path-follower/v0.1.0 --no-write-lock-file" + ) + + # Pure pursuit parameters + look_ahead_distance: float = 0.5 + max_speed: float = 2.0 + max_yaw_rate: float = 1.5 + + # Goal tolerance + goal_tolerance: float = 0.3 + + # Vehicle config + vehicle_config: str = "omniDir" + + +class PathFollower(NativeModule): + """Pure pursuit path follower with PID yaw control. + + Takes a path from the local planner and the current vehicle state, + then computes velocity commands to follow the path. + + Ports: + path (In[Path]): Local path to follow. + odometry (In[Odometry]): Vehicle state estimation. + cmd_vel (Out[Twist]): Velocity commands for the vehicle. + """ + + default_config: type[PathFollowerConfig] = PathFollowerConfig # type: ignore[assignment] + + path: In[Path] + odometry: In[Odometry] + cmd_vel: Out[Twist] diff --git a/dimos/navigation/smartnav/modules/path_follower/test_path_follower.py b/dimos/navigation/smartnav/modules/path_follower/test_path_follower.py new file mode 100644 index 0000000000..e6ce34ac37 --- /dev/null +++ b/dimos/navigation/smartnav/modules/path_follower/test_path_follower.py @@ -0,0 +1,100 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for PathFollower NativeModule wrapper.""" + +from pathlib import Path + +import pytest + +from dimos.navigation.smartnav.modules.path_follower.path_follower import ( + PathFollower, + PathFollowerConfig, +) + + +class TestPathFollowerConfig: + """Test PathFollower configuration.""" + + def test_default_config(self): + config = PathFollowerConfig() + assert config.look_ahead_distance == 0.5 + assert config.max_speed == 2.0 + assert config.max_yaw_rate == 1.5 + assert config.goal_tolerance == 0.3 + + def test_cli_args_generation(self): + config = PathFollowerConfig( + look_ahead_distance=1.0, + max_speed=1.0, + ) + args = config.to_cli_args() + assert "--look_ahead_distance" in args + assert "--max_speed" in args + + +class TestPathFollowerModule: + """Test PathFollower module declaration.""" + + def test_ports_declared(self): + from typing import get_origin, get_type_hints + + from dimos.core.stream import In, Out + + hints = get_type_hints(PathFollower) + in_ports = {k for k, v in hints.items() if get_origin(v) is In} + out_ports = {k for k, v in hints.items() if get_origin(v) is Out} + + assert "path" in in_ports + assert "odometry" in in_ports + assert "cmd_vel" in out_ports + + +@pytest.mark.skipif( + not Path(__file__).resolve().parent.joinpath("result", "bin").exists(), + reason="Native binary not built (run nix build first)", +) +class TestPathResolution: + """Verify native module paths resolve to real filesystem locations.""" + + def _make(self): + m = PathFollower() + m._resolve_paths() + return m + + def test_cwd_resolves_to_existing_directory(self): + m = self._make() + try: + assert Path(m.config.cwd).exists(), f"cwd does not exist: {m.config.cwd}" + assert Path(m.config.cwd).is_dir() + finally: + m.stop() + + def test_executable_exists(self): + m = self._make() + try: + exe = Path(m.config.executable) + assert exe.exists(), f"Binary not found: {exe}. Run nix build first." + finally: + m.stop() + + def test_cwd_resolves_to_smartnav_root(self): + """cwd should resolve to the smartnav root (where CMakeLists.txt lives).""" + m = self._make() + try: + cwd = Path(m.config.cwd).resolve() + assert (cwd / "CMakeLists.txt").exists(), f"cwd {cwd} is not the smartnav root" + assert (cwd / "flake.nix").exists() + finally: + m.stop() diff --git a/dimos/navigation/smartnav/modules/pgo/main.cpp b/dimos/navigation/smartnav/modules/pgo/main.cpp new file mode 100644 index 0000000000..c012a7b437 --- /dev/null +++ b/dimos/navigation/smartnav/modules/pgo/main.cpp @@ -0,0 +1,533 @@ +// PGO (Pose Graph Optimization) — dimos NativeModule +// Ported from ROS2: src/slam/FASTLIO2_ROS2/pgo/src/pgos/simple_pgo.cpp +// +// Performs keyframe-based pose graph optimization with loop closure detection. +// Subscribes to registered_scan + odometry, publishes corrected_odometry + global_map. +// +// Loop closure pipeline: +// 1. Keyframe detection (translation/rotation thresholds) +// 2. KD-tree radius search on past keyframe positions +// 3. ICP verification between current and candidate submaps +// 4. GTSAM iSAM2 pose graph optimization +// 5. Global map assembly from corrected keyframes + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "dimos_native_module.hpp" +#include "point_cloud_utils.hpp" + +#include "sensor_msgs/PointCloud2.hpp" +#include "nav_msgs/Odometry.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using PointType = pcl::PointXYZI; +using CloudType = pcl::PointCloud; +using M3D = Eigen::Matrix3d; +using V3D = Eigen::Vector3d; +using M4F = Eigen::Matrix4f; + +// ─── Configuration ─────────────────────────────────────────────────────────── + +struct PGOConfig { + double key_pose_delta_trans = 0.5; + double key_pose_delta_deg = 10.0; + double loop_search_radius = 15.0; + double loop_time_thresh = 60.0; + double loop_score_thresh = 0.3; + int loop_submap_half_range = 5; + double submap_resolution = 0.1; + double min_loop_detect_duration = 5.0; + double global_map_publish_rate = 0.5; + double global_map_voxel_size = 0.15; + int max_icp_iterations = 50; + double max_icp_correspondence_dist = 10.0; +}; + +// ─── Keyframe storage ──────────────────────────────────────────────────────── + +struct KeyPoseWithCloud { + M3D r_local; + V3D t_local; + M3D r_global; + V3D t_global; + double time; + CloudType::Ptr body_cloud; +}; + +struct LoopPair { + size_t source_id; + size_t target_id; + M3D r_offset; + V3D t_offset; + double score; +}; + +// ─── SimplePGO core algorithm ──────────────────────────────────────────────── + +class SimplePGO { +public: + SimplePGO(const PGOConfig& config) : m_config(config) { + gtsam::ISAM2Params isam2_params; + isam2_params.relinearizeThreshold = 0.01; + isam2_params.relinearizeSkip = 1; + m_isam2 = std::make_shared(isam2_params); + m_initial_values.clear(); + m_graph.resize(0); + m_r_offset.setIdentity(); + m_t_offset.setZero(); + + m_icp.setMaximumIterations(config.max_icp_iterations); + m_icp.setMaxCorrespondenceDistance(config.max_icp_correspondence_dist); + m_icp.setTransformationEpsilon(1e-6); + m_icp.setEuclideanFitnessEpsilon(1e-6); + m_icp.setRANSACIterations(0); + } + + bool isKeyPose(const M3D& r, const V3D& t) { + if (m_key_poses.empty()) return true; + const auto& last = m_key_poses.back(); + double delta_trans = (t - last.t_local).norm(); + double delta_deg = Eigen::Quaterniond(r).angularDistance( + Eigen::Quaterniond(last.r_local)) * 180.0 / M_PI; + return (delta_trans > m_config.key_pose_delta_trans || + delta_deg > m_config.key_pose_delta_deg); + } + + bool addKeyPose(const M3D& r_local, const V3D& t_local, + double timestamp, CloudType::Ptr body_cloud) { + if (!isKeyPose(r_local, t_local)) return false; + + size_t idx = m_key_poses.size(); + M3D init_r = m_r_offset * r_local; + V3D init_t = m_r_offset * t_local + m_t_offset; + + // Add initial value + m_initial_values.insert(idx, gtsam::Pose3(gtsam::Rot3(init_r), gtsam::Point3(init_t))); + + if (idx == 0) { + // Prior factor on first pose + auto noise = gtsam::noiseModel::Diagonal::Variances( + gtsam::Vector6::Ones() * 1e-12); + m_graph.add(gtsam::PriorFactor( + idx, gtsam::Pose3(gtsam::Rot3(init_r), gtsam::Point3(init_t)), noise)); + } else { + // Odometry factor + const auto& last = m_key_poses.back(); + M3D r_between = last.r_local.transpose() * r_local; + V3D t_between = last.r_local.transpose() * (t_local - last.t_local); + auto noise = gtsam::noiseModel::Diagonal::Variances( + (gtsam::Vector(6) << 1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-6).finished()); + m_graph.add(gtsam::BetweenFactor( + idx - 1, idx, + gtsam::Pose3(gtsam::Rot3(r_between), gtsam::Point3(t_between)), + noise)); + } + + KeyPoseWithCloud item; + item.time = timestamp; + item.r_local = r_local; + item.t_local = t_local; + item.body_cloud = body_cloud; + item.r_global = init_r; + item.t_global = init_t; + m_key_poses.push_back(item); + return true; + } + + CloudType::Ptr getSubMap(int idx, int half_range, double resolution) { + int min_idx = std::max(0, idx - half_range); + int max_idx = std::min(static_cast(m_key_poses.size()) - 1, idx + half_range); + + CloudType::Ptr ret(new CloudType); + for (int i = min_idx; i <= max_idx; i++) { + CloudType::Ptr global_cloud(new CloudType); + pcl::transformPointCloud(*m_key_poses[i].body_cloud, *global_cloud, + m_key_poses[i].t_global.cast(), + Eigen::Quaternionf(m_key_poses[i].r_global.cast())); + *ret += *global_cloud; + } + if (resolution > 0 && ret->size() > 0) { + pcl::VoxelGrid voxel_grid; + voxel_grid.setLeafSize(resolution, resolution, resolution); + voxel_grid.setInputCloud(ret); + voxel_grid.filter(*ret); + } + return ret; + } + + void searchForLoopPairs() { + if (m_key_poses.size() < 10) return; + + // Rate-limit loop detection + if (m_config.min_loop_detect_duration > 0.0 && !m_history_pairs.empty()) { + double current_time = m_key_poses.back().time; + double last_time = m_key_poses[m_history_pairs.back().second].time; + if (current_time - last_time < m_config.min_loop_detect_duration) return; + } + + size_t cur_idx = m_key_poses.size() - 1; + const auto& last_item = m_key_poses.back(); + + // Build KD-tree of all previous keyframe positions + pcl::PointCloud::Ptr key_poses_cloud(new pcl::PointCloud); + for (size_t i = 0; i < m_key_poses.size() - 1; i++) { + pcl::PointXYZ pt; + pt.x = m_key_poses[i].t_global(0); + pt.y = m_key_poses[i].t_global(1); + pt.z = m_key_poses[i].t_global(2); + key_poses_cloud->push_back(pt); + } + + pcl::KdTreeFLANN kdtree; + kdtree.setInputCloud(key_poses_cloud); + + pcl::PointXYZ search_pt; + search_pt.x = last_item.t_global(0); + search_pt.y = last_item.t_global(1); + search_pt.z = last_item.t_global(2); + + std::vector ids; + std::vector sqdists; + int neighbors = kdtree.radiusSearch(search_pt, m_config.loop_search_radius, ids, sqdists); + if (neighbors == 0) return; + + // Find candidate far enough in time + int loop_idx = -1; + for (size_t i = 0; i < ids.size(); i++) { + int idx = ids[i]; + if (std::abs(last_item.time - m_key_poses[idx].time) > m_config.loop_time_thresh) { + loop_idx = idx; + break; + } + } + if (loop_idx == -1) return; + + // ICP verification + CloudType::Ptr target_cloud = getSubMap(loop_idx, m_config.loop_submap_half_range, + m_config.submap_resolution); + CloudType::Ptr source_cloud = getSubMap(m_key_poses.size() - 1, 0, + m_config.submap_resolution); + CloudType::Ptr align_cloud(new CloudType); + + m_icp.setInputSource(source_cloud); + m_icp.setInputTarget(target_cloud); + m_icp.align(*align_cloud); + + if (!m_icp.hasConverged() || m_icp.getFitnessScore() > m_config.loop_score_thresh) + return; + + M4F loop_transform = m_icp.getFinalTransformation(); + + LoopPair pair; + pair.source_id = cur_idx; + pair.target_id = loop_idx; + pair.score = m_icp.getFitnessScore(); + M3D r_refined = loop_transform.block<3,3>(0,0).cast() * m_key_poses[cur_idx].r_global; + V3D t_refined = loop_transform.block<3,3>(0,0).cast() * m_key_poses[cur_idx].t_global + + loop_transform.block<3,1>(0,3).cast(); + pair.r_offset = m_key_poses[loop_idx].r_global.transpose() * r_refined; + pair.t_offset = m_key_poses[loop_idx].r_global.transpose() * (t_refined - m_key_poses[loop_idx].t_global); + m_cache_pairs.push_back(pair); + m_history_pairs.emplace_back(pair.target_id, pair.source_id); + + printf("[PGO] Loop closure detected: %zu <-> %zu (score=%.4f)\n", + pair.target_id, pair.source_id, pair.score); + } + + void smoothAndUpdate() { + bool has_loop = !m_cache_pairs.empty(); + + // Add loop closure factors + if (has_loop) { + for (auto& pair : m_cache_pairs) { + m_graph.add(gtsam::BetweenFactor( + pair.target_id, pair.source_id, + gtsam::Pose3(gtsam::Rot3(pair.r_offset), gtsam::Point3(pair.t_offset)), + gtsam::noiseModel::Diagonal::Variances( + gtsam::Vector6::Ones() * pair.score))); + } + m_cache_pairs.clear(); + } + + // iSAM2 update + m_isam2->update(m_graph, m_initial_values); + m_isam2->update(); + if (has_loop) { + // Extra iterations for convergence after loop closure + m_isam2->update(); + m_isam2->update(); + m_isam2->update(); + m_isam2->update(); + } + m_graph.resize(0); + m_initial_values.clear(); + + // Update keyframe poses from optimized values + gtsam::Values estimates = m_isam2->calculateBestEstimate(); + for (size_t i = 0; i < m_key_poses.size(); i++) { + gtsam::Pose3 pose = estimates.at(i); + m_key_poses[i].r_global = pose.rotation().matrix(); + m_key_poses[i].t_global = pose.translation(); + } + + // Update offset for incoming poses + const auto& last = m_key_poses.back(); + m_r_offset = last.r_global * last.r_local.transpose(); + m_t_offset = last.t_global - m_r_offset * last.t_local; + } + + // Build global map from all corrected keyframes + CloudType::Ptr buildGlobalMap(double voxel_size) { + CloudType::Ptr global_map(new CloudType); + for (auto& kp : m_key_poses) { + CloudType::Ptr world_cloud(new CloudType); + pcl::transformPointCloud(*kp.body_cloud, *world_cloud, + kp.t_global.cast(), + Eigen::Quaternionf(kp.r_global.cast())); + *global_map += *world_cloud; + } + if (voxel_size > 0 && global_map->size() > 0) { + pcl::VoxelGrid voxel; + voxel.setLeafSize(voxel_size, voxel_size, voxel_size); + voxel.setInputCloud(global_map); + voxel.filter(*global_map); + } + return global_map; + } + + // Accessors + const std::vector& keyPoses() const { return m_key_poses; } + size_t numKeyPoses() const { return m_key_poses.size(); } + M3D offsetR() const { return m_r_offset; } + V3D offsetT() const { return m_t_offset; } + + // Get corrected pose for current local pose + void getCorrectedPose(const M3D& r_local, const V3D& t_local, + M3D& r_corrected, V3D& t_corrected) const { + r_corrected = m_r_offset * r_local; + t_corrected = m_r_offset * t_local + m_t_offset; + } + +private: + PGOConfig m_config; + std::vector m_key_poses; + std::vector> m_history_pairs; + std::vector m_cache_pairs; + M3D m_r_offset; + V3D m_t_offset; + std::shared_ptr m_isam2; + gtsam::Values m_initial_values; + gtsam::NonlinearFactorGraph m_graph; + pcl::IterativeClosestPoint m_icp; +}; + +// ─── LCM Handler ───────────────────────────────────────────────────────────── + +static std::atomic g_running{true}; +void signal_handler(int) { g_running = false; } + +struct PGOHandler { + lcm::LCM* lcm; + SimplePGO* pgo; + std::string topic_corrected_odom; + std::string topic_global_map; + PGOConfig config; + + std::mutex mtx; + M3D latest_r = M3D::Identity(); + V3D latest_t = V3D::Zero(); + double latest_time = 0.0; + bool has_odom = false; + + // Global map publishing state + double last_global_map_time = 0.0; + + void onOdometry(const lcm::ReceiveBuffer*, const std::string&, + const nav_msgs::Odometry* msg) { + std::lock_guard lock(mtx); + latest_t = V3D(msg->pose.pose.position.x, + msg->pose.pose.position.y, + msg->pose.pose.position.z); + Eigen::Quaterniond q(msg->pose.pose.orientation.w, + msg->pose.pose.orientation.x, + msg->pose.pose.orientation.y, + msg->pose.pose.orientation.z); + latest_r = q.toRotationMatrix(); + latest_time = msg->header.stamp.sec + msg->header.stamp.nsec / 1e9; + has_odom = true; + } + + void onRegisteredScan(const lcm::ReceiveBuffer*, const std::string&, + const sensor_msgs::PointCloud2* msg) { + std::lock_guard lock(mtx); + if (!has_odom) return; + + double scan_time = msg->header.stamp.sec + msg->header.stamp.nsec / 1e9; + + // Convert PointCloud2 to PCL (body frame) + CloudType::Ptr body_cloud(new CloudType); + smartnav::to_pcl(*msg, *body_cloud); + + if (body_cloud->empty()) return; + + // Downsample body cloud for storage + if (config.submap_resolution > 0) { + pcl::VoxelGrid voxel; + voxel.setLeafSize(config.submap_resolution, config.submap_resolution, + config.submap_resolution); + voxel.setInputCloud(body_cloud); + voxel.filter(*body_cloud); + } + + // Try to add as keyframe + bool added = pgo->addKeyPose(latest_r, latest_t, latest_time, body_cloud); + + if (added) { + pgo->searchForLoopPairs(); + pgo->smoothAndUpdate(); + printf("[PGO] Keyframe %zu added (%.1f, %.1f, %.1f)\n", + pgo->numKeyPoses(), latest_t(0), latest_t(1), latest_t(2)); + } + + // Publish corrected odometry + publishCorrectedOdometry(scan_time); + + // Publish global map at configured rate + double now = std::chrono::duration( + std::chrono::steady_clock::now().time_since_epoch()).count(); + double interval = (config.global_map_publish_rate > 0) ? + 1.0 / config.global_map_publish_rate : 2.0; + if (now - last_global_map_time > interval) { + publishGlobalMap(scan_time); + last_global_map_time = now; + } + } + + void publishCorrectedOdometry(double timestamp) { + M3D r_corrected; + V3D t_corrected; + pgo->getCorrectedPose(latest_r, latest_t, r_corrected, t_corrected); + + Eigen::Quaterniond q(r_corrected); + + nav_msgs::Odometry odom; + odom.header = dimos::make_header("map", timestamp); + odom.child_frame_id = "sensor"; + odom.pose.pose.position.x = t_corrected(0); + odom.pose.pose.position.y = t_corrected(1); + odom.pose.pose.position.z = t_corrected(2); + odom.pose.pose.orientation.x = q.x(); + odom.pose.pose.orientation.y = q.y(); + odom.pose.pose.orientation.z = q.z(); + odom.pose.pose.orientation.w = q.w(); + + lcm->publish(topic_corrected_odom, &odom); + } + + void publishGlobalMap(double timestamp) { + if (pgo->numKeyPoses() == 0) return; + + CloudType::Ptr global_map = pgo->buildGlobalMap(config.global_map_voxel_size); + + sensor_msgs::PointCloud2 pc = smartnav::from_pcl(*global_map, "map", timestamp); + lcm->publish(topic_global_map, &pc); + + printf("[PGO] Global map published: %zu points, %zu keyframes\n", + global_map->size(), pgo->numKeyPoses()); + } +}; + +// ─── Main ──────────────────────────────────────────────────────────────────── + +int main(int argc, char** argv) { + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + + dimos::NativeModule mod(argc, argv); + + // Read config from CLI args + PGOConfig config; + config.key_pose_delta_trans = mod.arg_float("keyPoseDeltaTrans", 0.5f); + config.key_pose_delta_deg = mod.arg_float("keyPoseDeltaDeg", 10.0f); + config.loop_search_radius = mod.arg_float("loopSearchRadius", 15.0f); + config.loop_time_thresh = mod.arg_float("loopTimeThresh", 60.0f); + config.loop_score_thresh = mod.arg_float("loopScoreThresh", 0.3f); + config.loop_submap_half_range = mod.arg_int("loopSubmapHalfRange", 5); + config.submap_resolution = mod.arg_float("submapResolution", 0.1f); + config.min_loop_detect_duration = mod.arg_float("minLoopDetectDuration", 5.0f); + config.global_map_publish_rate = mod.arg_float("globalMapPublishRate", 0.5f); + config.global_map_voxel_size = mod.arg_float("globalMapVoxelSize", 0.15f); + config.max_icp_iterations = mod.arg_int("maxIcpIterations", 50); + config.max_icp_correspondence_dist = mod.arg_float("maxIcpCorrespondenceDist", 10.0f); + + printf("[PGO] Config: keyPoseDeltaTrans=%.2f keyPoseDeltaDeg=%.1f " + "loopSearchRadius=%.1f loopTimeThresh=%.1f loopScoreThresh=%.2f " + "globalMapVoxelSize=%.2f\n", + config.key_pose_delta_trans, config.key_pose_delta_deg, + config.loop_search_radius, config.loop_time_thresh, + config.loop_score_thresh, config.global_map_voxel_size); + + // Create PGO instance + SimplePGO pgo(config); + + // LCM setup + lcm::LCM lcm; + if (!lcm.good()) { + fprintf(stderr, "[PGO] LCM initialization failed\n"); + return 1; + } + + PGOHandler handler; + handler.lcm = &lcm; + handler.pgo = &pgo; + handler.topic_corrected_odom = mod.topic("corrected_odometry"); + handler.topic_global_map = mod.topic("global_map"); + handler.config = config; + + std::string topic_scan = mod.topic("registered_scan"); + std::string topic_odom = mod.topic("odometry"); + + lcm.subscribe(topic_odom, &PGOHandler::onOdometry, &handler); + lcm.subscribe(topic_scan, &PGOHandler::onRegisteredScan, &handler); + + printf("[PGO] Listening on: registered_scan=%s odometry=%s\n", + topic_scan.c_str(), topic_odom.c_str()); + printf("[PGO] Publishing: corrected_odometry=%s global_map=%s\n", + handler.topic_corrected_odom.c_str(), handler.topic_global_map.c_str()); + + while (g_running) { + lcm.handleTimeout(100); + } + + printf("[PGO] Shutting down. Total keyframes: %zu\n", pgo.numKeyPoses()); + return 0; +} diff --git a/dimos/navigation/smartnav/modules/pgo/pgo.py b/dimos/navigation/smartnav/modules/pgo/pgo.py new file mode 100644 index 0000000000..4ed494cab3 --- /dev/null +++ b/dimos/navigation/smartnav/modules/pgo/pgo.py @@ -0,0 +1,513 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PGO Module: Python pose graph optimization with loop closure. + +Ported from FASTLIO2_ROS2/pgo. Detects keyframes, performs loop closure +via ICP + KD-tree search, and optimizes the pose graph with GTSAM iSAM2. +Publishes corrected odometry and accumulated global map. + +Falls back from native C++ to pure Python when the native binary cannot +be built (e.g. missing GTSAM in nixpkgs). +""" + +from __future__ import annotations + +from dataclasses import dataclass +import threading +import time +from typing import Any + +import gtsam +import numpy as np +from scipy.spatial import KDTree + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class PGOConfig(ModuleConfig): + """Config for the PGO Python module.""" + + # Keyframe detection + key_pose_delta_trans: float = 0.5 + key_pose_delta_deg: float = 10.0 + + # Loop closure + loop_search_radius: float = 15.0 + loop_time_thresh: float = 60.0 + loop_score_thresh: float = 0.3 + loop_submap_half_range: int = 5 + submap_resolution: float = 0.1 + min_loop_detect_duration: float = 5.0 + + # Input mode + unregister_input: bool = True # Transform world-frame scans to body-frame using odom + + # Global map + global_map_publish_rate: float = 0.5 + global_map_voxel_size: float = 0.15 + + # ICP + max_icp_iterations: int = 50 + max_icp_correspondence_dist: float = 10.0 + + +# ─── Keyframe storage ──────────────────────────────────────────────────────── + + +@dataclass +class _KeyPose: + r_local: np.ndarray # 3x3 rotation in local/odom frame + t_local: np.ndarray # 3-vec translation in local/odom frame + r_global: np.ndarray # 3x3 corrected rotation + t_global: np.ndarray # 3-vec corrected translation + timestamp: float + body_cloud: np.ndarray # Nx3 points in body frame + + +# ─── Simple ICP (point-to-point, no PCL dependency) ───────────────────────── + + +def _icp( + source: np.ndarray, + target: np.ndarray, + max_iter: int = 50, + max_dist: float = 10.0, + tol: float = 1e-6, +) -> tuple[np.ndarray, float]: + """Simple point-to-point ICP. Returns (4x4 transform, fitness score).""" + if len(source) == 0 or len(target) == 0: + return np.eye(4), float("inf") + + tree = KDTree(target) + T = np.eye(4) + src = source.copy() + + for _ in range(max_iter): + dists, idxs = tree.query(src) + mask = dists < max_dist + if mask.sum() < 10: + return T, float("inf") + + p = src[mask] + q = target[idxs[mask]] + + cp = p.mean(axis=0) + cq = q.mean(axis=0) + H = (p - cp).T @ (q - cq) + + U, _, Vt = np.linalg.svd(H) + R = Vt.T @ U.T + if np.linalg.det(R) < 0: + Vt[-1, :] *= -1 + R = Vt.T @ U.T + t = cq - R @ cp + + dT = np.eye(4) + dT[:3, :3] = R + dT[:3, 3] = t + T = dT @ T + src = (R @ src.T).T + t + + if np.linalg.norm(t) < tol: + break + + # Fitness: mean squared distance of inliers + dists_final, _ = tree.query(src) + mask = dists_final < max_dist + fitness = float(np.mean(dists_final[mask] ** 2)) if mask.sum() > 0 else float("inf") + return T, fitness + + +def _voxel_downsample(pts: np.ndarray, voxel_size: float) -> np.ndarray: + """Voxel grid downsampling.""" + if len(pts) == 0 or voxel_size <= 0: + return pts + keys = np.floor(pts / voxel_size).astype(np.int32) + _, idx = np.unique(keys, axis=0, return_index=True) + return pts[idx] + + +# ─── SimplePGO core algorithm ──────────────────────────────────────────────── + + +class _SimplePGO: + """Python port of the C++ SimplePGO class.""" + + def __init__(self, config: PGOConfig) -> None: + self._cfg = config + self._key_poses: list[_KeyPose] = [] + self._history_pairs: list[tuple[int, int]] = [] + self._cache_pairs: list[dict] = [] + self._r_offset = np.eye(3) + self._t_offset = np.zeros(3) + + params = gtsam.ISAM2Params() + params.setRelinearizeThreshold(0.01) + params.relinearizeSkip = 1 + self._isam2 = gtsam.ISAM2(params) + self._graph = gtsam.NonlinearFactorGraph() + self._values = gtsam.Values() + + def is_key_pose(self, r: np.ndarray, t: np.ndarray) -> bool: + if not self._key_poses: + return True + last = self._key_poses[-1] + delta_trans = np.linalg.norm(t - last.t_local) + # Angular distance via quaternion dot product + from scipy.spatial.transform import Rotation + + q_cur = Rotation.from_matrix(r).as_quat() # [x,y,z,w] + q_last = Rotation.from_matrix(last.r_local).as_quat() + dot = abs(np.dot(q_cur, q_last)) + delta_deg = np.degrees(2.0 * np.arccos(min(dot, 1.0))) + return ( + delta_trans > self._cfg.key_pose_delta_trans or delta_deg > self._cfg.key_pose_delta_deg + ) + + def add_key_pose( + self, r_local: np.ndarray, t_local: np.ndarray, timestamp: float, body_cloud: np.ndarray + ) -> bool: + if not self.is_key_pose(r_local, t_local): + return False + + idx = len(self._key_poses) + init_r = self._r_offset @ r_local + init_t = self._r_offset @ t_local + self._t_offset + + pose = gtsam.Pose3(gtsam.Rot3(init_r), gtsam.Point3(init_t)) + self._values.insert(idx, pose) + + if idx == 0: + noise = gtsam.noiseModel.Diagonal.Variances(np.full(6, 1e-12)) + self._graph.add(gtsam.PriorFactorPose3(idx, pose, noise)) + else: + last = self._key_poses[-1] + r_between = last.r_local.T @ r_local + t_between = last.r_local.T @ (t_local - last.t_local) + noise = gtsam.noiseModel.Diagonal.Variances( + np.array([1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-6]) + ) + self._graph.add( + gtsam.BetweenFactorPose3( + idx - 1, idx, gtsam.Pose3(gtsam.Rot3(r_between), gtsam.Point3(t_between)), noise + ) + ) + + kp = _KeyPose( + r_local=r_local.copy(), + t_local=t_local.copy(), + r_global=init_r.copy(), + t_global=init_t.copy(), + timestamp=timestamp, + body_cloud=_voxel_downsample(body_cloud, self._cfg.submap_resolution), + ) + self._key_poses.append(kp) + return True + + def _get_submap(self, idx: int, half_range: int) -> np.ndarray: + lo = max(0, idx - half_range) + hi = min(len(self._key_poses) - 1, idx + half_range) + parts = [] + for i in range(lo, hi + 1): + kp = self._key_poses[i] + world = (kp.r_global @ kp.body_cloud.T).T + kp.t_global + parts.append(world) + if not parts: + return np.empty((0, 3)) + cloud = np.vstack(parts) + return _voxel_downsample(cloud, self._cfg.submap_resolution) + + def search_for_loops(self) -> None: + if len(self._key_poses) < 10: + return + + # Rate limit + if self._history_pairs: + cur_time = self._key_poses[-1].timestamp + last_time = self._key_poses[self._history_pairs[-1][1]].timestamp + if cur_time - last_time < self._cfg.min_loop_detect_duration: + return + + cur_idx = len(self._key_poses) - 1 + cur_kp = self._key_poses[-1] + + # Build KD-tree of previous keyframe positions + positions = np.array([kp.t_global for kp in self._key_poses[:-1]]) + tree = KDTree(positions) + + idxs = tree.query_ball_point(cur_kp.t_global, self._cfg.loop_search_radius) + if not idxs: + return + + # Find candidate far enough in time + loop_idx = -1 + for i in idxs: + if abs(cur_kp.timestamp - self._key_poses[i].timestamp) > self._cfg.loop_time_thresh: + loop_idx = i + break + if loop_idx == -1: + return + + # ICP verification + target = self._get_submap(loop_idx, self._cfg.loop_submap_half_range) + source = self._get_submap(cur_idx, 0) + + transform, fitness = _icp( + source, + target, + max_iter=self._cfg.max_icp_iterations, + max_dist=self._cfg.max_icp_correspondence_dist, + ) + if fitness > self._cfg.loop_score_thresh: + return + + # Compute relative pose + R_icp = transform[:3, :3] + t_icp = transform[:3, 3] + r_refined = R_icp @ cur_kp.r_global + t_refined = R_icp @ cur_kp.t_global + t_icp + r_offset = self._key_poses[loop_idx].r_global.T @ r_refined + t_offset = self._key_poses[loop_idx].r_global.T @ ( + t_refined - self._key_poses[loop_idx].t_global + ) + + self._cache_pairs.append( + { + "source": cur_idx, + "target": loop_idx, + "r_offset": r_offset, + "t_offset": t_offset, + "score": fitness, + } + ) + self._history_pairs.append((loop_idx, cur_idx)) + print(f"[PGO] Loop closure detected: {loop_idx} <-> {cur_idx} (score={fitness:.4f})") + + def smooth_and_update(self) -> None: + has_loop = bool(self._cache_pairs) + + for pair in self._cache_pairs: + noise = gtsam.noiseModel.Diagonal.Variances(np.full(6, pair["score"])) + self._graph.add( + gtsam.BetweenFactorPose3( + pair["target"], + pair["source"], + gtsam.Pose3(gtsam.Rot3(pair["r_offset"]), gtsam.Point3(pair["t_offset"])), + noise, + ) + ) + self._cache_pairs.clear() + + self._isam2.update(self._graph, self._values) + self._isam2.update() + if has_loop: + for _ in range(4): + self._isam2.update() + self._graph = gtsam.NonlinearFactorGraph() + self._values = gtsam.Values() + + estimates = self._isam2.calculateBestEstimate() + for i in range(len(self._key_poses)): + pose = estimates.atPose3(i) + self._key_poses[i].r_global = pose.rotation().matrix() + self._key_poses[i].t_global = pose.translation() + + last = self._key_poses[-1] + self._r_offset = last.r_global @ last.r_local.T + self._t_offset = last.t_global - self._r_offset @ last.t_local + + def get_corrected_pose( + self, r_local: np.ndarray, t_local: np.ndarray + ) -> tuple[np.ndarray, np.ndarray]: + return self._r_offset @ r_local, self._r_offset @ t_local + self._t_offset + + def build_global_map(self, voxel_size: float) -> np.ndarray: + if not self._key_poses: + return np.empty((0, 3), dtype=np.float32) + parts = [] + for kp in self._key_poses: + world = (kp.r_global @ kp.body_cloud.T).T + kp.t_global + parts.append(world) + cloud = np.vstack(parts).astype(np.float32) + return _voxel_downsample(cloud, voxel_size) + + @property + def num_key_poses(self) -> int: + return len(self._key_poses) + + +# ─── PGO Module ────────────────────────────────────────────────────────────── + + +class PGO(Module[PGOConfig]): + """Pose graph optimization with loop closure detection. + + Pure-Python implementation using GTSAM iSAM2 and scipy KDTree. + Detects keyframes from odometry, searches for loop closures, + optimizes with iSAM2, and publishes corrected poses + global map. + + Ports: + registered_scan (In[PointCloud2]): World-frame registered point cloud. + odometry (In[Odometry]): Current pose estimate from SLAM. + corrected_odometry (Out[Odometry]): Loop-closure-corrected pose. + global_map (Out[PointCloud2]): Accumulated keyframe map. + """ + + default_config = PGOConfig + + registered_scan: In[PointCloud2] + odometry: In[Odometry] + corrected_odometry: Out[Odometry] + global_map: Out[PointCloud2] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._lock = threading.Lock() + self._running = False + self._thread: threading.Thread | None = None + self._pgo: _SimplePGO | None = None + # Latest odom + self._latest_r = np.eye(3) + self._latest_t = np.zeros(3) + self._latest_time = 0.0 + self._has_odom = False + self._last_global_map_time = 0.0 + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + for k in ("_lock", "_thread", "_pgo"): + state.pop(k, None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._lock = threading.Lock() + self._thread = None + self._pgo = None + + def start(self) -> None: + self._pgo = _SimplePGO(self.config) + self.odometry._transport.subscribe(self._on_odom) + self.registered_scan._transport.subscribe(self._on_scan) + self._running = True + self._thread = threading.Thread(target=self._publish_loop, daemon=True) + self._thread.start() + print("[PGO] Python PGO module started (gtsam iSAM2)") + + def stop(self) -> None: + self._running = False + if self._thread: + self._thread.join(timeout=3.0) + super().stop() + + def _on_odom(self, msg: Odometry) -> None: + from scipy.spatial.transform import Rotation + + q = [ + msg.pose.orientation.x, + msg.pose.orientation.y, + msg.pose.orientation.z, + msg.pose.orientation.w, + ] + r = Rotation.from_quat(q).as_matrix() + t = np.array([msg.pose.position.x, msg.pose.position.y, msg.pose.position.z]) + with self._lock: + self._latest_r = r + self._latest_t = t + self._latest_time = msg.ts if msg.ts else time.time() + self._has_odom = True + + def _on_scan(self, cloud: PointCloud2) -> None: + points, _ = cloud.as_numpy() + if len(points) == 0: + return + + with self._lock: + if not self._has_odom: + return + r_local = self._latest_r.copy() + t_local = self._latest_t.copy() + ts = self._latest_time + + pgo = self._pgo + assert pgo is not None + + # Body-frame points + if self.config.unregister_input: + # registered_scan is world-frame, transform back to body-frame + body_pts = (r_local.T @ (points[:, :3].T - t_local[:, None])).T + else: + body_pts = points[:, :3] + + added = pgo.add_key_pose(r_local, t_local, ts, body_pts) + if added: + pgo.search_for_loops() + pgo.smooth_and_update() + print( + f"[PGO] Keyframe {pgo.num_key_poses} added " + f"({t_local[0]:.1f}, {t_local[1]:.1f}, {t_local[2]:.1f})" + ) + + # Publish corrected odometry + r_corr, t_corr = pgo.get_corrected_pose(r_local, t_local) + self._publish_corrected_odom(r_corr, t_corr, ts) + + def _publish_corrected_odom(self, r: np.ndarray, t: np.ndarray, ts: float) -> None: + from scipy.spatial.transform import Rotation as R + + from dimos.msgs.geometry_msgs.Pose import Pose + + q = R.from_matrix(r).as_quat() # [x,y,z,w] + + odom = Odometry( + ts=ts, + frame_id="map", + child_frame_id="sensor", + pose=Pose( + position=[float(t[0]), float(t[1]), float(t[2])], + orientation=[float(q[0]), float(q[1]), float(q[2]), float(q[3])], + ), + ) + self.corrected_odometry._transport.publish(odom) + + def _publish_loop(self) -> None: + """Periodically publish global map.""" + pgo = self._pgo + assert pgo is not None + rate = self.config.global_map_publish_rate + interval = 1.0 / rate if rate > 0 else 2.0 + + while self._running: + t0 = time.monotonic() + now = time.time() + + if now - self._last_global_map_time > interval and pgo.num_key_poses > 0: + cloud_np = pgo.build_global_map(self.config.global_map_voxel_size) + if len(cloud_np) > 0: + self.global_map._transport.publish( + PointCloud2.from_numpy(cloud_np, frame_id="map", timestamp=now) + ) + print( + f"[PGO] Global map published: {len(cloud_np)} points, " + f"{pgo.num_key_poses} keyframes" + ) + self._last_global_map_time = now + + elapsed = time.monotonic() - t0 + sleep_time = max(0.1, interval - elapsed) + time.sleep(sleep_time) diff --git a/dimos/navigation/smartnav/modules/pgo/pgo_reference.py b/dimos/navigation/smartnav/modules/pgo/pgo_reference.py new file mode 100644 index 0000000000..dd9d6fb7dd --- /dev/null +++ b/dimos/navigation/smartnav/modules/pgo/pgo_reference.py @@ -0,0 +1,359 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure-Python reference implementation of the PGO algorithm. + +Uses scipy KDTree for neighbor search, open3d for ICP, and gtsam Python +bindings for pose graph optimization. Tests the LOGIC independent of the +C++ binary. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import gtsam +import numpy as np +import open3d as o3d +from scipy.spatial import KDTree + + +@dataclass +class PGOConfig: + """PGO algorithm configuration.""" + + key_pose_delta_trans: float = 0.5 + key_pose_delta_deg: float = 10.0 + loop_search_radius: float = 15.0 + loop_time_thresh: float = 60.0 + loop_score_thresh: float = 0.3 + loop_submap_half_range: int = 5 + submap_resolution: float = 0.1 + min_loop_detect_duration: float = 5.0 + global_map_voxel_size: float = 0.15 + max_icp_iterations: int = 50 + max_icp_correspondence_dist: float = 10.0 + + +@dataclass +class KeyPose: + """Stored keyframe with local and global poses.""" + + r_local: np.ndarray # 3x3 rotation matrix (local/odometry frame) + t_local: np.ndarray # 3 translation vector (local/odometry frame) + r_global: np.ndarray # 3x3 rotation matrix (optimized global frame) + t_global: np.ndarray # 3 translation vector (optimized global frame) + time: float # timestamp + body_cloud: np.ndarray # Nx3 point cloud in body frame + + +@dataclass +class LoopPair: + """Detected loop closure between two keyframes.""" + + source_id: int + target_id: int + r_offset: np.ndarray # 3x3 relative rotation + t_offset: np.ndarray # 3 relative translation + score: float + + +def _rotation_to_quat(R: np.ndarray) -> np.ndarray: + """Convert 3x3 rotation matrix to quaternion [x,y,z,w].""" + from scipy.spatial.transform import Rotation + + return Rotation.from_matrix(R).as_quat() # [x,y,z,w] + + +def _angular_distance_deg(R1: np.ndarray, R2: np.ndarray) -> float: + """Compute angular distance in degrees between two rotation matrices.""" + R_diff = R1.T @ R2 + # Clamp to avoid numerical issues with arccos + trace = np.clip((np.trace(R_diff) - 1.0) / 2.0, -1.0, 1.0) + return np.degrees(np.arccos(trace)) + + +def _voxel_downsample(points: np.ndarray, voxel_size: float) -> np.ndarray: + """Voxel-grid downsample an Nx3 point cloud.""" + if len(points) == 0 or voxel_size <= 0: + return points + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points.astype(np.float64)) + pcd = pcd.voxel_down_sample(voxel_size) + return np.asarray(pcd.points) + + +class SimplePGOReference: + """Pure-Python reference implementation of SimplePGO. + + Mirrors the C++ SimplePGO class for testing purposes. + """ + + def __init__(self, config: PGOConfig | None = None) -> None: + self.config = config or PGOConfig() + self.key_poses: list[KeyPose] = [] + self.history_pairs: list[tuple[int, int]] = [] + self._cache_pairs: list[LoopPair] = [] + self._r_offset = np.eye(3) + self._t_offset = np.zeros(3) + + # GTSAM iSAM2 + params = gtsam.ISAM2Params() + params.setRelinearizeThreshold(0.01) + params.relinearizeSkip = 1 + self._isam2 = gtsam.ISAM2(params) + self._graph = gtsam.NonlinearFactorGraph() + self._initial_values = gtsam.Values() + + def is_key_pose(self, r: np.ndarray, t: np.ndarray) -> bool: + """Check if a pose qualifies as a new keyframe.""" + if len(self.key_poses) == 0: + return True + last = self.key_poses[-1] + delta_trans = np.linalg.norm(t - last.t_local) + delta_deg = _angular_distance_deg(last.r_local, r) + return ( + delta_trans > self.config.key_pose_delta_trans + or delta_deg > self.config.key_pose_delta_deg + ) + + def add_key_pose( + self, r_local: np.ndarray, t_local: np.ndarray, timestamp: float, body_cloud: np.ndarray + ) -> bool: + """Add a keyframe if it passes the keyframe test. Returns True if added.""" + if not self.is_key_pose(r_local, t_local): + return False + + idx = len(self.key_poses) + init_r = self._r_offset @ r_local + init_t = self._r_offset @ t_local + self._t_offset + + # Add initial value to GTSAM + pose = gtsam.Pose3(gtsam.Rot3(init_r), gtsam.Point3(init_t)) + self._initial_values.insert(idx, pose) + + if idx == 0: + # Prior factor + noise = gtsam.noiseModel.Diagonal.Variances(np.ones(6) * 1e-12) + self._graph.addPriorPose3(idx, pose, noise) + else: + # Odometry (between) factor + last = self.key_poses[-1] + r_between = last.r_local.T @ r_local + t_between = last.r_local.T @ (t_local - last.t_local) + noise = gtsam.noiseModel.Diagonal.Variances( + np.array([1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-6]) + ) + delta = gtsam.Pose3(gtsam.Rot3(r_between), gtsam.Point3(t_between)) + self._graph.add(gtsam.BetweenFactorPose3(idx - 1, idx, delta, noise)) + + kp = KeyPose( + r_local=r_local.copy(), + t_local=t_local.copy(), + r_global=init_r.copy(), + t_global=init_t.copy(), + time=timestamp, + body_cloud=body_cloud.copy() if len(body_cloud) > 0 else body_cloud, + ) + self.key_poses.append(kp) + return True + + def get_submap(self, idx: int, half_range: int, resolution: float) -> np.ndarray: + """Build a submap around a keyframe by transforming nearby body clouds.""" + min_idx = max(0, idx - half_range) + max_idx = min(len(self.key_poses) - 1, idx + half_range) + + all_pts = [] + for i in range(min_idx, max_idx + 1): + kp = self.key_poses[i] + if len(kp.body_cloud) == 0: + continue + # Transform body cloud to global frame + global_pts = (kp.r_global @ kp.body_cloud.T).T + kp.t_global + all_pts.append(global_pts) + + if not all_pts: + return np.zeros((0, 3)) + combined = np.vstack(all_pts) + if resolution > 0: + combined = _voxel_downsample(combined, resolution) + return combined + + def search_for_loop_pairs(self) -> None: + """Search for loop closure candidates using KD-tree radius search + ICP.""" + if len(self.key_poses) < 10: + return + + # Rate limiting + if self.config.min_loop_detect_duration > 0.0 and self.history_pairs: + current_time = self.key_poses[-1].time + last_time = self.key_poses[self.history_pairs[-1][1]].time + if current_time - last_time < self.config.min_loop_detect_duration: + return + + cur_idx = len(self.key_poses) - 1 + last_item = self.key_poses[-1] + + # Build KD-tree from all previous keyframe positions + positions = np.array([kp.t_global for kp in self.key_poses[:-1]]) + kdtree = KDTree(positions) + + # Radius search + indices = kdtree.query_ball_point(last_item.t_global, self.config.loop_search_radius) + if not indices: + return + + # Sort by distance + dists = [np.linalg.norm(last_item.t_global - positions[i]) for i in indices] + sorted_indices = [indices[i] for i in np.argsort(dists)] + + # Find candidate far enough in time + loop_idx = -1 + for idx in sorted_indices: + if abs(last_item.time - self.key_poses[idx].time) > self.config.loop_time_thresh: + loop_idx = idx + break + + if loop_idx == -1: + return + + # ICP verification + target_cloud = self.get_submap( + loop_idx, self.config.loop_submap_half_range, self.config.submap_resolution + ) + source_cloud = self.get_submap(cur_idx, 0, self.config.submap_resolution) + + if len(target_cloud) < 10 or len(source_cloud) < 10: + return + + transform, score = self._run_icp(source_cloud, target_cloud) + if score > self.config.loop_score_thresh: + return + + # Compute loop closure constraint + r_transform = transform[:3, :3] + t_transform = transform[:3, 3] + r_refined = r_transform @ self.key_poses[cur_idx].r_global + t_refined = r_transform @ self.key_poses[cur_idx].t_global + t_transform + r_offset = self.key_poses[loop_idx].r_global.T @ r_refined + t_offset = self.key_poses[loop_idx].r_global.T @ ( + t_refined - self.key_poses[loop_idx].t_global + ) + + pair = LoopPair( + source_id=cur_idx, + target_id=loop_idx, + r_offset=r_offset, + t_offset=t_offset, + score=score, + ) + self._cache_pairs.append(pair) + self.history_pairs.append((loop_idx, cur_idx)) + + def _run_icp(self, source: np.ndarray, target: np.ndarray) -> tuple[np.ndarray, float]: + """Run ICP between source and target point clouds. + + Returns (4x4 transform, fitness score). + """ + src_pcd = o3d.geometry.PointCloud() + src_pcd.points = o3d.utility.Vector3dVector(source.astype(np.float64)) + tgt_pcd = o3d.geometry.PointCloud() + tgt_pcd.points = o3d.utility.Vector3dVector(target.astype(np.float64)) + + result = o3d.pipelines.registration.registration_icp( + src_pcd, + tgt_pcd, + max_correspondence_distance=self.config.max_icp_correspondence_dist, + init=np.eye(4), + estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(), + criteria=o3d.pipelines.registration.ICPConvergenceCriteria( + max_iteration=self.config.max_icp_iterations, + ), + ) + # Reject matches with zero/near-zero correspondences (fitness=0 means + # no points were within max_correspondence_distance). In this case + # inlier_rmse is 0.0 which would incorrectly pass the score threshold. + if result.fitness < 0.05: + return result.transformation, float("inf") + return result.transformation, result.inlier_rmse + + def smooth_and_update(self) -> None: + """Run iSAM2 optimization and update keyframe poses.""" + has_loop = len(self._cache_pairs) > 0 + + # Add loop closure factors + if has_loop: + for pair in self._cache_pairs: + noise = gtsam.noiseModel.Diagonal.Variances(np.ones(6) * pair.score) + delta = gtsam.Pose3(gtsam.Rot3(pair.r_offset), gtsam.Point3(pair.t_offset)) + self._graph.add( + gtsam.BetweenFactorPose3(pair.target_id, pair.source_id, delta, noise) + ) + self._cache_pairs.clear() + + # iSAM2 update + self._isam2.update(self._graph, self._initial_values) + self._isam2.update() + if has_loop: + for _ in range(4): + self._isam2.update() + self._graph = gtsam.NonlinearFactorGraph() + self._initial_values = gtsam.Values() + + # Update keyframe poses from estimates + estimates = self._isam2.calculateBestEstimate() + for i in range(len(self.key_poses)): + pose = estimates.atPose3(i) + self.key_poses[i].r_global = pose.rotation().matrix() + self.key_poses[i].t_global = pose.translation() + + # Update offset + last = self.key_poses[-1] + self._r_offset = last.r_global @ last.r_local.T + self._t_offset = last.t_global - self._r_offset @ last.t_local + + def get_corrected_pose( + self, r_local: np.ndarray, t_local: np.ndarray + ) -> tuple[np.ndarray, np.ndarray]: + """Get corrected pose for a local pose.""" + r_corrected = self._r_offset @ r_local + t_corrected = self._r_offset @ t_local + self._t_offset + return r_corrected, t_corrected + + def build_global_map(self, voxel_size: float | None = None) -> np.ndarray: + """Build global map from all corrected keyframes.""" + if voxel_size is None: + voxel_size = self.config.global_map_voxel_size + + all_pts = [] + for kp in self.key_poses: + if len(kp.body_cloud) == 0: + continue + global_pts = (kp.r_global @ kp.body_cloud.T).T + kp.t_global + all_pts.append(global_pts) + + if not all_pts: + return np.zeros((0, 3)) + combined = np.vstack(all_pts) + if voxel_size > 0: + combined = _voxel_downsample(combined, voxel_size) + return combined + + @property + def r_offset(self) -> np.ndarray: + return self._r_offset + + @property + def t_offset(self) -> np.ndarray: + return self._t_offset diff --git a/dimos/navigation/smartnav/modules/pgo/test_pgo.py b/dimos/navigation/smartnav/modules/pgo/test_pgo.py new file mode 100644 index 0000000000..2ebe6c8f1a --- /dev/null +++ b/dimos/navigation/smartnav/modules/pgo/test_pgo.py @@ -0,0 +1,561 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for PGO (Pose Graph Optimization) module. + +Tests the Python reference implementation of the PGO algorithm, covering: +- Keyframe detection +- Loop closure detection and correction +- Global map accumulation +- ICP matching +- Edge cases +""" + +from __future__ import annotations + +import math +import time + +import numpy as np +import pytest + +try: + import gtsam # noqa: F401 + import open3d as o3d + from scipy.spatial.transform import Rotation + + from dimos.navigation.smartnav.modules.pgo.pgo_reference import PGOConfig, SimplePGOReference + + _HAS_PGO_DEPS = True +except ImportError: + _HAS_PGO_DEPS = False + +pytestmark = pytest.mark.skipif(not _HAS_PGO_DEPS, reason="gtsam/open3d not installed") + +# ─── Helper functions ───────────────────────────────────────────────────────── + + +def make_rotation(yaw_deg: float) -> np.ndarray: + """Create a 3x3 rotation matrix from a yaw angle in degrees.""" + return Rotation.from_euler("z", yaw_deg, degrees=True).as_matrix() + + +def make_random_cloud( + center: np.ndarray, n_points: int = 200, spread: float = 1.0, seed: int | None = None +) -> np.ndarray: + """Create a random Nx3 point cloud around a center point.""" + rng = np.random.default_rng(seed) + return center + rng.normal(0, spread, (n_points, 3)) + + +def make_box_cloud( + center: np.ndarray, size: float = 2.0, n_points: int = 500, seed: int | None = None +) -> np.ndarray: + """Create a uniform-random box-shaped point cloud.""" + rng = np.random.default_rng(seed) + pts = rng.uniform(-size / 2, size / 2, (n_points, 3)) + return pts + center + + +def make_structured_cloud(center: np.ndarray, n_points: int = 500, seed: int = 42) -> np.ndarray: + """Create a structured point cloud (sphere surface) around a center.""" + rng = np.random.default_rng(seed) + phi = rng.uniform(0, 2 * np.pi, n_points) + theta = rng.uniform(0, np.pi, n_points) + r = 2.0 + x = r * np.sin(theta) * np.cos(phi) + center[0] + y = r * np.sin(theta) * np.sin(phi) + center[1] + z = r * np.cos(theta) + center[2] + return np.column_stack([x, y, z]) + + +# ─── Keyframe Detection Tests ──────────────────────────────────────────────── + + +class TestKeyframeDetection: + """Test keyframe selection logic.""" + + def test_first_pose_is_always_keyframe(self): + """The very first pose should always be accepted as a keyframe.""" + pgo = SimplePGOReference() + cloud = make_random_cloud(np.zeros(3), seed=0) + result = pgo.add_key_pose(np.eye(3), np.zeros(3), 0.0, cloud) + assert result is True + assert len(pgo.key_poses) == 1 + + def test_small_movement_not_keyframe(self): + """A pose very close to the last keyframe should be rejected.""" + pgo = SimplePGOReference(PGOConfig(key_pose_delta_trans=0.5, key_pose_delta_deg=10.0)) + cloud = make_random_cloud(np.zeros(3), seed=0) + + # Add first keyframe + pgo.add_key_pose(np.eye(3), np.zeros(3), 0.0, cloud) + pgo.smooth_and_update() + + # Try to add a pose with tiny movement (0.1m, 0 rotation) + result = pgo.add_key_pose(np.eye(3), np.array([0.1, 0.0, 0.0]), 1.0, cloud) + assert result is False + assert len(pgo.key_poses) == 1 + + def test_translation_threshold_triggers_keyframe(self): + """A pose exceeding the translation threshold should be a keyframe.""" + pgo = SimplePGOReference(PGOConfig(key_pose_delta_trans=0.5, key_pose_delta_deg=10.0)) + cloud = make_random_cloud(np.zeros(3), seed=0) + + pgo.add_key_pose(np.eye(3), np.zeros(3), 0.0, cloud) + pgo.smooth_and_update() + + # Move 0.6m (exceeds 0.5m threshold) + result = pgo.add_key_pose(np.eye(3), np.array([0.6, 0.0, 0.0]), 1.0, cloud) + assert result is True + assert len(pgo.key_poses) == 2 + + def test_rotation_threshold_triggers_keyframe(self): + """A pose exceeding the rotation threshold should be a keyframe.""" + pgo = SimplePGOReference(PGOConfig(key_pose_delta_trans=0.5, key_pose_delta_deg=10.0)) + cloud = make_random_cloud(np.zeros(3), seed=0) + + pgo.add_key_pose(np.eye(3), np.zeros(3), 0.0, cloud) + pgo.smooth_and_update() + + # Rotate 15 degrees (exceeds 10 degree threshold), no translation + r_rotated = make_rotation(15.0) + result = pgo.add_key_pose(r_rotated, np.zeros(3), 1.0, cloud) + assert result is True + assert len(pgo.key_poses) == 2 + + +# ─── Loop Closure Tests ────────────────────────────────────────────────────── + + +class TestLoopClosure: + """Test loop closure detection and correction.""" + + def _build_square_trajectory( + self, + pgo: SimplePGOReference, + side_length: float = 20.0, + step: float = 0.4, + time_per_step: float = 1.0, + ) -> None: + """Drive a square trajectory, returning to near the start. + + Generates keyframes along a square path with consistent point clouds + at each pose. Calls search_for_loop_pairs() on each keyframe. + """ + t = 0.0 + positions = [] + + # Generate waypoints along a square + for direction in range(4): + yaw = direction * 90.0 + r = make_rotation(yaw) + dx = step * math.cos(math.radians(yaw)) + dy = step * math.sin(math.radians(yaw)) + n_steps = int(side_length / step) + + for _s in range(n_steps): + if not positions: + pos = np.array([0.0, 0.0, 0.0]) + else: + pos = positions[-1] + np.array([dx, dy, 0.0]) + positions.append(pos) + + cloud = make_structured_cloud(np.zeros(3), n_points=300, seed=int(t) % 1000) + added = pgo.add_key_pose(r, pos, t, cloud) + if added: + pgo.search_for_loop_pairs() + pgo.smooth_and_update() + t += time_per_step + + def test_loop_closure_detected_on_revisit(self): + """Square trajectory returning to start should detect a loop closure.""" + config = PGOConfig( + key_pose_delta_trans=0.4, + key_pose_delta_deg=10.0, + loop_search_radius=15.0, + loop_time_thresh=30.0, + loop_score_thresh=1.0, # Relaxed for structured clouds + loop_submap_half_range=3, + submap_resolution=0.2, + min_loop_detect_duration=0.0, + max_icp_iterations=30, + max_icp_correspondence_dist=15.0, + ) + pgo = SimplePGOReference(config) + self._build_square_trajectory(pgo, side_length=20.0, step=0.4, time_per_step=1.0) + + # The robot should have gone around a 20m square and come back near start + # With ~200 keyframes and loop_time_thresh=30, the start keyframes + # are far enough in time. Loop closure should be detected. + assert len(pgo.history_pairs) > 0, ( + f"No loop closure detected with {len(pgo.key_poses)} keyframes. " + f"Start pos: {pgo.key_poses[0].t_global}, " + f"End pos: {pgo.key_poses[-1].t_global}" + ) + + def test_no_false_loop_closure(self): + """Straight-line trajectory should NOT detect any loop closures.""" + config = PGOConfig( + key_pose_delta_trans=0.4, + key_pose_delta_deg=10.0, + loop_search_radius=5.0, + loop_time_thresh=30.0, + loop_score_thresh=0.3, + min_loop_detect_duration=0.0, + ) + pgo = SimplePGOReference(config) + + # Drive in a straight line — no revisiting + r = np.eye(3) + for i in range(100): + pos = np.array([i * 0.5, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=100, seed=i) + added = pgo.add_key_pose(r, pos, float(i), cloud) + if added: + pgo.search_for_loop_pairs() + pgo.smooth_and_update() + + assert len(pgo.history_pairs) == 0, "False loop closure on straight line" + + def test_loop_closure_respects_time_threshold(self): + """Nearby poses that are close in TIME should NOT trigger loop closure.""" + config = PGOConfig( + key_pose_delta_trans=0.3, + key_pose_delta_deg=10.0, + loop_search_radius=20.0, + loop_time_thresh=60.0, # Very high time threshold + loop_score_thresh=1.0, + min_loop_detect_duration=0.0, + ) + pgo = SimplePGOReference(config) + + # Build a trajectory where robot goes and comes back quickly + # Time stamps are close together (1s apart), so loop_time_thresh=60 blocks detection + r = np.eye(3) + for i in range(20): + pos = np.array([i * 0.5, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=100, seed=i) + pgo.add_key_pose(r, pos, float(i), cloud) + pgo.smooth_and_update() + + # Come back to start + for i in range(20): + pos = np.array([(19 - i) * 0.5, 0.1, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=100, seed=i + 100) + added = pgo.add_key_pose(r, pos, float(20 + i), cloud) + if added: + pgo.search_for_loop_pairs() + pgo.smooth_and_update() + + # Should NOT detect loop because total time ~40s < 60s threshold + assert len(pgo.history_pairs) == 0, "Loop closure triggered despite time threshold not met" + + def test_loop_closure_corrects_drift(self): + """After loop closure, corrected poses should be closer to ground truth.""" + config = PGOConfig( + key_pose_delta_trans=0.4, + key_pose_delta_deg=10.0, + loop_search_radius=15.0, + loop_time_thresh=20.0, + loop_score_thresh=2.0, # Very relaxed + loop_submap_half_range=3, + submap_resolution=0.2, + min_loop_detect_duration=0.0, + max_icp_iterations=30, + max_icp_correspondence_dist=20.0, + ) + pgo = SimplePGOReference(config) + + # Build a circular trajectory with drift + n_keyframes = 80 + radius = 10.0 + drift_per_step = np.array([0.01, 0.005, 0.0]) # Accumulated drift + + ground_truth_positions = [] + for i in range(n_keyframes): + angle = 2 * math.pi * i / n_keyframes + gt_x = radius * math.cos(angle) + gt_y = radius * math.sin(angle) + ground_truth_positions.append(np.array([gt_x, gt_y, 0.0])) + + # Add drift to odometry + drift = drift_per_step * i + drifted_pos = np.array([gt_x, gt_y, 0.0]) + drift + yaw = angle + math.pi / 2 # Tangent direction + r = Rotation.from_euler("z", yaw).as_matrix() + + cloud = make_structured_cloud( + np.zeros(3), n_points=200, seed=i % 50 + ) # Reuse clouds for loop match + t_sec = float(i) * 1.0 # 1 second per step + added = pgo.add_key_pose(r, drifted_pos, t_sec, cloud) + if added: + pgo.search_for_loop_pairs() + pgo.smooth_and_update() + + # Compute drift at end (before any correction) + start_pos = pgo.key_poses[0].t_global + end_pos = pgo.key_poses[-1].t_global + gt_start = ground_truth_positions[0] + gt_end = ground_truth_positions[-1] + + # The positions should be reasonably close to ground truth + # (exact correction depends on ICP quality, but optimization should help) + # At minimum, the system should have run without crashing + assert len(pgo.key_poses) > 0 + assert len(pgo.key_poses) >= 10 + + # If loop closure was detected, check that it improved things + if len(pgo.history_pairs) > 0: + # The start and end should be closer together after optimization + # (they're near the same ground-truth position on a circle) + dist_start_end = np.linalg.norm(end_pos - start_pos) + gt_dist = np.linalg.norm(gt_end - gt_start) + # After loop closure correction, distance should be reasonable + # (ICP on synthetic data can only do so much, relax threshold) + assert dist_start_end < 10.0, ( + f"After loop closure, start-end distance {dist_start_end:.2f}m " + f"is too large (gt: {gt_dist:.2f}m)" + ) + + +# ─── Global Map Tests ──────────────────────────────────────────────────────── + + +class TestGlobalMap: + """Test global map accumulation and publishing.""" + + def test_global_map_accumulates_keyframes(self): + """Global map should contain points from all keyframes.""" + pgo = SimplePGOReference( + PGOConfig( + key_pose_delta_trans=0.3, + global_map_voxel_size=0.0, # No downsampling + ) + ) + + n_keyframes = 5 + pts_per_frame = 50 + for i in range(n_keyframes): + pos = np.array([i * 1.0, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=pts_per_frame, seed=i) + pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + pgo.smooth_and_update() + + assert len(pgo.key_poses) == n_keyframes + + global_map = pgo.build_global_map(voxel_size=0.0) + # Should have points from all keyframes + assert len(global_map) == n_keyframes * pts_per_frame + + def test_global_map_updates_after_loop_closure(self): + """After loop closure correction, global map positions should shift.""" + config = PGOConfig( + key_pose_delta_trans=0.3, + loop_search_radius=15.0, + loop_time_thresh=5.0, + loop_score_thresh=2.0, + min_loop_detect_duration=0.0, + global_map_voxel_size=0.0, + max_icp_correspondence_dist=20.0, + ) + pgo = SimplePGOReference(config) + + # Add enough keyframes for a trajectory + for i in range(15): + pos = np.array([i * 0.5, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=50, seed=i % 3) + pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + pgo.smooth_and_update() + + map_before = pgo.build_global_map(voxel_size=0.0) + assert len(map_before) > 0 + + # Inject a synthetic loop closure factor between first and last keyframe + # to force the optimizer to shift poses + if len(pgo.key_poses) >= 2: + from dimos.navigation.smartnav.modules.pgo.pgo_reference import LoopPair + + pgo._cache_pairs.append( + LoopPair( + source_id=len(pgo.key_poses) - 1, + target_id=0, + r_offset=np.eye(3), + t_offset=np.zeros(3), + score=0.1, + ) + ) + pgo.smooth_and_update() + + map_after = pgo.build_global_map(voxel_size=0.0) + assert len(map_after) > 0 + # After loop closure, positions should have shifted + # (the optimizer pulls the last keyframe toward the first) + diff = np.abs(map_after - map_before).sum() + assert diff > 0.0, "Global map should change after loop closure" + + def test_global_map_is_published_as_pointcloud(self): + """Global map should produce a valid numpy array that can become PointCloud2.""" + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + pgo = SimplePGOReference(PGOConfig(key_pose_delta_trans=0.3)) + + for i in range(3): + pos = np.array([i * 1.0, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=100, seed=i) + pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + pgo.smooth_and_update() + + global_map = pgo.build_global_map() + assert len(global_map) > 0 + + # Convert to PointCloud2 — verify it's valid + pc2 = PointCloud2.from_numpy( + global_map.astype(np.float32), frame_id="map", timestamp=time.time() + ) + points_back, _ = pc2.as_numpy() + assert len(points_back) > 0 + assert points_back.shape[1] >= 3 + + +# ─── ICP Tests ──────────────────────────────────────────────────────────────── + + +class TestICP: + """Test ICP matching functionality.""" + + def test_icp_matches_identical_clouds(self): + """ICP between two identical clouds should return identity transform.""" + pgo = SimplePGOReference() + cloud = make_structured_cloud(np.zeros(3), n_points=500, seed=42) + + transform, score = pgo._run_icp(cloud, cloud) + # Transform should be near identity + np.testing.assert_allclose(transform[:3, :3], np.eye(3), atol=0.1) + np.testing.assert_allclose(transform[:3, 3], np.zeros(3), atol=0.1) + assert score < 0.1 + + def test_icp_matches_translated_cloud(self): + """ICP should find the correct translation between shifted clouds.""" + pgo = SimplePGOReference(PGOConfig(max_icp_correspondence_dist=5.0)) + cloud = make_structured_cloud(np.zeros(3), n_points=500, seed=42) + shifted = cloud + np.array([1.0, 0.0, 0.0]) + + transform, score = pgo._run_icp(shifted, cloud) + # The transform should move the shifted cloud back toward the original + estimated_translation = transform[:3, 3] + assert abs(estimated_translation[0] - (-1.0)) < 0.5, ( + f"Expected ~-1.0 x-translation, got {estimated_translation[0]:.3f}" + ) + + def test_icp_rejects_dissimilar_clouds(self): + """ICP between very different clouds should fail to match.""" + SimplePGOReference(PGOConfig(max_icp_correspondence_dist=2.0)) + + # Two clouds in completely different locations + cloud_a = make_structured_cloud(np.array([0.0, 0.0, 0.0]), n_points=200, seed=1) + cloud_b = make_structured_cloud(np.array([100.0, 100.0, 0.0]), n_points=200, seed=2) + + result = o3d.pipelines.registration.registration_icp( + o3d.geometry.PointCloud(o3d.utility.Vector3dVector(cloud_a)), + o3d.geometry.PointCloud(o3d.utility.Vector3dVector(cloud_b)), + max_correspondence_distance=2.0, + init=np.eye(4), + estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(), + criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=30), + ) + # With max_correspondence_dist=2.0 and clouds 141m apart, + # O3D finds zero correspondences → fitness=0 + assert result.fitness == 0.0, ( + f"Expected zero fitness (no correspondences), got {result.fitness}" + ) + + +# ─── Edge Case Tests ───────────────────────────────────────────────────────── + + +class TestEdgeCases: + """Test edge cases and robustness.""" + + def test_empty_cloud_handled(self): + """Adding a keyframe with an empty cloud should not crash.""" + pgo = SimplePGOReference() + empty_cloud = np.zeros((0, 3)) + result = pgo.add_key_pose(np.eye(3), np.zeros(3), 0.0, empty_cloud) + assert result is True # First pose is always a keyframe + pgo.smooth_and_update() + + # Global map from empty keyframe + global_map = pgo.build_global_map() + assert len(global_map) == 0 + + def test_single_keyframe_no_crash(self): + """System should work with just a single keyframe, no crash.""" + pgo = SimplePGOReference() + cloud = make_random_cloud(np.zeros(3), n_points=100, seed=0) + pgo.add_key_pose(np.eye(3), np.zeros(3), 0.0, cloud) + pgo.smooth_and_update() + + # These should all work without crashing + assert len(pgo.key_poses) == 1 + global_map = pgo.build_global_map() + assert len(global_map) > 0 + r, t = pgo.get_corrected_pose(np.eye(3), np.zeros(3)) + np.testing.assert_allclose(r, np.eye(3), atol=1e-6) + np.testing.assert_allclose(t, np.zeros(3), atol=1e-6) + + # Loop search with single keyframe should not crash + pgo.search_for_loop_pairs() + assert len(pgo.history_pairs) == 0 + + +# ─── Python Wrapper Port Tests ─────────────────────────────────────────────── + + +class TestPGOWrapper: + """Test the Python NativeModule wrapper (port definitions).""" + + def test_pgo_module_has_correct_ports(self): + """PGO module should declare the right input/output ports.""" + from dimos.navigation.smartnav.modules.pgo.pgo import PGO + + # Check class annotations for port definitions + annotations = PGO.__annotations__ + assert "registered_scan" in annotations + assert "odometry" in annotations + assert "corrected_odometry" in annotations + assert "global_map" in annotations + + def test_pgo_config_defaults(self): + """PGO config should have sensible defaults.""" + from dimos.navigation.smartnav.modules.pgo.pgo import PGOConfig + + # NativeModuleConfig is Pydantic; check model_fields for defaults + fields = PGOConfig.model_fields + assert fields["key_pose_delta_trans"].default == 0.5 + assert fields["key_pose_delta_deg"].default == 10.0 + assert fields["loop_search_radius"].default == 15.0 + assert fields["loop_score_thresh"].default == 0.3 + assert fields["global_map_voxel_size"].default == 0.15 + assert "pgo" in fields["executable"].default + + def test_pgo_config_build_command(self): + """PGO config should specify nix build command.""" + from dimos.navigation.smartnav.modules.pgo.pgo import PGOConfig + + fields = PGOConfig.model_fields + assert fields["build_command"].default is not None + assert "nix build" in fields["build_command"].default + assert "pgo" in fields["build_command"].default diff --git a/dimos/navigation/smartnav/modules/sensor_scan_generation/sensor_scan_generation.py b/dimos/navigation/smartnav/modules/sensor_scan_generation/sensor_scan_generation.py new file mode 100644 index 0000000000..2c6dd161ce --- /dev/null +++ b/dimos/navigation/smartnav/modules/sensor_scan_generation/sensor_scan_generation.py @@ -0,0 +1,111 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SensorScanGeneration: transforms registered (world-frame) point cloud to sensor frame. + +Ported from sensorScanGeneration.cpp. Takes Odometry + PointCloud2 (world-frame), +computes inverse transform, outputs sensor-frame point cloud. +""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class SensorScanGeneration(Module): + """Transform registered world-frame point cloud into sensor frame. + + Ports: + registered_scan (In[PointCloud2]): World-frame registered point cloud from SLAM. + odometry (In[Odometry]): Vehicle state estimation. + sensor_scan (Out[PointCloud2]): Sensor-frame point cloud. + odometry_at_scan (Out[Odometry]): Odometry republished with scan timestamp. + """ + + registered_scan: In[PointCloud2] + odometry: In[Odometry] + sensor_scan: Out[PointCloud2] + odometry_at_scan: Out[Odometry] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._latest_odom: Odometry | None = None + self._lock = threading.Lock() + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + state.pop("_lock", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._lock = threading.Lock() + + def start(self) -> None: + self.odometry._transport.subscribe(self._on_odometry) + self.registered_scan._transport.subscribe(self._on_scan) + + def stop(self) -> None: + super().stop() + + def _on_odometry(self, odom: Odometry) -> None: + with self._lock: + self._latest_odom = odom + + def _on_scan(self, cloud: PointCloud2) -> None: + with self._lock: + odom = self._latest_odom + + if odom is None: + return + + try: + # Build transform from odometry (map -> sensor) + tf_map_to_sensor = Transform( + translation=Vector3(odom.x, odom.y, odom.z), + rotation=odom.orientation, + frame_id="map", + child_frame_id="sensor", + ) + + # Inverse transform: sensor -> map (transforms world points into sensor frame) + tf_sensor_to_map = tf_map_to_sensor.inverse() + + # Transform the point cloud into sensor frame + sensor_cloud = cloud.transform(tf_sensor_to_map) + sensor_cloud.frame_id = "sensor_at_scan" + + # Publish sensor-frame cloud + self.sensor_scan._transport.publish(sensor_cloud) + + # Republish odometry with scan timestamp + odom_at_scan = Odometry( + ts=cloud.ts if cloud.ts is not None else time.time(), + frame_id="map", + child_frame_id="sensor_at_scan", + pose=odom.pose, + twist=odom.twist, + ) + self.odometry_at_scan._transport.publish(odom_at_scan) + except Exception: + pass # Skip malformed messages silently diff --git a/dimos/navigation/smartnav/modules/sensor_scan_generation/test_sensor_scan_generation.py b/dimos/navigation/smartnav/modules/sensor_scan_generation/test_sensor_scan_generation.py new file mode 100644 index 0000000000..1a3732757b --- /dev/null +++ b/dimos/navigation/smartnav/modules/sensor_scan_generation/test_sensor_scan_generation.py @@ -0,0 +1,183 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SensorScanGeneration module.""" + +import math +import time + +import numpy as np +import pytest + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) + + +class _MockTransport: + """Lightweight mock transport that captures published messages.""" + + def __init__(self): + self._messages = [] + self._subscribers = [] + + def publish(self, msg): + self._messages.append(msg) + for cb in self._subscribers: + cb(msg) + + def broadcast(self, _stream, msg): + self.publish(msg) + + def subscribe(self, cb): + self._subscribers.append(cb) + + def unsub(): + self._subscribers.remove(cb) + + return unsub + + +def make_pointcloud(points: np.ndarray, frame_id: str = "map") -> PointCloud2: + """Create a PointCloud2 from an Nx3 numpy array.""" + return PointCloud2.from_numpy( + points.astype(np.float32), frame_id=frame_id, timestamp=time.time() + ) + + +def make_odometry(x: float, y: float, z: float, yaw: float = 0.0) -> Odometry: + """Create an Odometry message at the given position and yaw.""" + quat = Quaternion.from_euler(Vector3(0.0, 0.0, yaw)) + return Odometry( + ts=time.time(), + frame_id="map", + child_frame_id="sensor", + pose=Pose( + position=[x, y, z], + orientation=[quat.x, quat.y, quat.z, quat.w], + ), + ) + + +def _wire_transports(module): + """Wire mock transports onto all ports of a SensorScanGeneration module.""" + scan_out_transport = _MockTransport() + odom_out_transport = _MockTransport() + module.sensor_scan._transport = scan_out_transport + module.odometry_at_scan._transport = odom_out_transport + return scan_out_transport, odom_out_transport + + +class TestSensorScanGeneration: + """Test SensorScanGeneration module transforms.""" + + @pytest.fixture(autouse=True) + def _create_module(self): + self.module = SensorScanGeneration() + self.scan_t, self.odom_t = _wire_transports(self.module) + yield + self.module.stop() + + def test_identity_transform(self): + """When vehicle is at origin with zero rotation, sensor frame = world frame.""" + odom = make_odometry(0.0, 0.0, 0.0, 0.0) + self.module._on_odometry(odom) + + world_points = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + cloud = make_pointcloud(world_points) + + results = [] + self.scan_t.subscribe(lambda msg: results.append(msg)) + self.module._on_scan(cloud) + + assert len(results) == 1 + sensor_points, _ = results[0].as_numpy() + np.testing.assert_allclose(sensor_points, world_points, atol=1e-4) + + def test_translation_transform(self): + """Points should be shifted by the inverse of the vehicle translation.""" + odom = make_odometry(2.0, 3.0, 0.0, 0.0) + self.module._on_odometry(odom) + + world_points = np.array([[5.0, 3.0, 0.0]]) + cloud = make_pointcloud(world_points) + + results = [] + self.scan_t.subscribe(lambda msg: results.append(msg)) + self.module._on_scan(cloud) + + assert len(results) == 1 + sensor_points, _ = results[0].as_numpy() + np.testing.assert_allclose(sensor_points[0], [3.0, 0.0, 0.0], atol=1e-4) + + def test_rotation_transform(self): + """Points should be rotated by the inverse of the vehicle rotation.""" + odom = make_odometry(0.0, 0.0, 0.0, math.pi / 2) + self.module._on_odometry(odom) + + world_points = np.array([[1.0, 0.0, 0.0]]) + cloud = make_pointcloud(world_points) + + results = [] + self.scan_t.subscribe(lambda msg: results.append(msg)) + self.module._on_scan(cloud) + + assert len(results) == 1 + sensor_points, _ = results[0].as_numpy() + np.testing.assert_allclose(sensor_points[0], [0.0, -1.0, 0.0], atol=1e-4) + + def test_no_odometry_no_output(self): + """If no odometry has been received, no scan should be published.""" + world_points = np.array([[1.0, 0.0, 0.0]]) + cloud = make_pointcloud(world_points) + + results = [] + self.scan_t.subscribe(lambda msg: results.append(msg)) + self.module._on_scan(cloud) + + assert len(results) == 0 + + def test_empty_cloud(self): + """Empty point cloud should produce empty output.""" + odom = make_odometry(0.0, 0.0, 0.0) + self.module._on_odometry(odom) + + cloud = make_pointcloud(np.zeros((0, 3))) + + results = [] + self.scan_t.subscribe(lambda msg: results.append(msg)) + self.module._on_scan(cloud) + + assert len(results) == 1 + assert len(results[0]) == 0 + + def test_odometry_at_scan_published(self): + """Odometry at scan time should be published.""" + odom = make_odometry(1.0, 2.0, 3.0) + self.module._on_odometry(odom) + + cloud = make_pointcloud(np.array([[0.0, 0.0, 0.0]])) + + odom_results = [] + self.odom_t.subscribe(lambda msg: odom_results.append(msg)) + self.module._on_scan(cloud) + + assert len(odom_results) == 1 + assert odom_results[0].frame_id == "map" + assert odom_results[0].child_frame_id == "sensor_at_scan" diff --git a/dimos/navigation/smartnav/modules/tare_planner/tare_planner.py b/dimos/navigation/smartnav/modules/tare_planner/tare_planner.py new file mode 100644 index 0000000000..e61942f491 --- /dev/null +++ b/dimos/navigation/smartnav/modules/tare_planner/tare_planner.py @@ -0,0 +1,62 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TarePlanner NativeModule: C++ frontier-based autonomous exploration planner. + +Ported from tare_planner. Uses sensor coverage planning and frontier detection +to autonomously explore unknown environments. +""" + +from __future__ import annotations + +from dimos.core.native_module import NativeModule, NativeModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class TarePlannerConfig(NativeModuleConfig): + """Config for the TARE planner native module.""" + + cwd: str | None = "." + executable: str = "result/bin/tare_planner" + build_command: str | None = ( + "nix build github:dimensionalOS/dimos-module-tare-planner/v0.1.0 --no-write-lock-file" + ) + + # Exploration parameters + exploration_range: float = 20.0 + update_rate: float = 1.0 + sensor_range: float = 20.0 + + +class TarePlanner(NativeModule): + """TARE planner: frontier-based autonomous exploration. + + Maintains a coverage map and detects frontiers (boundaries between + explored and unexplored space). Plans exploration paths that maximize + information gain. Outputs waypoints for the local planner. + + Ports: + registered_scan (In[PointCloud2]): World-frame point cloud for coverage updates. + odometry (In[Odometry]): Vehicle state. + way_point (Out[PointStamped]): Exploration waypoint for local planner. + """ + + default_config: type[TarePlannerConfig] = TarePlannerConfig # type: ignore[assignment] + + registered_scan: In[PointCloud2] + odometry: In[Odometry] + way_point: Out[PointStamped] diff --git a/dimos/navigation/smartnav/modules/tare_planner/test_tare_planner.py b/dimos/navigation/smartnav/modules/tare_planner/test_tare_planner.py new file mode 100644 index 0000000000..7bc7bf4174 --- /dev/null +++ b/dimos/navigation/smartnav/modules/tare_planner/test_tare_planner.py @@ -0,0 +1,101 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TarePlanner NativeModule wrapper.""" + +from pathlib import Path + +import pytest + +from dimos.navigation.smartnav.modules.tare_planner.tare_planner import ( + TarePlanner, + TarePlannerConfig, +) + + +class TestTarePlannerConfig: + """Test TarePlanner configuration.""" + + def test_default_config(self): + config = TarePlannerConfig() + assert config.exploration_range == 20.0 + assert config.update_rate == 1.0 + assert config.sensor_range == 20.0 + + def test_cli_args_generation(self): + config = TarePlannerConfig( + exploration_range=30.0, + update_rate=2.0, + ) + args = config.to_cli_args() + assert "--exploration_range" in args + assert "30.0" in args + assert "--update_rate" in args + assert "2.0" in args + + +class TestTarePlannerModule: + """Test TarePlanner module declaration.""" + + def test_ports_declared(self): + from typing import get_origin, get_type_hints + + from dimos.core.stream import In, Out + + hints = get_type_hints(TarePlanner) + in_ports = {k for k, v in hints.items() if get_origin(v) is In} + out_ports = {k for k, v in hints.items() if get_origin(v) is Out} + + assert "registered_scan" in in_ports + assert "odometry" in in_ports + assert "way_point" in out_ports + + +@pytest.mark.skipif( + not Path(__file__).resolve().parent.joinpath("result", "bin").exists(), + reason="Native binary not built (run nix build first)", +) +class TestPathResolution: + """Verify native module paths resolve to real filesystem locations.""" + + def _make(self): + m = TarePlanner() + m._resolve_paths() + return m + + def test_cwd_resolves_to_existing_directory(self): + m = self._make() + try: + assert Path(m.config.cwd).exists(), f"cwd does not exist: {m.config.cwd}" + assert Path(m.config.cwd).is_dir() + finally: + m.stop() + + def test_executable_exists(self): + m = self._make() + try: + exe = Path(m.config.executable) + assert exe.exists(), f"Binary not found: {exe}. Run nix build first." + finally: + m.stop() + + def test_cwd_resolves_to_smartnav_root(self): + """cwd should resolve to the smartnav root (where CMakeLists.txt lives).""" + m = self._make() + try: + cwd = Path(m.config.cwd).resolve() + assert (cwd / "CMakeLists.txt").exists(), f"cwd {cwd} is not the smartnav root" + assert (cwd / "flake.nix").exists() + finally: + m.stop() diff --git a/dimos/navigation/smartnav/modules/terrain_analysis/terrain_analysis.py b/dimos/navigation/smartnav/modules/terrain_analysis/terrain_analysis.py new file mode 100644 index 0000000000..e222b03eee --- /dev/null +++ b/dimos/navigation/smartnav/modules/terrain_analysis/terrain_analysis.py @@ -0,0 +1,64 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TerrainAnalysis NativeModule: C++ terrain processing for obstacle detection. + +Ported from terrainAnalysis.cpp. Processes registered point clouds to produce +a terrain cost map with obstacle classification. +""" + +from __future__ import annotations + +from dimos.core.native_module import NativeModule, NativeModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class TerrainAnalysisConfig(NativeModuleConfig): + """Config for the terrain analysis native module.""" + + cwd: str | None = "." + executable: str = "result/bin/terrain_analysis" + build_command: str | None = ( + "nix build github:dimensionalOS/dimos-module-terrain-analysis/v0.1.0 --no-write-lock-file" + ) + + # Terrain analysis parameters + sensor_range: float = 20.0 + obstacle_height_threshold: float = 0.15 + ground_height_threshold: float = 0.1 + voxel_size: float = 0.05 + terrain_voxel_size: float = 1.0 + terrain_voxel_half_width: int = 10 + terrain_voxel_width: int = 21 + + +class TerrainAnalysis(NativeModule): + """Terrain analysis native module for obstacle cost map generation. + + Processes registered point clouds from SLAM to classify terrain as + ground/obstacle, outputting a cost-annotated point cloud. + + Ports: + registered_scan (In[PointCloud2]): World-frame registered point cloud. + odometry (In[Odometry]): Vehicle state for local frame reference. + terrain_map (Out[PointCloud2]): Terrain cost map (intensity=obstacle cost). + """ + + default_config: type[TerrainAnalysisConfig] = TerrainAnalysisConfig # type: ignore[assignment] + + registered_scan: In[PointCloud2] + odometry: In[Odometry] + terrain_map: Out[PointCloud2] diff --git a/dimos/navigation/smartnav/modules/terrain_analysis/test_terrain_analysis.py b/dimos/navigation/smartnav/modules/terrain_analysis/test_terrain_analysis.py new file mode 100644 index 0000000000..223a7bddc2 --- /dev/null +++ b/dimos/navigation/smartnav/modules/terrain_analysis/test_terrain_analysis.py @@ -0,0 +1,104 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TerrainAnalysis NativeModule wrapper.""" + +from pathlib import Path + +import pytest + +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import ( + TerrainAnalysis, + TerrainAnalysisConfig, +) + + +class TestTerrainAnalysisConfig: + """Test TerrainAnalysis configuration.""" + + def test_default_config(self): + """Default config should have sensible values.""" + config = TerrainAnalysisConfig() + assert config.obstacle_height_threshold == 0.15 + assert config.voxel_size == 0.05 + assert config.sensor_range == 20.0 + + def test_cli_args_generation(self): + """Config should generate CLI args for the native binary.""" + config = TerrainAnalysisConfig( + obstacle_height_threshold=0.2, + voxel_size=0.1, + ) + args = config.to_cli_args() + assert "--obstacle_height_threshold" in args + assert "0.2" in args + assert "--voxel_size" in args + assert "0.1" in args + + +class TestTerrainAnalysisModule: + """Test TerrainAnalysis module declaration.""" + + def test_ports_declared(self): + """Module should declare the expected In/Out ports.""" + from typing import get_origin, get_type_hints + + from dimos.core.stream import In, Out + + hints = get_type_hints(TerrainAnalysis) + in_ports = {k for k, v in hints.items() if get_origin(v) is In} + out_ports = {k for k, v in hints.items() if get_origin(v) is Out} + + assert "registered_scan" in in_ports + assert "odometry" in in_ports + assert "terrain_map" in out_ports + + +@pytest.mark.skipif( + not Path(__file__).resolve().parent.joinpath("result", "bin").exists(), + reason="Native binary not built (run nix build first)", +) +class TestPathResolution: + """Verify native module paths resolve to real filesystem locations.""" + + def _make(self): + m = TerrainAnalysis() + m._resolve_paths() + return m + + def test_cwd_resolves_to_existing_directory(self): + m = self._make() + try: + assert Path(m.config.cwd).exists(), f"cwd does not exist: {m.config.cwd}" + assert Path(m.config.cwd).is_dir() + finally: + m.stop() + + def test_executable_exists(self): + m = self._make() + try: + exe = Path(m.config.executable) + assert exe.exists(), f"Binary not found: {exe}. Run nix build first." + finally: + m.stop() + + def test_cwd_resolves_to_smartnav_root(self): + """cwd should resolve to the smartnav root (where CMakeLists.txt lives).""" + m = self._make() + try: + cwd = Path(m.config.cwd).resolve() + assert (cwd / "CMakeLists.txt").exists(), f"cwd {cwd} is not the smartnav root" + assert (cwd / "flake.nix").exists() + finally: + m.stop() diff --git a/dimos/navigation/smartnav/modules/terrain_map_ext/terrain_map_ext.py b/dimos/navigation/smartnav/modules/terrain_map_ext/terrain_map_ext.py new file mode 100644 index 0000000000..76bc877a68 --- /dev/null +++ b/dimos/navigation/smartnav/modules/terrain_map_ext/terrain_map_ext.py @@ -0,0 +1,153 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TerrainMapExt: extended persistent terrain map with time decay. + +Accumulates terrain_map messages from TerrainAnalysis into a larger +rolling voxel grid (~40m radius, 2m voxels, 4s decay). Publishes +the accumulated map as terrain_map_ext for visualization and planning. + +Port of terrain_analysis_ext from the original ROS2 codebase, simplified +to Python using numpy voxel hashing. +""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import numpy as np + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +class TerrainMapExtConfig(ModuleConfig): + """Config for extended terrain map.""" + + voxel_size: float = 0.4 # meters per voxel (coarser than local) + decay_time: float = 8.0 # seconds before points expire + publish_rate: float = 2.0 # Hz + max_range: float = 40.0 # max distance from robot to keep + + +class TerrainMapExt(Module[TerrainMapExtConfig]): + """Extended terrain map with time-decayed voxel accumulation. + + Subscribes to terrain_map (local) and accumulates into a persistent + map that covers a larger area with slower decay. + + Ports: + terrain_map (In[PointCloud2]): Local terrain from TerrainAnalysis. + odometry (In[Odometry]): Vehicle pose for range culling. + terrain_map_ext (Out[PointCloud2]): Extended accumulated terrain. + """ + + default_config = TerrainMapExtConfig + + terrain_map: In[PointCloud2] + odometry: In[Odometry] + terrain_map_ext: Out[PointCloud2] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._lock = threading.Lock() + self._running = False + self._thread: threading.Thread | None = None + # Voxel storage: key=(ix,iy,iz) -> (x, y, z, intensity, timestamp) + self._voxels: dict[tuple[int, int, int], tuple[float, float, float, float, float]] = {} + self._robot_x = 0.0 + self._robot_y = 0.0 + + def __getstate__(self) -> dict[str, Any]: + s = super().__getstate__() + for k in ("_lock", "_thread", "_voxels"): + s.pop(k, None) + return s + + def __setstate__(self, s: dict) -> None: + super().__setstate__(s) + self._lock = threading.Lock() + self._thread = None + self._voxels = {} + + def start(self) -> None: + self.terrain_map._transport.subscribe(self._on_terrain) + self.odometry._transport.subscribe(self._on_odom) + self._running = True + self._thread = threading.Thread(target=self._publish_loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + self._running = False + if self._thread: + self._thread.join(timeout=3.0) + super().stop() + + def _on_odom(self, msg: Odometry) -> None: + with self._lock: + self._robot_x = msg.pose.position.x + self._robot_y = msg.pose.position.y + + def _on_terrain(self, cloud: PointCloud2) -> None: + points, _ = cloud.as_numpy() + if len(points) == 0: + return + + vs = self.config.voxel_size + now = time.time() + + with self._lock: + for i in range(len(points)): + x, y, z = float(points[i, 0]), float(points[i, 1]), float(points[i, 2]) + ix = int(np.floor(x / vs)) + iy = int(np.floor(y / vs)) + iz = int(np.floor(z / vs)) + self._voxels[(ix, iy, iz)] = (x, y, z, 0.0, now) + + def _publish_loop(self) -> None: + dt = 1.0 / self.config.publish_rate + while self._running: + t0 = time.monotonic() + now = time.time() + decay = self.config.decay_time + max_r2 = self.config.max_range**2 + + with self._lock: + rx, ry = self._robot_x, self._robot_y + # Expire old voxels and range-cull + expired = [] + pts = [] + for k, (x, y, z, _intensity, ts) in self._voxels.items(): + if now - ts > decay: + expired.append(k) + elif (x - rx) ** 2 + (y - ry) ** 2 > max_r2: + expired.append(k) + else: + pts.append([x, y, z]) + for k in expired: + del self._voxels[k] + + if pts: + arr = np.array(pts, dtype=np.float32) + self.terrain_map_ext._transport.publish( + PointCloud2.from_numpy(arr, frame_id="map", timestamp=now) + ) + + elapsed = time.monotonic() - t0 + if elapsed < dt: + time.sleep(dt - elapsed) diff --git a/dimos/navigation/smartnav/modules/tests/test_cmd_vel_mux.py b/dimos/navigation/smartnav/modules/tests/test_cmd_vel_mux.py new file mode 100644 index 0000000000..11f21c9471 --- /dev/null +++ b/dimos/navigation/smartnav/modules/tests/test_cmd_vel_mux.py @@ -0,0 +1,57 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CmdVelMux teleop/nav priority switching.""" + +from __future__ import annotations + +from dimos.navigation.smartnav.modules.cmd_vel_mux import CmdVelMux + + +class TestCmdVelMux: + def test_teleop_initially_inactive(self) -> None: + mux = CmdVelMux.__new__(CmdVelMux) + mux.__dict__["_teleop_active"] = False + assert not mux._teleop_active + + def test_end_teleop_clears_flag(self) -> None: + import threading + + mux = CmdVelMux.__new__(CmdVelMux) + mux.__dict__["_teleop_active"] = True + mux.__dict__["_timer"] = None + mux.__dict__["_lock"] = threading.Lock() + mux._end_teleop() + assert not mux._teleop_active + + def test_nav_suppressed_when_teleop_active(self) -> None: + """When _teleop_active is True, _on_nav returns early (no publish).""" + import threading + + mux = CmdVelMux.__new__(CmdVelMux) + mux.__dict__["_teleop_active"] = True + mux.__dict__["_lock"] = threading.Lock() + # _on_nav should return before reaching cmd_vel._transport.publish + # If it didn't return early, it would crash since cmd_vel has no transport + from dimos.msgs.geometry_msgs.Twist import Twist + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + mux._on_nav(Twist(linear=Vector3(1, 0, 0), angular=Vector3(0, 0, 0))) + assert mux._teleop_active # Still active, nav was suppressed + + def test_cooldown_default(self) -> None: + from dimos.navigation.smartnav.modules.cmd_vel_mux import CmdVelMuxConfig + + config = CmdVelMuxConfig() + assert config.teleop_cooldown_sec == 1.0 diff --git a/dimos/navigation/smartnav/modules/tui_control/test_tui_control.py b/dimos/navigation/smartnav/modules/tui_control/test_tui_control.py new file mode 100644 index 0000000000..23e1302fa2 --- /dev/null +++ b/dimos/navigation/smartnav/modules/tui_control/test_tui_control.py @@ -0,0 +1,157 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TUIControlModule.""" + +import pytest + +from dimos.navigation.smartnav.modules.tui_control.tui_control import TUIControlModule + + +class _MockTransport: + """Lightweight mock transport that captures published messages.""" + + def __init__(self): + self._messages = [] + self._subscribers = [] + + def publish(self, msg): + self._messages.append(msg) + for cb in self._subscribers: + cb(msg) + + def broadcast(self, _stream, msg): + self.publish(msg) + + def subscribe(self, cb): + self._subscribers.append(cb) + + def unsub(): + self._subscribers.remove(cb) + + return unsub + + +class TestTUIControl: + """Test TUI controller key handling and output.""" + + @pytest.fixture(autouse=True) + def _create_module(self): + self.module = TUIControlModule(max_speed=2.0, max_yaw_rate=1.5) + yield + self.module.stop() + + def test_initial_state_zero(self): + """All velocities should start at zero.""" + module = self.module + assert module._fwd == 0.0 + assert module._left == 0.0 + assert module._yaw == 0.0 + + def test_forward_key(self): + """'w' key should set forward motion.""" + module = self.module + module._handle_key("w") + assert module._fwd == 1.0 + assert module._left == 0.0 + assert module._yaw == 0.0 + + def test_backward_key(self): + """'s' key should set backward motion.""" + module = self.module + module._handle_key("s") + assert module._fwd == -1.0 + + def test_strafe_left_key(self): + """'a' key should set left strafe.""" + module = self.module + module._handle_key("a") + assert module._left == 1.0 + assert module._fwd == 0.0 + + def test_strafe_right_key(self): + """'d' key should set right strafe.""" + module = self.module + module._handle_key("d") + assert module._left == -1.0 + + def test_rotate_left_key(self): + """'q' key should set left rotation.""" + module = self.module + module._handle_key("q") + assert module._yaw == 1.0 + assert module._fwd == 0.0 + assert module._left == 0.0 + + def test_rotate_right_key(self): + """'e' key should set right rotation.""" + module = self.module + module._handle_key("e") + assert module._yaw == -1.0 + + def test_stop_key(self): + """Space should stop all motion.""" + module = self.module + module._handle_key("w") + assert module._fwd == 1.0 + module._handle_key(" ") + assert module._fwd == 0.0 + assert module._left == 0.0 + assert module._yaw == 0.0 + + def test_speed_increase(self): + """'+' key should increase speed scale.""" + module = self.module + # First decrease from the default (1.0) so there is room to increase + module._handle_key("-") + lowered_scale = module._speed_scale + module._handle_key("+") + assert module._speed_scale > lowered_scale + + def test_speed_decrease(self): + """'-' key should decrease speed scale.""" + module = self.module + module._handle_key("-") + assert module._speed_scale < 1.0 + + def test_speed_scale_bounds(self): + """Speed scale should be bounded [0.1, 1.0].""" + module = self.module + # Try to go below minimum + for _ in range(20): + module._handle_key("-") + assert module._speed_scale >= 0.1 + + # Try to go above maximum + for _ in range(20): + module._handle_key("+") + assert module._speed_scale <= 1.0 + + def test_waypoint_publish(self): + """send_waypoint should publish a PointStamped message.""" + module = self.module + + # Wire a mock transport onto the way_point output port + wp_transport = _MockTransport() + module.way_point._transport = wp_transport + + results = [] + wp_transport.subscribe(lambda msg: results.append(msg)) + + module.send_waypoint(5.0, 10.0, 0.0) + + assert len(results) == 1 + assert results[0].x == 5.0 + assert results[0].y == 10.0 + assert results[0].frame_id == "map" diff --git a/dimos/navigation/smartnav/modules/tui_control/tui_control.py b/dimos/navigation/smartnav/modules/tui_control/tui_control.py new file mode 100644 index 0000000000..dc7776c75f --- /dev/null +++ b/dimos/navigation/smartnav/modules/tui_control/tui_control.py @@ -0,0 +1,217 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TUIControlModule: terminal-based teleop controller. + +Provides arrow-key control for the vehicle and mode switching. +""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist + + +class TUIControlConfig(ModuleConfig): + """Configuration for the TUI controller.""" + + max_speed: float = 2.0 + max_yaw_rate: float = 1.5 + speed_step: float = 0.1 + publish_rate: float = 20.0 # Hz + + +class TUIControlModule(Module[TUIControlConfig]): + """Terminal-based teleop controller with arrow key input. + + Ports: + cmd_vel (Out[Twist]): Velocity commands from keyboard. + way_point (Out[PointStamped]): Waypoint commands (typed coordinates). + """ + + default_config = TUIControlConfig + + cmd_vel: Out[Twist] + way_point: Out[PointStamped] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._lock = threading.Lock() + self._fwd = 0.0 + self._left = 0.0 + self._yaw = 0.0 + self._speed_scale = 1.0 + self._running = False + self._publish_thread: threading.Thread | None = None + self._input_thread: threading.Thread | None = None + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + state.pop("_lock", None) + state.pop("_publish_thread", None) + state.pop("_input_thread", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._lock = threading.Lock() + self._publish_thread = None + self._input_thread = None + + def start(self) -> None: + self._running = True + self._publish_thread = threading.Thread(target=self._publish_loop, daemon=True) + self._publish_thread.start() + self._input_thread = threading.Thread(target=self._input_loop, daemon=True) + self._input_thread.start() + + def stop(self) -> None: + self._running = False + if self._publish_thread: + self._publish_thread.join(timeout=2.0) + super().stop() + + def _publish_loop(self) -> None: + """Publish current velocity at fixed rate.""" + dt = 1.0 / self.config.publish_rate + while self._running: + with self._lock: + fwd = self._fwd + left = self._left + yaw = self._yaw + scale = self._speed_scale + twist = Twist( + linear=[ + fwd * scale * self.config.max_speed, + left * scale * self.config.max_speed, + 0.0, + ], + angular=[ + 0.0, + 0.0, + yaw * scale * self.config.max_yaw_rate, + ], + ) + self.cmd_vel._transport.publish(twist) + time.sleep(dt) + + def _input_loop(self) -> None: + """Read keyboard input for teleop control. + + Controls: + w/up: forward, s/down: backward + a/left: strafe left, d/right: strafe right + q: rotate left, e: rotate right + +/-: increase/decrease speed + space: stop + Ctrl+C: quit + """ + try: + import sys + import termios + import tty + + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + + print("\n--- SmartNav TUI Controller ---") + print("w/s: fwd/back | a/d: strafe | q/e: rotate") + print("+/-: speed | g: waypoint | space: stop") + print("Ctrl+C: quit") + print("-------------------------------\n") + + try: + tty.setraw(fd) + while self._running: + ch = sys.stdin.read(1) + if ch == "\x03": # Ctrl+C + self._running = False + break + self._handle_key(ch) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + except Exception: + # Not a terminal (e.g., running in a worker process, piped stdin, etc.) + while self._running: + time.sleep(1.0) + + def _handle_key(self, ch: str) -> None: + """Process a single keypress.""" + with self._lock: + if ch in ("w", "W"): + self._fwd = 1.0 + self._left = 0.0 + self._yaw = 0.0 + elif ch in ("s", "S"): + self._fwd = -1.0 + self._left = 0.0 + self._yaw = 0.0 + elif ch in ("a", "A"): + self._fwd = 0.0 + self._left = 1.0 + self._yaw = 0.0 + elif ch in ("d", "D"): + self._fwd = 0.0 + self._left = -1.0 + self._yaw = 0.0 + elif ch in ("q", "Q"): + self._fwd = 0.0 + self._left = 0.0 + self._yaw = 1.0 + elif ch in ("e", "E"): + self._fwd = 0.0 + self._left = 0.0 + self._yaw = -1.0 + elif ch == " ": + self._fwd = 0.0 + self._left = 0.0 + self._yaw = 0.0 + elif ch == "+" or ch == "=": + self._speed_scale = min(self._speed_scale + 0.1, 1.0) + elif ch == "-": + self._speed_scale = max(self._speed_scale - 0.1, 0.1) + if ch == "\x1b": + import sys + + seq1 = sys.stdin.read(1) + if seq1 == "[": + seq2 = sys.stdin.read(1) + with self._lock: + if seq2 == "A": # Up + self._fwd = 1.0 + self._left = 0.0 + self._yaw = 0.0 + elif seq2 == "B": # Down + self._fwd = -1.0 + self._left = 0.0 + self._yaw = 0.0 + elif seq2 == "C": # Right + self._fwd = 0.0 + self._left = -1.0 + self._yaw = 0.0 + elif seq2 == "D": # Left + self._fwd = 0.0 + self._left = 1.0 + self._yaw = 0.0 + + def send_waypoint(self, x: float, y: float, z: float = 0.0) -> None: + """Programmatically send a waypoint.""" + wp = PointStamped(x=x, y=y, z=z, frame_id="map") + self.way_point._transport.publish(wp) diff --git a/dimos/navigation/smartnav/tests/test_explore_movement.py b/dimos/navigation/smartnav/tests/test_explore_movement.py new file mode 100644 index 0000000000..be5585871c --- /dev/null +++ b/dimos/navigation/smartnav/tests/test_explore_movement.py @@ -0,0 +1,361 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test: verify exploration planner produces movement. + +Validates the complete explore pipeline: + [MockVehicle] → registered_scan + odometry + → [SensorScanGeneration] → sensor_scan + → [TerrainAnalysis] → terrain_map + → [TarePlanner] → way_point (exploration waypoints) + → [LocalPlanner] → path (autonomyMode=true) + → [PathFollower] → cmd_vel + → [MockVehicle] (tracks position changes) + +Requires built C++ native binaries (nix build). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import math +from pathlib import Path +import platform +import threading +import time +from typing import Any + +import numpy as np +import pytest + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +_NATIVE_DIR = Path(__file__).resolve().parent.parent +_REQUIRED_BINARIES = [ + ("result-terrain-analysis", "terrain_analysis"), + ("result-local-planner", "local_planner"), + ("result-path-follower", "path_follower"), + ("result-tare-planner", "tare_planner"), +] +_HAS_BINARIES = all((_NATIVE_DIR / d / "bin" / name).exists() for d, name in _REQUIRED_BINARIES) +_IS_LINUX_X86 = platform.system() == "Linux" and platform.machine() in ("x86_64", "AMD64") + +pytestmark = [ + pytest.mark.slow, + pytest.mark.skipif(not _IS_LINUX_X86, reason="Native modules require Linux x86_64"), + pytest.mark.skipif( + not _HAS_BINARIES, + reason="Native binaries not built (run: cd smartnav/native && nix build)", + ), +] + + +# Helpers (must be at module level for pickling) + + +def _make_room_cloud( + robot_x: float, + robot_y: float, + room_size: float = 20.0, + wall_height: float = 2.5, + ground_z: float = 0.0, + density: float = 0.3, +) -> np.ndarray: + """Generate a room point cloud: flat ground + walls on 4 sides. + + Returns Nx3 array [x, y, z] (PointCloud2.from_numpy expects Nx3). + """ + pts = [] + + step = 1.0 / density + half = room_size / 2 + xs = np.arange(robot_x - half, robot_x + half, step) + ys = np.arange(robot_y - half, robot_y + half, step) + xx, yy = np.meshgrid(xs, ys) + ground = np.column_stack( + [ + xx.ravel(), + yy.ravel(), + np.full(xx.size, ground_z), + ] + ) + pts.append(ground) + + wall_step = 0.5 + for wall_x in [robot_x - half, robot_x + half]: + wy = np.arange(robot_y - half, robot_y + half, wall_step) + wz = np.arange(ground_z, ground_z + wall_height, wall_step) + wyy, wzz = np.meshgrid(wy, wz) + wall = np.column_stack( + [ + np.full(wyy.size, wall_x), + wyy.ravel(), + wzz.ravel(), + ] + ) + pts.append(wall) + + for wall_y in [robot_y - half, robot_y + half]: + wx = np.arange(robot_x - half, robot_x + half, wall_step) + wz = np.arange(ground_z, ground_z + wall_height, wall_step) + wxx, wzz = np.meshgrid(wx, wz) + wall = np.column_stack( + [ + wxx.ravel(), + np.full(wxx.size, wall_y), + wzz.ravel(), + ] + ) + pts.append(wall) + + return np.concatenate(pts, axis=0).astype(np.float32) + + +class MockVehicleConfig(ModuleConfig): + rate: float = 10.0 + sim_rate: float = 50.0 + + +class MockVehicle(Module[MockVehicleConfig]): + """Publishes sensor data and integrates cmd_vel for position tracking.""" + + default_config = MockVehicleConfig + + cmd_vel: In[Twist] + registered_scan: Out[PointCloud2] + odometry: Out[Odometry] + + def __init__(self, **kwargs): # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._x = 0.0 + self._y = 0.0 + self._z = 0.75 + self._yaw = 0.0 + self._fwd = 0.0 + self._left = 0.0 + self._yaw_rate = 0.0 + self._cmd_lock = threading.Lock() + self._running = False + self._sensor_thread: threading.Thread | None = None + self._sim_thread: threading.Thread | None = None + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + state.pop("_cmd_lock", None) + state.pop("_sensor_thread", None) + state.pop("_sim_thread", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._cmd_lock = threading.Lock() + self._sensor_thread = None + self._sim_thread = None + + def start(self) -> None: + self.cmd_vel._transport.subscribe(self._on_cmd_vel) + self._running = True + self._sim_thread = threading.Thread(target=self._sim_loop, daemon=True) + self._sim_thread.start() + self._sensor_thread = threading.Thread(target=self._sensor_loop, daemon=True) + self._sensor_thread.start() + + def stop(self) -> None: + self._running = False + if self._sim_thread: + self._sim_thread.join(timeout=3.0) + if self._sensor_thread: + self._sensor_thread.join(timeout=3.0) + super().stop() + + def _on_cmd_vel(self, twist: Twist) -> None: + with self._cmd_lock: + self._fwd = twist.linear.x + self._left = twist.linear.y + self._yaw_rate = twist.angular.z + + def _sim_loop(self) -> None: + dt = 1.0 / self.config.sim_rate + while self._running: + t0 = time.monotonic() + with self._cmd_lock: + fwd, left, yr = self._fwd, self._left, self._yaw_rate + + self._yaw += dt * yr + cy, sy = math.cos(self._yaw), math.sin(self._yaw) + self._x += dt * (cy * fwd - sy * left) + self._y += dt * (sy * fwd + cy * left) + + now = time.time() + quat = Quaternion.from_euler(Vector3(0.0, 0.0, self._yaw)) + self.odometry._transport.publish( + Odometry( + ts=now, + frame_id="map", + child_frame_id="sensor", + pose=Pose( + position=[self._x, self._y, self._z], + orientation=[quat.x, quat.y, quat.z, quat.w], + ), + twist=Twist( + linear=[fwd, left, 0.0], + angular=[0.0, 0.0, yr], + ), + ) + ) + self.tf.publish( + Transform( + translation=Vector3(self._x, self._y, self._z), + rotation=quat, + frame_id="map", + child_frame_id="sensor", + ts=now, + ), + ) + + elapsed = time.monotonic() - t0 + if elapsed < dt: + time.sleep(dt - elapsed) + + def _sensor_loop(self) -> None: + dt = 1.0 / self.config.rate + while self._running: + now = time.time() + cloud_data = _make_room_cloud(self._x, self._y) + self.registered_scan._transport.publish( + PointCloud2.from_numpy(cloud_data, frame_id="map", timestamp=now) + ) + time.sleep(dt) + + +@dataclass +class Collector: + """Thread-safe message collector.""" + + waypoints: list = field(default_factory=list) + paths: list = field(default_factory=list) + cmd_vels: list = field(default_factory=list) + terrain_maps: list = field(default_factory=list) + lock: threading.Lock = field(default_factory=threading.Lock) + + +# Test + + +def test_explore_produces_movement(): + """End-to-end: TARE planner drives robot movement via full pipeline.""" + from dimos.core.blueprints import autoconnect + from dimos.msgs.geometry_msgs.PointStamped import PointStamped + from dimos.msgs.nav_msgs.Path import Path as NavPath + from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner + from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower + from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, + ) + from dimos.navigation.smartnav.modules.tare_planner.tare_planner import TarePlanner + from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis + + collector = Collector() + + blueprint = autoconnect( + MockVehicle.blueprint(), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint(), + LocalPlanner.blueprint( + extra_args=["--autonomyMode", "true"], + ), + PathFollower.blueprint( + extra_args=["--autonomyMode", "true"], + ), + TarePlanner.blueprint(), + ) + + coordinator = blueprint.build() + + # Subscribe to outputs + tare = coordinator.get_instance(TarePlanner) + planner = coordinator.get_instance(LocalPlanner) + follower = coordinator.get_instance(PathFollower) + coordinator.get_instance(MockVehicle) + terrain = coordinator.get_instance(TerrainAnalysis) + + def _on_wp(msg: PointStamped) -> None: + with collector.lock: + collector.waypoints.append((msg.x, msg.y, msg.z)) + + def _on_terrain(msg: PointCloud2) -> None: + with collector.lock: + collector.terrain_maps.append(True) + + def _on_path(msg: NavPath) -> None: + with collector.lock: + collector.paths.append(msg) + + def _on_cmd(msg: Twist) -> None: + with collector.lock: + collector.cmd_vels.append((msg.linear.x, msg.linear.y, msg.angular.z)) + + tare.way_point._transport.subscribe(_on_wp) + planner.path._transport.subscribe(_on_path) + follower.cmd_vel._transport.subscribe(_on_cmd) + terrain.terrain_map._transport.subscribe(_on_terrain) + + try: + coordinator.start() + + # Wait for pipeline outputs — TARE needs several scan cycles + deadline = time.monotonic() + 30.0 + while time.monotonic() < deadline: + with collector.lock: + has_terrain = len(collector.terrain_maps) > 0 + has_waypoints = len(collector.waypoints) > 0 + has_paths = len(collector.paths) > 0 + has_cmds = len(collector.cmd_vels) > 0 + if has_terrain and has_waypoints and has_paths and has_cmds: + break + time.sleep(0.5) + + # Let movement accumulate + time.sleep(5.0) + + # -- Assertions -- + with collector.lock: + assert len(collector.terrain_maps) > 0, "TerrainAnalysis never produced terrain_map" + + assert len(collector.waypoints) > 0, "TarePlanner never produced a waypoint" + + assert len(collector.paths) > 0, ( + "LocalPlanner never produced a path — check that autonomyMode=true is being passed" + ) + + nonzero_cmds = [ + (vx, vy, wz) + for vx, vy, wz in collector.cmd_vels + if abs(vx) > 0.01 or abs(vy) > 0.01 or abs(wz) > 0.01 + ] + assert len(nonzero_cmds) > 0, ( + f"PathFollower produced {len(collector.cmd_vels)} cmd_vels " + f"but ALL were zero — robot is not moving" + ) + + finally: + coordinator.stop() diff --git a/dimos/navigation/smartnav/tests/test_full_nav_loop.py b/dimos/navigation/smartnav/tests/test_full_nav_loop.py new file mode 100644 index 0000000000..706456483d --- /dev/null +++ b/dimos/navigation/smartnav/tests/test_full_nav_loop.py @@ -0,0 +1,210 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test: full navigation closed loop. + +Verifies that synthetic lidar + odometry data flows through the entire +SmartNav pipeline and produces autonomous navigation output: + + [MockSensor] → registered_scan + odometry + → [SensorScanGeneration] → sensor_scan + → [TerrainAnalysis] → terrain_map + → [LocalPlanner] → path + → [PathFollower] → cmd_vel + +Requires built C++ native binaries (nix build). +""" + +from __future__ import annotations + +from pathlib import Path +import platform +import threading +import time +from typing import Any + +import numpy as np +import pytest + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import Out +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +# Skip conditions +_NATIVE_DIR = Path(__file__).resolve().parent.parent +_HAS_BINARIES = all( + (_NATIVE_DIR / d / "bin" / name).exists() + for d, name in [ + ("result-terrain-analysis", "terrain_analysis"), + ("result-local-planner", "local_planner"), + ("result-path-follower", "path_follower"), + ] +) +_IS_LINUX_X86 = platform.system() == "Linux" and platform.machine() in ("x86_64", "AMD64") + +pytestmark = [ + pytest.mark.slow, + pytest.mark.skipif(not _IS_LINUX_X86, reason="Native modules require Linux x86_64"), + pytest.mark.skipif(not _HAS_BINARIES, reason="Native binaries not built"), +] + + +def _make_flat_ground_cloud() -> np.ndarray: + """Nx3 flat ground cloud around origin.""" + step = 2.0 + xs = np.arange(-10, 10, step) + ys = np.arange(-10, 10, step) + xx, yy = np.meshgrid(xs, ys) + return np.column_stack([xx.ravel(), yy.ravel(), np.zeros(xx.size)]).astype(np.float32) + + +class MockSensorConfig(ModuleConfig): + rate: float = 5.0 + + +class MockSensor(Module[MockSensorConfig]): + """Publishes synthetic lidar + odometry at fixed rate.""" + + default_config = MockSensorConfig + registered_scan: Out[PointCloud2] + odometry: Out[Odometry] + + def __init__(self, **kwargs): # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._running = False + self._thread: threading.Thread | None = None + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + state.pop("_thread", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._thread = None + + def start(self) -> None: + self._running = True + self._thread = threading.Thread(target=self._loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + self._running = False + if self._thread: + self._thread.join(timeout=3.0) + super().stop() + + def _loop(self) -> None: + dt = 1.0 / self.config.rate + while self._running: + now = time.time() + self.registered_scan._transport.publish( + PointCloud2.from_numpy(_make_flat_ground_cloud(), frame_id="map", timestamp=now) + ) + quat = Quaternion(0.0, 0.0, 0.0, 1.0) + self.odometry._transport.publish( + Odometry( + ts=now, + frame_id="map", + child_frame_id="sensor", + pose=Pose( + position=[0.0, 0.0, 0.75], + orientation=[quat.x, quat.y, quat.z, quat.w], + ), + twist=Twist(linear=[0.0, 0.0, 0.0], angular=[0.0, 0.0, 0.0]), + ) + ) + self.tf.publish( + Transform( + translation=Vector3(0.0, 0.0, 0.75), + rotation=quat, + frame_id="map", + child_frame_id="sensor", + ts=now, + ), + ) + time.sleep(dt) + + +def test_full_nav_closed_loop(): + """End-to-end: synthetic data -> terrain_map + path + cmd_vel produced.""" + from dimos.core.blueprints import autoconnect + from dimos.msgs.geometry_msgs.PointStamped import PointStamped + from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner + from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower + from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, + ) + from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis + + terrain_maps: list = [] + paths: list = [] + cmd_vels: list = [] + lock = threading.Lock() + + blueprint = autoconnect( + MockSensor.blueprint(), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint(), + LocalPlanner.blueprint(extra_args=["--autonomyMode", "true"]), + PathFollower.blueprint(extra_args=["--autonomyMode", "true"]), + ) + + coordinator = blueprint.build() + + terrain = coordinator.get_instance(TerrainAnalysis) + planner = coordinator.get_instance(LocalPlanner) + follower = coordinator.get_instance(PathFollower) + + terrain.terrain_map._transport.subscribe( + lambda m: (lock.acquire(), terrain_maps.append(m), lock.release()) + ) + planner.path._transport.subscribe(lambda m: (lock.acquire(), paths.append(m), lock.release())) + follower.cmd_vel._transport.subscribe( + lambda m: (lock.acquire(), cmd_vels.append(m), lock.release()) + ) + + # Send waypoint after warmup + def _send_waypoint() -> None: + time.sleep(3.0) + lp = coordinator.get_instance(LocalPlanner) + wp = PointStamped(x=5.0, y=0.0, z=0.0, frame_id="map") + lp.way_point._transport.publish(wp) + + wp_thread = threading.Thread(target=_send_waypoint, daemon=True) + wp_thread.start() + + try: + coordinator.start() + + deadline = time.monotonic() + 30.0 + while time.monotonic() < deadline: + with lock: + done = len(terrain_maps) > 0 and len(paths) > 0 and len(cmd_vels) > 0 + if done: + break + time.sleep(0.5) + + with lock: + assert len(terrain_maps) > 0, "TerrainAnalysis produced no terrain_map" + assert len(paths) > 0, "LocalPlanner produced no path" + assert len(cmd_vels) > 0, "PathFollower produced no cmd_vel" + finally: + coordinator.stop() diff --git a/dimos/navigation/smartnav/tests/test_nav_loop.py b/dimos/navigation/smartnav/tests/test_nav_loop.py new file mode 100644 index 0000000000..6d105a645e --- /dev/null +++ b/dimos/navigation/smartnav/tests/test_nav_loop.py @@ -0,0 +1,177 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test: verify blueprint construction and autoconnect wiring. + +Tests the real blueprint.build() path which involves: +- Module pickling across worker processes +- Transport assignment via autoconnect +- Stream wiring by name+type matching +""" + +import time + +import numpy as np + +from dimos.core.blueprints import autoconnect +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.tui_control.tui_control import TUIControlModule +from dimos.simulation.unity.module import UnityBridgeModule + + +class TestBlueprintConstruction: + """Test that autoconnect produces a valid blueprint without errors.""" + + def test_python_modules_autoconnect(self): + """autoconnect on Python-only modules should not raise.""" + bp = autoconnect( + UnityBridgeModule.blueprint(sim_rate=10.0), + SensorScanGeneration.blueprint(), + TUIControlModule.blueprint(publish_rate=1.0), + ) + # Should have 3 module atoms + assert len(bp.blueprints) == 3 + + def test_full_blueprint_autoconnect(self): + """Full simulation blueprint including NativeModules should not raise.""" + from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner + from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower + from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import ( + TerrainAnalysis, + ) + + bp = autoconnect( + UnityBridgeModule.blueprint(sim_rate=10.0), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint(), + LocalPlanner.blueprint(), + PathFollower.blueprint(), + TUIControlModule.blueprint(publish_rate=1.0), + ) + assert len(bp.blueprints) == 6 + + def test_no_type_conflicts(self): + """Blueprint should detect no type conflicts among streams.""" + from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner + from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower + from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import ( + TerrainAnalysis, + ) + + bp = autoconnect( + UnityBridgeModule.blueprint(sim_rate=10.0), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint(), + LocalPlanner.blueprint(), + PathFollower.blueprint(), + TUIControlModule.blueprint(publish_rate=1.0), + ) + # _verify_no_name_conflicts is called during build() -- test it directly + bp._verify_no_name_conflicts() # should not raise + + +class TestEndToEndDataFlow: + """Test data flowing through real LCM transports between modules.""" + + def test_odom_flows_from_sim_to_scan_gen(self): + """Odometry published by UnityBridge should reach SensorScanGeneration.""" + sim = UnityBridgeModule(sim_rate=200.0) + scan_gen = SensorScanGeneration() + + # Shared transport (simulates what autoconnect does) + odom_transport = LCMTransport("/e2e_odom", Odometry) + sim.odometry._transport = odom_transport + scan_gen.odometry._transport = odom_transport + + # Wire dummy transports for other ports so start() doesn't fail + scan_gen.registered_scan._transport = LCMTransport("/e2e_regscan", PointCloud2) + scan_gen.sensor_scan._transport = LCMTransport("/e2e_sensorscan", PointCloud2) + scan_gen.odometry_at_scan._transport = LCMTransport("/e2e_odom_at_scan", Odometry) + + # Start scan gen (subscribes to odom transport) + scan_gen.start() + + # Publish odometry through sim's transport + quat = Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)) + odom = Odometry( + ts=time.time(), + frame_id="map", + child_frame_id="sensor", + pose=Pose( + position=[5.0, 3.0, 0.75], + orientation=[quat.x, quat.y, quat.z, quat.w], + ), + ) + odom_transport.publish(odom) + time.sleep(0.1) + + # SensorScanGeneration should have received it + assert scan_gen._latest_odom is not None + assert abs(scan_gen._latest_odom.x - 5.0) < 0.01 + + def test_full_scan_transform_chain(self): + """Odom + cloud in -> sensor-frame cloud out, all via transports.""" + scan_gen = SensorScanGeneration() + + odom_t = LCMTransport("/chain_odom", Odometry) + regscan_t = LCMTransport("/chain_regscan", PointCloud2) + sensorscan_t = LCMTransport("/chain_sensorscan", PointCloud2) + odom_at_t = LCMTransport("/chain_odom_at", Odometry) + + scan_gen.odometry._transport = odom_t + scan_gen.registered_scan._transport = regscan_t + scan_gen.sensor_scan._transport = sensorscan_t + scan_gen.odometry_at_scan._transport = odom_at_t + + results = [] + sensorscan_t.subscribe(lambda msg: results.append(msg)) + + scan_gen.start() + + # Publish odometry at (2, 0, 0), no rotation + quat = Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)) + odom_t.publish( + Odometry( + ts=time.time(), + frame_id="map", + child_frame_id="sensor", + pose=Pose( + position=[2.0, 0.0, 0.0], + orientation=[quat.x, quat.y, quat.z, quat.w], + ), + ) + ) + time.sleep(0.05) + + # Publish a world-frame cloud with a point at (5, 0, 0) + cloud = PointCloud2.from_numpy( + np.array([[5.0, 0.0, 0.0]], dtype=np.float32), + frame_id="map", + timestamp=time.time(), + ) + regscan_t.publish(cloud) + time.sleep(0.2) + + # In sensor frame, (5,0,0) - (2,0,0) = (3,0,0) + assert len(results) >= 1 + pts, _ = results[0].as_numpy() + assert abs(pts[0][0] - 3.0) < 0.1 diff --git a/dimos/navigation/smartnav/tests/test_nav_loop_drive.py b/dimos/navigation/smartnav/tests/test_nav_loop_drive.py new file mode 100644 index 0000000000..7e5a371e97 --- /dev/null +++ b/dimos/navigation/smartnav/tests/test_nav_loop_drive.py @@ -0,0 +1,332 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test: robot navigates a multi-waypoint loop. + +Sends waypoints in a square pattern and verifies the robot actually +moves toward each one. Prints detailed odometry + cmd_vel diagnostics. + +This is the definitive test that the nav stack works end-to-end. +""" + +from __future__ import annotations + +import math +from pathlib import Path +import platform +import threading +import time +from typing import Any + +import numpy as np +import pytest + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +_NATIVE_DIR = Path(__file__).resolve().parent.parent +_HAS_BINARIES = all( + (_NATIVE_DIR / d / "bin" / name).exists() + for d, name in [ + ("result-terrain-analysis", "terrain_analysis"), + ("result-local-planner", "local_planner"), + ("result-path-follower", "path_follower"), + ] +) +_IS_LINUX_X86 = platform.system() == "Linux" and platform.machine() in ("x86_64", "AMD64") + +pytestmark = [ + pytest.mark.slow, + pytest.mark.skipif(not _IS_LINUX_X86, reason="Native modules require Linux x86_64"), + pytest.mark.skipif(not _HAS_BINARIES, reason="Native binaries not built"), +] + + +def _make_ground(rx: float, ry: float) -> np.ndarray: + """Flat ground cloud around robot. Nx3.""" + step = 1.5 + xs = np.arange(rx - 15, rx + 15, step) + ys = np.arange(ry - 15, ry + 15, step) + xx, yy = np.meshgrid(xs, ys) + return np.column_stack([xx.ravel(), yy.ravel(), np.zeros(xx.size)]).astype(np.float32) + + +class VehicleConfig(ModuleConfig): + sensor_rate: float = 5.0 + sim_rate: float = 50.0 + + +class Vehicle(Module[VehicleConfig]): + """Kinematic sim vehicle with public position for test inspection.""" + + default_config = VehicleConfig + cmd_vel: In[Twist] + registered_scan: Out[PointCloud2] + odometry: Out[Odometry] + + def __init__(self, **kw): # type: ignore[no-untyped-def] + super().__init__(**kw) + self.x = 0.0 + self.y = 0.0 + self.z = 0.75 + self.yaw = 0.0 + self._fwd = 0.0 + self._left = 0.0 + self._yr = 0.0 + self._lock = threading.Lock() + self._running = False + self._threads: list[threading.Thread] = [] + + def __getstate__(self) -> dict[str, Any]: + s = super().__getstate__() + for k in ("_lock", "_threads"): + s.pop(k, None) + return s + + def __setstate__(self, s: dict) -> None: + super().__setstate__(s) + self._lock = threading.Lock() + self._threads = [] + + def start(self) -> None: + self.cmd_vel._transport.subscribe(self._on_cmd) + self._running = True + for fn in (self._sim_loop, self._sensor_loop): + t = threading.Thread(target=fn, daemon=True) + t.start() + self._threads.append(t) + + def stop(self) -> None: + self._running = False + for t in self._threads: + t.join(timeout=3) + super().stop() + + def _on_cmd(self, tw: Twist) -> None: + with self._lock: + self._fwd = tw.linear.x + self._left = tw.linear.y + self._yr = tw.angular.z + + def _sim_loop(self) -> None: + dt = 1.0 / self.config.sim_rate + while self._running: + t0 = time.monotonic() + with self._lock: + fwd, left, yr = self._fwd, self._left, self._yr + self.yaw += dt * yr + cy, sy = math.cos(self.yaw), math.sin(self.yaw) + self.x += dt * (cy * fwd - sy * left) + self.y += dt * (sy * fwd + cy * left) + now = time.time() + q = Quaternion.from_euler(Vector3(0.0, 0.0, self.yaw)) + self.odometry._transport.publish( + Odometry( + ts=now, + frame_id="map", + child_frame_id="sensor", + pose=Pose(position=[self.x, self.y, self.z], orientation=[q.x, q.y, q.z, q.w]), + twist=Twist(linear=[fwd, left, 0], angular=[0, 0, yr]), + ) + ) + self.tf.publish( + Transform( + translation=Vector3(self.x, self.y, self.z), + rotation=q, + frame_id="map", + child_frame_id="sensor", + ts=now, + ) + ) + sl = dt - (time.monotonic() - t0) + if sl > 0: + time.sleep(sl) + + def _sensor_loop(self) -> None: + dt = 1.0 / self.config.sensor_rate + while self._running: + now = time.time() + cloud = _make_ground(self.x, self.y) + self.registered_scan._transport.publish( + PointCloud2.from_numpy(cloud, frame_id="map", timestamp=now) + ) + time.sleep(dt) + + +def test_multi_waypoint_loop(): + """Send 4 waypoints in a square, verify robot moves toward each.""" + from dimos.core.blueprints import autoconnect + from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner + from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower + from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, + ) + from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis + + # Collect cmd_vel to verify non-zero commands + cmd_log: list[tuple[float, float, float]] = [] + cmd_lock = threading.Lock() + + blueprint = autoconnect( + Vehicle.blueprint(), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ) + coord = blueprint.build() + + planner = coord.get_instance(LocalPlanner) + follower = coord.get_instance(PathFollower) + + follower.cmd_vel._transport.subscribe( + lambda m: ( + cmd_lock.acquire(), + cmd_log.append((m.linear.x, m.linear.y, m.angular.z)), + cmd_lock.release(), + ) + ) + + # Also track path sizes to diagnose stop paths + path_sizes: list[int] = [] + path_lock = threading.Lock() + planner.path._transport.subscribe( + lambda m: (path_lock.acquire(), path_sizes.append(len(m.poses)), path_lock.release()) + ) + + # We can't access vehicle._x directly (Actor proxy blocks private attrs). + # Instead, subscribe to odometry and track position ourselves. + positions: list[tuple[float, float]] = [] + pos_lock = threading.Lock() + + def _on_odom(msg: Odometry) -> None: + with pos_lock: + positions.append((msg.pose.position.x, msg.pose.position.y)) + + vehicle_actor = coord.get_instance(Vehicle) + vehicle_actor.odometry._transport.subscribe(_on_odom) + + coord.start() + + waypoints = [(5.0, 0.0), (5.0, 5.0), (0.0, 5.0), (0.0, 0.0)] + + try: + # Wait for C++ modules to initialize + print("[test] Waiting 3s for modules to start...") + time.sleep(3.0) + + for i, (wx, wy) in enumerate(waypoints): + wp = PointStamped(x=wx, y=wy, z=0.0, frame_id="map") + planner.way_point._transport.publish(wp) + print(f"[test] Sent waypoint {i}: ({wx}, {wy})") + + # Drive toward waypoint for up to 8 seconds + t0 = time.monotonic() + while time.monotonic() - t0 < 8.0: + time.sleep(0.5) + with pos_lock: + if positions: + cx, cy = positions[-1] + else: + cx, cy = 0.0, 0.0 + dist = math.sqrt((cx - wx) ** 2 + (cy - wy) ** 2) + if dist < 1.0: + print(f"[test] Reached wp{i} at ({cx:.2f}, {cy:.2f}), dist={dist:.2f}") + break + else: + with pos_lock: + if positions: + cx, cy = positions[-1] + else: + cx, cy = 0.0, 0.0 + dist = math.sqrt((cx - wx) ** 2 + (cy - wy) ** 2) + print(f"[test] Timeout wp{i}: pos=({cx:.2f}, {cy:.2f}), dist={dist:.2f}") + + # Final position summary + with pos_lock: + if positions: + fx, fy = positions[-1] + else: + fx, fy = 0.0, 0.0 + print(f"[test] Final position: ({fx:.2f}, {fy:.2f})") + + # Check we actually moved + with pos_lock: + all_x = [p[0] for p in positions] + all_y = [p[1] for p in positions] + x_range = max(all_x) - min(all_x) if all_x else 0 + y_range = max(all_y) - min(all_y) if all_y else 0 + print( + f"[test] Position range: x=[{min(all_x):.2f}, {max(all_x):.2f}] y=[{min(all_y):.2f}, {max(all_y):.2f}]" + ) + + with cmd_lock: + total_cmds = len(cmd_log) + nonzero = sum( + 1 for vx, vy, wz in cmd_log if abs(vx) > 0.01 or abs(vy) > 0.01 or abs(wz) > 0.01 + ) + print(f"[test] cmd_vel: {total_cmds} total, {nonzero} non-zero") + + with path_lock: + n_paths = len(path_sizes) + stop_paths = sum(1 for s in path_sizes if s <= 1) + real_paths = sum(1 for s in path_sizes if s > 1) + if path_sizes: + avg_len = sum(path_sizes) / len(path_sizes) + else: + avg_len = 0 + print( + f"[test] paths: {n_paths} total, {real_paths} real (>1 pose), {stop_paths} stop (<=1 pose), avg_len={avg_len:.1f}" + ) + + # Hard assertions + assert total_cmds > 0, "No cmd_vel messages at all" + assert nonzero > 0, f"All {total_cmds} cmd_vel were zero — autonomyMode not working" + assert x_range > 1.0 or y_range > 1.0, ( + f"Robot barely moved: x_range={x_range:.2f}, y_range={y_range:.2f}. " + f"Non-zero cmds: {nonzero}/{total_cmds}" + ) + + finally: + coord.stop() diff --git a/dimos/navigation/smartnav/tests/test_paths_and_blueprint.py b/dimos/navigation/smartnav/tests/test_paths_and_blueprint.py new file mode 100644 index 0000000000..9a72259788 --- /dev/null +++ b/dimos/navigation/smartnav/tests/test_paths_and_blueprint.py @@ -0,0 +1,99 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests: verify all paths resolve and blueprint is constructable.""" + +import importlib +from pathlib import Path + +import pytest + +from dimos.core.native_module import NativeModule + + +class TestAllNativeModulePaths: + """Every NativeModule in smartnav must have valid, existing paths.""" + + @pytest.fixture( + params=[ + "terrain_analysis", + "local_planner", + "path_follower", + "far_planner", + "tare_planner", + "arise_slam", + ] + ) + def native_module(self, request): + """Parametrized fixture that yields each native module class.""" + name = request.param + mod = importlib.import_module(f"dimos.navigation.smartnav.modules.{name}.{name}") + # The class name varies; find the NativeModule subclass + for attr_name in dir(mod): + attr = getattr(mod, attr_name) + if ( + isinstance(attr, type) + and issubclass(attr, NativeModule) + and attr is not NativeModule + ): + return attr + pytest.fail(f"No NativeModule subclass found in {name}") + + def test_cwd_exists(self, native_module): + m = native_module() + m._resolve_paths() + try: + assert Path(m.config.cwd).exists() + finally: + m.stop() + + def test_executable_exists(self, native_module): + m = native_module() + m._resolve_paths() + try: + assert Path(m.config.executable).exists() + finally: + m.stop() + + def test_cwd_is_smartnav_root(self, native_module): + m = native_module() + m._resolve_paths() + try: + cwd = Path(m.config.cwd).resolve() + assert (cwd / "CMakeLists.txt").exists() + finally: + m.stop() + + +class TestDataFiles: + def test_path_data_exists(self): + from dimos.utils.data import get_data + + data = get_data("smartnav_paths") + for f in ["startPaths.ply", "pathList.ply", "paths.ply"]: + assert (data / f).exists(), f"Missing data file: {data / f}" + + +class TestBlueprintImport: + def test_g1_nav_sim_blueprint_importable(self): + from dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_sim import ( + unitree_g1_nav_sim, + ) + + assert unitree_g1_nav_sim is not None + + def test_simulation_blueprint_importable(self): + from dimos.navigation.smartnav.blueprints.simulation import simulation_blueprint + + assert simulation_blueprint is not None diff --git a/dimos/navigation/smartnav/tests/test_pgo_global_map.py b/dimos/navigation/smartnav/tests/test_pgo_global_map.py new file mode 100644 index 0000000000..c04e1b74ba --- /dev/null +++ b/dimos/navigation/smartnav/tests/test_pgo_global_map.py @@ -0,0 +1,380 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests: PGO global map functionality. + +Tests the PGO (Pose Graph Optimization) module's global map capabilities: +- Global map accumulation from keyframes +- Global map point cloud contains points from ALL keyframes +- Loop closure updates the global map positions +- Global map can be exported as a valid PointCloud2 + +Uses the Python reference implementation for algorithm-level testing. +""" + +from __future__ import annotations + +import math +import time + +import numpy as np +import pytest +from scipy.spatial.transform import Rotation + +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +try: + from dimos.navigation.smartnav.modules.pgo.pgo_reference import PGOConfig, SimplePGOReference + + _HAS_PGO_DEPS = True +except ImportError: + _HAS_PGO_DEPS = False + +pytestmark = pytest.mark.skipif(not _HAS_PGO_DEPS, reason="gtsam not installed") + +# ─── Helpers ───────────────────────────────────────────────────────────────── + + +def make_rotation(yaw_deg: float) -> np.ndarray: + return Rotation.from_euler("z", yaw_deg, degrees=True).as_matrix() + + +def make_structured_cloud(center: np.ndarray, n_points: int = 500, seed: int = 42) -> np.ndarray: + """Create a sphere-surface point cloud around a center.""" + rng = np.random.default_rng(seed) + phi = rng.uniform(0, 2 * np.pi, n_points) + theta = rng.uniform(0, np.pi, n_points) + r = 2.0 + x = r * np.sin(theta) * np.cos(phi) + center[0] + y = r * np.sin(theta) * np.sin(phi) + center[1] + z = r * np.cos(theta) + center[2] + return np.column_stack([x, y, z]) + + +def make_random_cloud( + center: np.ndarray, n_points: int = 200, spread: float = 1.0, seed: int | None = None +) -> np.ndarray: + rng = np.random.default_rng(seed) + return center + rng.normal(0, spread, (n_points, 3)) + + +def drive_trajectory( + pgo: SimplePGOReference, + waypoints: list[np.ndarray], + step: float = 0.4, + time_per_step: float = 1.0, + cloud_seed_base: int = 0, +) -> None: + """Drive a trajectory through a list of waypoints, adding keyframes.""" + t = 0.0 + pos = waypoints[0].copy() + for i in range(1, len(waypoints)): + direction = waypoints[i] - waypoints[i - 1] + dist = np.linalg.norm(direction) + if dist < 1e-6: + continue + direction_norm = direction / dist + yaw = math.degrees(math.atan2(direction_norm[1], direction_norm[0])) + r = make_rotation(yaw) + n_steps = int(dist / step) + + for s in range(n_steps): + pos = waypoints[i - 1] + direction_norm * step * (s + 1) + cloud = make_structured_cloud( + np.zeros(3), n_points=200, seed=(cloud_seed_base + int(t)) % 10000 + ) + added = pgo.add_key_pose(r, pos, t, cloud) + if added: + pgo.search_for_loop_pairs() + pgo.smooth_and_update() + t += time_per_step + + +# ─── Global Map Accumulation Tests ─────────────────────────────────────────── + + +class TestGlobalMapAccumulation: + """Test that PGO produces a valid global map from keyframes.""" + + def test_global_map_contains_all_keyframes(self): + """Global map should contain transformed points from every keyframe.""" + config = PGOConfig( + key_pose_delta_trans=0.3, + global_map_voxel_size=0.0, # No downsampling + ) + pgo = SimplePGOReference(config) + + n_keyframes = 10 + pts_per_frame = 100 + for i in range(n_keyframes): + pos = np.array([i * 1.0, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=pts_per_frame, seed=i) + pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + pgo.smooth_and_update() + + assert len(pgo.key_poses) == n_keyframes + global_map = pgo.build_global_map(voxel_size=0.0) + assert len(global_map) == n_keyframes * pts_per_frame, ( + f"Expected {n_keyframes * pts_per_frame} points, got {len(global_map)}" + ) + + def test_global_map_points_are_in_world_frame(self): + """Points in the global map should be transformed to world coordinates.""" + config = PGOConfig( + key_pose_delta_trans=0.3, + global_map_voxel_size=0.0, + ) + pgo = SimplePGOReference(config) + + # Add keyframe at origin with cloud centered at body origin + cloud_body = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + pgo.add_key_pose(np.eye(3), np.array([10.0, 20.0, 0.0]), 0.0, cloud_body) + pgo.smooth_and_update() + + global_map = pgo.build_global_map(voxel_size=0.0) + + # Points should be shifted by the keyframe position (10, 20, 0) + expected = cloud_body + np.array([10.0, 20.0, 0.0]) + np.testing.assert_allclose(global_map, expected, atol=1e-6) + + def test_global_map_with_rotation(self): + """Global map should correctly rotate body-frame points.""" + config = PGOConfig( + key_pose_delta_trans=0.3, + global_map_voxel_size=0.0, + ) + pgo = SimplePGOReference(config) + + # 90 degree yaw rotation + r_90 = make_rotation(90.0) + cloud_body = np.array([[1.0, 0.0, 0.0]]) # Point along body x-axis + pgo.add_key_pose(r_90, np.zeros(3), 0.0, cloud_body) + pgo.smooth_and_update() + + global_map = pgo.build_global_map(voxel_size=0.0) + + # After 90 deg yaw, body x-axis → world y-axis + np.testing.assert_allclose(global_map[0, 0], 0.0, atol=1e-6) + np.testing.assert_allclose(global_map[0, 1], 1.0, atol=1e-6) + np.testing.assert_allclose(global_map[0, 2], 0.0, atol=1e-6) + + def test_global_map_grows_with_trajectory(self): + """Global map should grow as more keyframes are added.""" + config = PGOConfig(key_pose_delta_trans=0.3, global_map_voxel_size=0.0) + pgo = SimplePGOReference(config) + + sizes = [] + for i in range(20): + pos = np.array([i * 0.5, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=50, seed=i) + pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + pgo.smooth_and_update() + sizes.append(len(pgo.build_global_map(voxel_size=0.0))) + + # Map should be monotonically growing + for j in range(1, len(sizes)): + assert sizes[j] >= sizes[j - 1], f"Map shrunk: {sizes[j]} < {sizes[j - 1]} at step {j}" + + def test_global_map_voxel_downsampling(self): + """Downsampled global map should have fewer points.""" + config = PGOConfig(key_pose_delta_trans=0.3) + pgo = SimplePGOReference(config) + + for i in range(10): + pos = np.array([i * 1.0, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=200, seed=i) + pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + pgo.smooth_and_update() + + map_full = pgo.build_global_map(voxel_size=0.0) + map_ds = pgo.build_global_map(voxel_size=0.5) + + assert len(map_ds) < len(map_full), ( + f"Downsampled map ({len(map_ds)}) should be smaller than full ({len(map_full)})" + ) + assert len(map_ds) > 0 + + +# ─── Loop Closure Global Map Tests ────────────────────────────────────────── + + +class TestLoopClosureGlobalMap: + """Test that loop closure correctly updates the global map.""" + + def test_global_map_updates_after_loop_closure(self): + """After loop closure, global map positions should be corrected.""" + config = PGOConfig( + key_pose_delta_trans=0.4, + key_pose_delta_deg=10.0, + loop_search_radius=15.0, + loop_time_thresh=30.0, + loop_score_thresh=2.0, # Very relaxed for synthetic data + loop_submap_half_range=3, + submap_resolution=0.2, + min_loop_detect_duration=0.0, + global_map_voxel_size=0.0, + max_icp_iterations=30, + max_icp_correspondence_dist=20.0, + ) + pgo = SimplePGOReference(config) + + # Drive a square trajectory + side = 20.0 + waypoints = [ + np.array([0.0, 0.0, 0.0]), + np.array([side, 0.0, 0.0]), + np.array([side, side, 0.0]), + np.array([0.0, side, 0.0]), + np.array([0.0, 0.0, 0.0]), # Return to start + ] + drive_trajectory(pgo, waypoints, step=0.4, time_per_step=1.0) + + # Should have accumulated keyframes + assert len(pgo.key_poses) > 20 + + # Build global map + global_map = pgo.build_global_map(voxel_size=0.0) + assert len(global_map) > 0 + + # If loop closure detected, verify map is consistent + if len(pgo.history_pairs) > 0: + # The start and end keyframe positions should be close + start_pos = pgo.key_poses[0].t_global + end_pos = pgo.key_poses[-1].t_global + # After loop closure correction + dist = np.linalg.norm(end_pos - start_pos) + assert dist < 15.0, f"After loop closure, start-end distance {dist:.2f}m is too large" + + def test_global_map_all_keyframes_present_after_loop(self): + """After loop closure, ALL keyframes should still be in the map.""" + config = PGOConfig( + key_pose_delta_trans=0.3, + loop_search_radius=15.0, + loop_time_thresh=20.0, + loop_score_thresh=2.0, + min_loop_detect_duration=0.0, + global_map_voxel_size=0.0, + max_icp_correspondence_dist=20.0, + ) + pgo = SimplePGOReference(config) + + pts_per_frame = 50 + n_poses = 0 + for i in range(40): + pos = np.array([i * 0.5, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=pts_per_frame, seed=i % 5) + added = pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + if added: + pgo.smooth_and_update() + n_poses += 1 + + global_map = pgo.build_global_map(voxel_size=0.0) + expected_points = n_poses * pts_per_frame + assert len(global_map) == expected_points, ( + f"Expected {expected_points} points from {n_poses} keyframes, got {len(global_map)}" + ) + + +# ─── PointCloud2 Export Tests ──────────────────────────────────────────────── + + +class TestGlobalMapExport: + """Test that global map can be exported as valid PointCloud2.""" + + def test_export_as_pointcloud2(self): + """Global map numpy array should convert to valid PointCloud2.""" + config = PGOConfig(key_pose_delta_trans=0.3, global_map_voxel_size=0.0) + pgo = SimplePGOReference(config) + + for i in range(5): + pos = np.array([i * 1.0, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=100, seed=i) + pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + pgo.smooth_and_update() + + global_map = pgo.build_global_map(voxel_size=0.1) + assert len(global_map) > 0 + + # Convert to PointCloud2 + pc2 = PointCloud2.from_numpy( + global_map.astype(np.float32), + frame_id="map", + timestamp=time.time(), + ) + + # Verify round-trip + points_back, _ = pc2.as_numpy() + assert points_back.shape[0] > 0 + assert points_back.shape[1] >= 3 + + def test_export_empty_map(self): + """Exporting an empty global map should not crash.""" + pgo = SimplePGOReference() + global_map = pgo.build_global_map() + assert len(global_map) == 0 + + def test_export_large_map(self): + """Test export with a larger accumulated map (many keyframes).""" + config = PGOConfig( + key_pose_delta_trans=0.3, + global_map_voxel_size=0.2, + ) + pgo = SimplePGOReference(config) + + for i in range(50): + pos = np.array([i * 0.5, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=200, seed=i) + pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + pgo.smooth_and_update() + + global_map = pgo.build_global_map() + assert len(global_map) > 0 + + # Should be downsampled (less than 50 * 200 = 10000) + assert len(global_map) < 10000 + + # Convert to PointCloud2 + pc2 = PointCloud2.from_numpy( + global_map.astype(np.float32), + frame_id="map", + timestamp=time.time(), + ) + points_back, _ = pc2.as_numpy() + assert len(points_back) == len(global_map) + + def test_global_map_spatial_extent(self): + """Global map should span the spatial extent of the trajectory.""" + config = PGOConfig( + key_pose_delta_trans=0.3, + global_map_voxel_size=0.0, + ) + pgo = SimplePGOReference(config) + + # Drive 10 meters in x direction + for i in range(30): + pos = np.array([i * 0.5, 0.0, 0.0]) + cloud = make_random_cloud(np.zeros(3), n_points=50, spread=0.5, seed=i) + pgo.add_key_pose(np.eye(3), pos, float(i), cloud) + pgo.smooth_and_update() + + global_map = pgo.build_global_map(voxel_size=0.0) + + # Map x-range should roughly span trajectory + x_min = global_map[:, 0].min() + x_max = global_map[:, 0].max() + x_span = x_max - x_min + + # Should span close to the trajectory length (15m) +/- cloud spread + assert x_span > 10.0, f"X-span {x_span:.1f}m too narrow for 15m trajectory" + assert x_span < 25.0, f"X-span {x_span:.1f}m too wide" diff --git a/dimos/navigation/smartnav/tests/test_sim_pipeline.py b/dimos/navigation/smartnav/tests/test_sim_pipeline.py new file mode 100644 index 0000000000..819a4550fa --- /dev/null +++ b/dimos/navigation/smartnav/tests/test_sim_pipeline.py @@ -0,0 +1,229 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test: verify modules survive the real blueprint deployment path. + +These tests exercise the actual framework machinery -- pickling, transport wiring, +cross-process communication -- not just direct method calls. +""" + +import pickle +import time + +import numpy as np + +from dimos.core.stream import In, Out +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.tui_control.tui_control import TUIControlModule +from dimos.simulation.unity.module import UnityBridgeModule + + +class TestModulePickling: + """Every module must survive pickle round-trip (the deployment path).""" + + def test_sensor_scan_generation_pickles(self): + m = SensorScanGeneration() + m2 = pickle.loads(pickle.dumps(m)) + assert hasattr(m2, "_lock") + assert m2._latest_odom is None + + def test_unity_bridge_pickles(self): + m = UnityBridgeModule(sim_rate=200.0) + m2 = pickle.loads(pickle.dumps(m)) + assert hasattr(m2, "_cmd_lock") + assert m2._running is False + + def test_tui_control_pickles(self): + m = TUIControlModule(max_speed=2.0) + m2 = pickle.loads(pickle.dumps(m)) + assert hasattr(m2, "_lock") + assert m2._fwd == 0.0 + + def test_all_native_modules_pickle(self): + """NativeModule wrappers must also pickle cleanly.""" + from dimos.navigation.smartnav.modules.far_planner.far_planner import FarPlanner + from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner + from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower + from dimos.navigation.smartnav.modules.tare_planner.tare_planner import TarePlanner + from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import ( + TerrainAnalysis, + ) + + for cls in [TerrainAnalysis, LocalPlanner, PathFollower, FarPlanner, TarePlanner]: + m = cls() + m2 = pickle.loads(pickle.dumps(m)) + assert type(m2) is cls, f"{cls.__name__} failed pickle round-trip" + + +class TestTransportWiring: + """Test that modules publish/subscribe through real LCM transports.""" + + def test_unity_bridge_publishes_odometry_via_transport(self): + """UnityBridge sim loop should publish through _transport, not .publish().""" + m = UnityBridgeModule(sim_rate=200.0) + + # Wire a real LCM transport to the odometry output + transport = LCMTransport("/_test/smartnav/odom", Odometry) + m.odometry._transport = transport + + received = [] + transport.subscribe(lambda msg: received.append(msg)) + + # Simulate one odometry publish (same code path as _sim_loop) + quat = Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)) + odom = Odometry( + ts=time.time(), + frame_id="map", + child_frame_id="sensor", + pose=Pose( + position=[1.0, 2.0, 0.75], + orientation=[quat.x, quat.y, quat.z, quat.w], + ), + ) + m.odometry._transport.publish(odom) + + # LCM transport delivers asynchronously -- give it a moment + time.sleep(0.1) + assert len(received) >= 1 + assert abs(received[0].x - 1.0) < 0.01 + + def test_sensor_scan_subscribes_and_publishes_via_transport(self): + """SensorScanGeneration should work entirely through transports.""" + m = SensorScanGeneration() + + # Wire transports (topic string must NOT include #type suffix -- type is the 2nd arg) + odom_transport = LCMTransport("/_test/smartnav/scan_gen/odom", Odometry) + scan_in_transport = LCMTransport("/_test/smartnav/scan_gen/registered_scan", PointCloud2) + scan_out_transport = LCMTransport("/_test/smartnav/scan_gen/sensor_scan", PointCloud2) + odom_out_transport = LCMTransport("/_test/smartnav/scan_gen/odom_at_scan", Odometry) + + m.odometry._transport = odom_transport + m.registered_scan._transport = scan_in_transport + m.sensor_scan._transport = scan_out_transport + m.odometry_at_scan._transport = odom_out_transport + + # Start the module (subscribes via transport) + m.start() + + # Collect outputs + scan_results = [] + scan_out_transport.subscribe(lambda msg: scan_results.append(msg)) + + # Publish odometry + quat = Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)) + odom = Odometry( + ts=time.time(), + frame_id="map", + child_frame_id="sensor", + pose=Pose( + position=[0.0, 0.0, 0.0], + orientation=[quat.x, quat.y, quat.z, quat.w], + ), + ) + odom_transport.publish(odom) + time.sleep(0.05) + + # Publish a point cloud + points = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=np.float32) + cloud = PointCloud2.from_numpy(points, frame_id="map", timestamp=time.time()) + scan_in_transport.publish(cloud) + time.sleep(0.2) + + assert len(scan_results) >= 1 + assert scan_results[0].frame_id == "sensor_at_scan" + + def test_tui_publishes_twist_via_transport(self): + """TUI module should publish cmd_vel through its transport.""" + m = TUIControlModule(max_speed=2.0, publish_rate=50.0) + + transport = LCMTransport("/_test/smartnav/tui/cmd_vel", Twist) + m.cmd_vel._transport = transport + + # Also wire way_point so it doesn't error + from dimos.msgs.geometry_msgs.PointStamped import PointStamped + + wp_transport = LCMTransport("/_test/smartnav/tui/way_point", PointStamped) + m.way_point._transport = wp_transport + + received = [] + transport.subscribe(lambda msg: received.append(msg)) + + m._handle_key("w") # forward + m.start() + time.sleep(0.15) # let publish loop run a few times + m.stop() + + assert len(received) >= 1 + assert received[-1].linear.x > 0 # forward velocity + + +class TestPortTypeCompatibility: + """Verify that module port types are compatible for autoconnect.""" + + def test_all_stream_types_match(self): + from typing import get_args, get_origin, get_type_hints + + from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner + from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower + from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, + ) + from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import ( + TerrainAnalysis, + ) + from dimos.simulation.unity.module import UnityBridgeModule + + def get_streams(cls): + hints = get_type_hints(cls) + streams = {} + for name, hint in hints.items(): + origin = get_origin(hint) + if origin in (In, Out): + direction = "in" if origin is In else "out" + msg_type = get_args(hint)[0] + streams[name] = (direction, msg_type) + return streams + + sim = get_streams(UnityBridgeModule) + scan = get_streams(SensorScanGeneration) + terrain = get_streams(TerrainAnalysis) + planner = get_streams(LocalPlanner) + follower = get_streams(PathFollower) + + # Odometry types must match across all consumers + odom_type = sim["odometry"][1] + assert scan["odometry"][1] == odom_type + assert terrain["odometry"][1] == odom_type + assert planner["odometry"][1] == odom_type + assert follower["odometry"][1] == odom_type + + # Path: planner out == follower in + assert planner["path"][1] == follower["path"][1] + + # cmd_vel: follower out == sim in + assert follower["cmd_vel"][1] == sim["cmd_vel"][1] + + # registered_scan: all consumers match + pc_type = scan["registered_scan"][1] + assert terrain["registered_scan"][1] == pc_type + assert planner["registered_scan"][1] == pc_type diff --git a/dimos/navigation/smartnav/tests/test_waypoint_nav.py b/dimos/navigation/smartnav/tests/test_waypoint_nav.py new file mode 100644 index 0000000000..675f74d58c --- /dev/null +++ b/dimos/navigation/smartnav/tests/test_waypoint_nav.py @@ -0,0 +1,271 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test: waypoint navigation produces path + movement. + +Sets a waypoint at (10, 0) and verifies: +1. TerrainAnalysis produces terrain_map +2. LocalPlanner produces a path toward the goal +3. PathFollower produces non-zero cmd_vel +4. Robot position moves toward the waypoint + +This is the core nav stack test without any exploration planner. +""" + +from __future__ import annotations + +import math +from pathlib import Path +import platform +import threading +import time +from typing import Any + +import numpy as np +import pytest + +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +_NATIVE_DIR = Path(__file__).resolve().parent.parent +_HAS_BINARIES = all( + (_NATIVE_DIR / d / "bin" / name).exists() + for d, name in [ + ("result-terrain-analysis", "terrain_analysis"), + ("result-local-planner", "local_planner"), + ("result-path-follower", "path_follower"), + ] +) +_IS_LINUX_X86 = platform.system() == "Linux" and platform.machine() in ("x86_64", "AMD64") + +pytestmark = [ + pytest.mark.slow, + pytest.mark.skipif(not _IS_LINUX_X86, reason="Native modules require Linux x86_64"), + pytest.mark.skipif(not _HAS_BINARIES, reason="Native binaries not built"), +] + + +def _make_ground_cloud(rx: float, ry: float) -> np.ndarray: + """Flat ground + obstacle wall at x=8 to test path planning around it.""" + pts = [] + # Ground plane + step = 1.0 + for x in np.arange(rx - 12, rx + 12, step): + for y in np.arange(ry - 12, ry + 12, step): + pts.append([x, y, 0.0]) + # Wall obstacle at x=5, y=-2..2, z=0..1 (partial blockage) + for y in np.arange(-2, 2, 0.3): + for z in np.arange(0, 1.0, 0.3): + pts.append([5.0, y, z]) + return np.array(pts, dtype=np.float32) + + +class SimVehicleConfig(ModuleConfig): + sensor_rate: float = 5.0 + sim_rate: float = 50.0 + + +class SimVehicle(Module[SimVehicleConfig]): + """Kinematic vehicle sim: publishes lidar + odom, integrates cmd_vel.""" + + default_config = SimVehicleConfig + cmd_vel: In[Twist] + registered_scan: Out[PointCloud2] + odometry: Out[Odometry] + + def __init__(self, **kwargs): # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self.x = 0.0 + self.y = 0.0 + self.z = 0.75 + self.yaw = 0.0 + self._fwd = 0.0 + self._left = 0.0 + self._yr = 0.0 + self._lock = threading.Lock() + self._running = False + self._threads: list[threading.Thread] = [] + + def __getstate__(self) -> dict[str, Any]: + s = super().__getstate__() + for k in ("_lock", "_threads"): + s.pop(k, None) + return s + + def __setstate__(self, s: dict) -> None: + super().__setstate__(s) + self._lock = threading.Lock() + self._threads = [] + + def start(self) -> None: + self.cmd_vel._transport.subscribe(self._on_cmd) + self._running = True + for fn in (self._sim_loop, self._sensor_loop): + t = threading.Thread(target=fn, daemon=True) + t.start() + self._threads.append(t) + + def stop(self) -> None: + self._running = False + for t in self._threads: + t.join(timeout=3) + super().stop() + + def _on_cmd(self, tw: Twist) -> None: + with self._lock: + self._fwd = tw.linear.x + self._left = tw.linear.y + self._yr = tw.angular.z + + def _sim_loop(self) -> None: + dt = 1.0 / self.config.sim_rate + while self._running: + t0 = time.monotonic() + with self._lock: + fwd, left, yr = self._fwd, self._left, self._yr + self.yaw += dt * yr + cy, sy = math.cos(self.yaw), math.sin(self.yaw) + self.x += dt * (cy * fwd - sy * left) + self.y += dt * (sy * fwd + cy * left) + now = time.time() + q = Quaternion.from_euler(Vector3(0.0, 0.0, self.yaw)) + self.odometry._transport.publish( + Odometry( + ts=now, + frame_id="map", + child_frame_id="sensor", + pose=Pose(position=[self.x, self.y, self.z], orientation=[q.x, q.y, q.z, q.w]), + twist=Twist(linear=[fwd, left, 0], angular=[0, 0, yr]), + ) + ) + self.tf.publish( + Transform( + translation=Vector3(self.x, self.y, self.z), + rotation=q, + frame_id="map", + child_frame_id="sensor", + ts=now, + ) + ) + sl = dt - (time.monotonic() - t0) + if sl > 0: + time.sleep(sl) + + def _sensor_loop(self) -> None: + dt = 1.0 / self.config.sensor_rate + while self._running: + now = time.time() + cloud = _make_ground_cloud(self.x, self.y) + self.registered_scan._transport.publish( + PointCloud2.from_numpy(cloud, frame_id="map", timestamp=now) + ) + time.sleep(dt) + + +def test_waypoint_nav_produces_path_and_movement(): + """Send waypoint at (10,0), verify terrain_map + path + non-zero cmd_vel.""" + from dimos.core.blueprints import autoconnect + from dimos.msgs.geometry_msgs.PointStamped import PointStamped + from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner + from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower + from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, + ) + from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis + + terrain_msgs: list = [] + path_msgs: list = [] + cmd_msgs: list[tuple] = [] + lock = threading.Lock() + + blueprint = autoconnect( + SimVehicle.blueprint(), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint(), + LocalPlanner.blueprint(extra_args=["--autonomyMode", "true"]), + PathFollower.blueprint(extra_args=["--autonomyMode", "true"]), + ) + coordinator = blueprint.build() + + terrain = coordinator.get_instance(TerrainAnalysis) + planner = coordinator.get_instance(LocalPlanner) + follower = coordinator.get_instance(PathFollower) + + terrain.terrain_map._transport.subscribe( + lambda m: (lock.acquire(), terrain_msgs.append(1), lock.release()) + ) + planner.path._transport.subscribe( + lambda m: (lock.acquire(), path_msgs.append(1), lock.release()) + ) + follower.cmd_vel._transport.subscribe( + lambda m: ( + lock.acquire(), + cmd_msgs.append((m.linear.x, m.linear.y, m.angular.z)), + lock.release(), + ) + ) + + # Send waypoint after modules warm up + def _send_wp(): + time.sleep(2.0) + wp = PointStamped(x=10.0, y=0.0, z=0.0, frame_id="map") + planner.way_point._transport.publish(wp) + print("[test] Sent waypoint (10, 0)") + + threading.Thread(target=_send_wp, daemon=True).start() + + try: + coordinator.start() + + # Wait up to 20s for all pipeline stages + deadline = time.monotonic() + 20.0 + while time.monotonic() < deadline: + with lock: + ok = len(terrain_msgs) > 0 and len(path_msgs) > 0 and len(cmd_msgs) > 0 + if ok: + break + time.sleep(0.5) + + # Let movement accumulate + time.sleep(5.0) + + with lock: + n_terrain = len(terrain_msgs) + n_path = len(path_msgs) + n_cmd = len(cmd_msgs) + nonzero = [ + (vx, vy, wz) + for vx, vy, wz in cmd_msgs + if abs(vx) > 0.01 or abs(vy) > 0.01 or abs(wz) > 0.01 + ] + + print( + f"[test] terrain_map: {n_terrain}, path: {n_path}, " + f"cmd_vel: {n_cmd} (nonzero: {len(nonzero)})" + ) + + assert n_terrain > 0, "TerrainAnalysis produced no terrain_map" + assert n_path > 0, "LocalPlanner produced no path" + assert n_cmd > 0, "PathFollower produced no cmd_vel" + assert len(nonzero) > 0, f"All {n_cmd} cmd_vel messages were zero — robot not moving" + + finally: + coordinator.stop() diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index c6d8c96625..13fb26cbb5 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.mcp.mcp_client import McpClient -from dimos.agents.mcp.mcp_server import McpServer +from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera from dimos.hardware.sensors.camera.zed.compat import ZEDCamera from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_choice = "zed" @@ -34,7 +33,6 @@ demo_object_scene_registration = autoconnect( camera_module, ObjectSceneRegistrationModule.blueprint(target_frame="world", prompt_mode=YoloePromptMode.LRPC), - FoxgloveBridge.blueprint(), - McpServer.blueprint(), - McpClient.blueprint(), + vis_module("foxglove"), + Agent.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 00690d514f..970f49bc4a 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -56,6 +56,10 @@ "mid360-fastlio": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio", "mid360-fastlio-voxels": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_voxels", "mid360-fastlio-voxels-native": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_voxels_native", + "simulation-blueprint": "dimos.navigation.smartnav.blueprints.simulation:simulation_blueprint", + "simulation-pgo-blueprint": "dimos.navigation.smartnav.blueprints.simulation_pgo:simulation_pgo_blueprint", + "simulation-route-blueprint": "dimos.navigation.smartnav.blueprints.simulation_route:simulation_route_blueprint", + "simulation-slam-blueprint": "dimos.navigation.smartnav.blueprints.simulation_slam:simulation_slam_blueprint", "teleop-phone": "dimos.teleop.phone.blueprints:teleop_phone", "teleop-phone-go2": "dimos.teleop.phone.blueprints:teleop_phone_go2", "teleop-phone-go2-fleet": "dimos.teleop.phone.blueprints:teleop_phone_go2_fleet", @@ -73,6 +77,18 @@ "unitree-g1-detection": "dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1_detection:unitree_g1_detection", "unitree-g1-full": "dimos.robot.unitree.g1.blueprints.agentic.unitree_g1_full:unitree_g1_full", "unitree-g1-joystick": "dimos.robot.unitree.g1.blueprints.basic.unitree_g1_joystick:unitree_g1_joystick", + "unitree-g1-nav-arise-onboard": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_arise_onboard:unitree_g1_nav_arise_onboard", + "unitree-g1-nav-arise-sim": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_arise_sim:unitree_g1_nav_arise_sim", + "unitree-g1-nav-basic-onboard": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_basic_onboard:unitree_g1_nav_basic_onboard", + "unitree-g1-nav-basic-sim": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_basic_sim:unitree_g1_nav_basic_sim", + "unitree-g1-nav-explore-onboard": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_explore_onboard:unitree_g1_nav_explore_onboard", + "unitree-g1-nav-explore-sim": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_explore_sim:unitree_g1_nav_explore_sim", + "unitree-g1-nav-far-onboard": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_far_onboard:unitree_g1_nav_far_onboard", + "unitree-g1-nav-onboard": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_onboard:unitree_g1_nav_onboard", + "unitree-g1-nav-pgo-onboard": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_pgo_onboard:unitree_g1_nav_pgo_onboard", + "unitree-g1-nav-sim": "dimos.robot.unitree.g1.blueprints.navigation.unitree_g1_nav_sim:unitree_g1_nav_sim", + "unitree-g1-rosnav-onboard": "dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1_rosnav_onboard:unitree_g1_rosnav_onboard", + "unitree-g1-rosnav-sim": "dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1_rosnav_sim:unitree_g1_rosnav_sim", "unitree-g1-shm": "dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1_shm:unitree_g1_shm", "unitree-g1-sim": "dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1_sim:unitree_g1_sim", "unitree-go2": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2:unitree_go2", @@ -84,6 +100,7 @@ "unitree-go2-detection": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_detection:unitree_go2_detection", "unitree-go2-fleet": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_fleet:unitree_go2_fleet", "unitree-go2-ros": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_ros:unitree_go2_ros", + "unitree-go2-smartnav": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_smartnav:unitree_go2_smartnav", "unitree-go2-spatial": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial:unitree_go2_spatial", "unitree-go2-temporal-memory": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_temporal_memory:unitree_go2_temporal_memory", "unitree-go2-vlm-stream-test": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_vlm_stream_test:unitree_go2_vlm_stream_test", @@ -98,11 +115,15 @@ all_modules = { + "arise-sim-adapter": "dimos.navigation.smartnav.modules.arise_sim_adapter.AriseSimAdapter", + "arise-slam": "dimos.navigation.smartnav.modules.arise_slam.arise_slam.AriseSLAM", "arm-teleop-module": "dimos.teleop.quest.quest_extensions.ArmTeleopModule", "b-box-navigation-module": "dimos.navigation.bbox_navigation.BBoxNavigationModule", "b1-connection-module": "dimos.robot.unitree.b1.connection.B1ConnectionModule", "camera-module": "dimos.hardware.sensors.camera.module.CameraModule", "cartesian-motion-controller": "dimos.manipulation.control.servo_control.cartesian_motion_controller.CartesianMotionController", + "click-to-goal": "dimos.navigation.smartnav.modules.click_to_goal.click_to_goal.ClickToGoal", + "cmd-vel-mux": "dimos.navigation.smartnav.modules.cmd_vel_mux.CmdVelMux", "control-coordinator": "dimos.control.coordinator.ControlCoordinator", "cost-mapper": "dimos.mapping.costmapper.CostMapper", "demo-calculator-skill": "dimos.agents.skills.demo_calculator_skill.DemoCalculatorSkill", @@ -114,11 +135,15 @@ "drone-tracking-module": "dimos.robot.drone.drone_tracking_module.DroneTrackingModule", "embedding-memory": "dimos.memory.embedding.EmbeddingMemory", "emitter-module": "dimos.utils.demo_image_encoding.EmitterModule", + "far-planner": "dimos.navigation.smartnav.modules.far_planner.far_planner.FarPlanner", "fast-lio2": "dimos.hardware.sensors.lidar.fastlio2.module.FastLio2", "foxglove-bridge": "dimos.robot.foxglove_bridge.FoxgloveBridge", "g1-connection": "dimos.robot.unitree.g1.connection.G1Connection", "g1-connection-base": "dimos.robot.unitree.g1.connection.G1ConnectionBase", - "g1-sim-connection": "dimos.robot.unitree.g1.sim.G1SimConnection", + "g1-high-level-dds-sdk": "dimos.robot.unitree.g1.effectors.high_level.dds_sdk.G1HighLevelDdsSdk", + "g1-high-level-web-rtc": "dimos.robot.unitree.g1.effectors.high_level.webrtc.G1HighLevelWebRtc", + "g1-sim-connection": "dimos.robot.unitree.g1.mujoco_sim.G1SimConnection", + "global-map": "dimos.navigation.smartnav.modules.global_map.global_map.GlobalMap", "go2-connection": "dimos.robot.unitree.go2.connection.GO2Connection", "go2-fleet-connection": "dimos.robot.unitree.go2.fleet_connection.Go2FleetConnection", "google-maps-skill-container": "dimos.agents.skills.google_maps_skill_container.GoogleMapsSkillContainer", @@ -130,6 +155,7 @@ "joystick-module": "dimos.robot.unitree.b1.joystick_module.JoystickModule", "keyboard-teleop": "dimos.robot.unitree.keyboard_teleop.KeyboardTeleop", "keyboard-teleop-module": "dimos.teleop.keyboard.keyboard_teleop_module.KeyboardTeleopModule", + "local-planner": "dimos.navigation.smartnav.modules.local_planner.local_planner.LocalPlanner", "manipulation-module": "dimos.manipulation.manipulation_module.ManipulationModule", "map": "dimos.robot.unitree.type.map.Map", "mcp-client": "dimos.agents.mcp.mcp_client.McpClient", @@ -144,11 +170,14 @@ "object-tracker2-d": "dimos.perception.object_tracker_2d.ObjectTracker2D", "object-tracker3-d": "dimos.perception.object_tracker_3d.ObjectTracker3D", "object-tracking": "dimos.perception.object_tracker.ObjectTracking", + "odom-adapter": "dimos.navigation.smartnav.modules.odom_adapter.odom_adapter.OdomAdapter", "osm-skill": "dimos.agents.skills.osm.OsmSkill", + "path-follower": "dimos.navigation.smartnav.modules.path_follower.path_follower.PathFollower", "patrolling-module": "dimos.navigation.patrolling.module.PatrollingModule", "perceive-loop-skill": "dimos.perception.perceive_loop_skill.PerceiveLoopSkill", "person-follow-skill-container": "dimos.agents.skills.person_follow.PersonFollowSkillContainer", "person-tracker": "dimos.perception.detection.person_tracker.PersonTracker", + "pgo": "dimos.navigation.smartnav.modules.pgo.pgo.PGO", "phone-teleop-module": "dimos.teleop.phone.phone_teleop_module.PhoneTeleopModule", "pick-and-place-module": "dimos.manipulation.pick_and_place_module.PickAndPlaceModule", "quest-teleop-module": "dimos.teleop.quest.quest_teleop_module.QuestTeleopModule", @@ -157,11 +186,17 @@ "reid-module": "dimos.perception.detection.reid.module.ReidModule", "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module.ReplanningAStarPlanner", "rerun-bridge-module": "dimos.visualization.rerun.bridge.RerunBridgeModule", - "ros-nav": "dimos.navigation.rosnav.ROSNav", + "rerun-web-socket-server": "dimos.visualization.rerun.websocket_server.RerunWebSocketServer", + "ros-nav": "dimos.navigation.rosnav_legacy.ROSNav", + "sensor-scan-generation": "dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation.SensorScanGeneration", "simple-phone-teleop": "dimos.teleop.phone.phone_extensions.SimplePhoneTeleop", "spatial-memory": "dimos.perception.spatial_perception.SpatialMemory", "speak-skill": "dimos.agents.skills.speak_skill.SpeakSkill", + "tare-planner": "dimos.navigation.smartnav.modules.tare_planner.tare_planner.TarePlanner", "temporal-memory": "dimos.perception.experimental.temporal_memory.temporal_memory.TemporalMemory", + "terrain-analysis": "dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis.TerrainAnalysis", + "terrain-map-ext": "dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext.TerrainMapExt", + "tui-control-module": "dimos.navigation.smartnav.modules.tui_control.tui_control.TUIControlModule", "twist-teleop-module": "dimos.teleop.quest.quest_extensions.TwistTeleopModule", "unitree-g1-skill-container": "dimos.robot.unitree.g1.skill_container.UnitreeG1SkillContainer", "unitree-skill-container": "dimos.robot.unitree.unitree_skill_container.UnitreeSkillContainer", diff --git a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py index 3294da1772..9166a4de6e 100644 --- a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py +++ b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py @@ -20,7 +20,7 @@ from dimos.robot.unitree.g1.blueprints.primitive.uintree_g1_primitive_no_nav import ( uintree_g1_primitive_no_nav, ) -from dimos.robot.unitree.g1.sim import G1SimConnection +from dimos.robot.unitree.g1.mujoco_sim import G1SimConnection unitree_g1_basic_sim = autoconnect( uintree_g1_primitive_no_nav, diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_arise_onboard.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_arise_onboard.py new file mode 100644 index 0000000000..75b21ef722 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_arise_onboard.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 with AriseSLAM on real hardware. + +Uses the C++ AriseSLAM module (feature-based LiDAR-IMU SLAM) instead of +FastLio2. The raw Mid-360 driver provides body-frame point clouds and IMU +data; AriseSLAM produces world-frame registered scans and odometry that feed +the rest of the SmartNav stack. + +Data flow: + Mid360 → raw lidar (body frame) + imu + → AriseSLAM → registered_scan (world frame) + odometry + → SensorScanGeneration → TerrainAnalysis → LocalPlanner → PathFollower + → G1HighLevelDdsSdk +""" + +from __future__ import annotations + +import os +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.hardware.sensors.lidar.livox.module import Mid360 +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + global_map_override, + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.arise_slam.arise_slam import AriseSLAM +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.global_map.global_map import GlobalMap +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import G1HighLevelDdsSdk +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Spatial3DView(origin="world", name="3D"), + ) + + +_rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/global_map": global_map_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +unitree_g1_nav_arise_onboard = ( + autoconnect( + Mid360.blueprint( + host_ip=os.getenv("LIDAR_HOST_IP", "192.168.123.164"), + lidar_ip=os.getenv("LIDAR_IP", "192.168.123.120"), + enable_imu=True, + ), + AriseSLAM.blueprint( + extra_args=[ + "--scanVoxelSize", + "0.1", + "--maxRange", + "50.0", + ] + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--vehicleHeight", + "1.2", + ] + ), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.5", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--maxAccel", + "2.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + GlobalMap.blueprint(), + G1HighLevelDdsSdk.blueprint(), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config), + ) + .remappings( + [ + # Mid360 outputs "lidar" (body frame); AriseSLAM expects "raw_points" + (Mid360, "lidar", "raw_points"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1") +) + + +def main() -> None: + unitree_g1_nav_arise_onboard.build().loop() + + +__all__ = ["unitree_g1_nav_arise_onboard"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_arise_sim.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_arise_sim.py new file mode 100644 index 0000000000..92c4134439 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_arise_sim.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 nav sim with AriseSLAM — tests SLAM in simulation. + +Instead of using Unity's ground-truth odometry, this blueprint feeds +the sim's lidar + synthetic IMU into AriseSLAM, which estimates the +pose via scan-to-map matching. This lets you test and tune SLAM +without real hardware. + +AriseSimAdapter handles both: + 1. Transforming world-frame scans → body-frame using Unity's odom + 2. Synthesizing IMU from Unity's odom (orientation + angular vel + gravity) + +Data flow: + Unity → registered_scan + odometry → AriseSimAdapter → raw_points + imu + → AriseSLAM → registered_scan + odometry → nav stack + +Note: AriseSLAM's odometry replaces Unity's ground-truth, so navigation +accuracy depends on how well SLAM tracks. Any drift is real SLAM drift. +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + goal_path_override, + path_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.arise_sim_adapter import AriseSimAdapter +from dimos.navigation.smartnav.modules.arise_slam.arise_slam import AriseSLAM +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.cmd_vel_mux import CmdVelMux +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +_vis = vis_module( + viewer_backend=global_config.viewer, + rerun_config={ + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, + }, +) + +unitree_g1_nav_arise_sim = ( + autoconnect( + # Simulator — provides ground-truth registered_scan and odometry + UnityBridgeModule.blueprint( + unity_binary="", + unity_scene="home_building_1", + vehicle_height=1.24, + ), + # Adapter: transforms scan to body-frame + synthesizes IMU from odom + AriseSimAdapter.blueprint(), + # SLAM — estimates pose from body-frame lidar + synthetic IMU + AriseSLAM.blueprint(use_imu=True), + # Nav stack — uses SLAM's odometry + registered_scan (NOT Unity's) + TerrainAnalysis.blueprint(extra_args=["--obstacleHeightThre", "0.2", "--maxRelZ", "1.5"]), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + CmdVelMux.blueprint(), + _vis, + ) + .remappings( + [ + (PathFollower, "cmd_vel", "nav_cmd_vel"), + (UnityBridgeModule, "terrain_map", "terrain_map_ext"), + # Rename Unity's outputs so they don't collide with AriseSLAM's. + # The adapter reads sim_* and AriseSLAM outputs the canonical names. + (UnityBridgeModule, "registered_scan", "sim_registered_scan"), + (UnityBridgeModule, "odometry", "sim_odometry"), + (AriseSimAdapter, "registered_scan", "sim_registered_scan"), + (AriseSimAdapter, "odometry", "sim_odometry"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1", simulation=True) +) + + +def main() -> None: + unitree_g1_nav_arise_sim.build().loop() + + +__all__ = ["unitree_g1_nav_arise_sim"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_basic_onboard.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_basic_onboard.py new file mode 100644 index 0000000000..19af404d46 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_basic_onboard.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 basic nav onboard — local planner + path follower only (no FAR/PGO). + +Lightweight navigation stack for real hardware: uses SmartNav C++ native +modules for terrain analysis, local planning, and path following. +FastLio2 provides SLAM from a Livox Mid-360 lidar. No global route +planner (FAR) or loop closure (PGO). For the full stack, use +unitree_g1_nav_onboard. +""" + +from __future__ import annotations + +import os +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.cmd_vel_mux import CmdVelMux +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import G1HighLevelDdsSdk +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Spatial3DView(origin="world", name="3D"), + ) + + +_rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +unitree_g1_nav_basic_onboard = ( + autoconnect( + FastLio2.blueprint( + host_ip=os.getenv("LIDAR_HOST_IP", "192.168.123.164"), + lidar_ip=os.getenv("LIDAR_IP", "192.168.123.120"), + # G1 lidar mount: 1.2m height, 180° around X (upside-down mount) + # [x, y, z, qx, qy, qz, qw] — quaternion (1,0,0,0) = 180° X rotation + init_pose=[0.0, 0.0, 1.2, 1.0, 0.0, 0.0, 0.0], + map_freq=1.0, # Publish global map at 1 Hz + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--vehicleHeight", + "1.2", + ] + ), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.5", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--maxAccel", + "2.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + CmdVelMux.blueprint(), + G1HighLevelDdsSdk.blueprint(), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config), + RerunWebSocketServer.blueprint(), + ) + .remappings( + [ + # FastLio2 outputs "lidar"; SmartNav modules expect "registered_scan" + (FastLio2, "lidar", "registered_scan"), + # PathFollower cmd_vel → CmdVelMux nav input (avoid name collision with mux output) + (PathFollower, "cmd_vel", "nav_cmd_vel"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1") +) + + +def main() -> None: + unitree_g1_nav_basic_onboard.build().loop() + + +__all__ = ["unitree_g1_nav_basic_onboard"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_basic_sim.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_basic_sim.py new file mode 100644 index 0000000000..cd9a5a9cf0 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_basic_sim.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 basic nav sim — reactive local planner only (no global route planning). + +Click-to-navigate sends waypoints directly to the local planner, which +reactively avoids obstacles. Good for short-range navigation but can get +stuck in dead ends. For global route planning, use unitree-g1-nav-sim. +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.cmd_vel_mux import CmdVelMux +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +_vis = vis_module( + viewer_backend=global_config.viewer, + rerun_config={ + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, + }, +) + +unitree_g1_nav_basic_sim = ( + autoconnect( + UnityBridgeModule.blueprint( + unity_binary="", + unity_scene="home_building_1", + vehicle_height=1.24, + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + ] + ), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + CmdVelMux.blueprint(), + _vis, + ) + .remappings( + [ + # PathFollower cmd_vel → CmdVelMux nav input (avoid name collision with mux output) + (PathFollower, "cmd_vel", "nav_cmd_vel"), + # Unity needs the extended (persistent) terrain map for Z-height, not the local one + (UnityBridgeModule, "terrain_map", "terrain_map_ext"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1", simulation=True) +) + + +def main() -> None: + unitree_g1_nav_basic_sim.build().loop() + + +__all__ = ["unitree_g1_nav_basic_sim"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_explore_onboard.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_explore_onboard.py new file mode 100644 index 0000000000..d6adb988fe --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_explore_onboard.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 with TARE autonomous exploration on real hardware. + +Zero-ROS navigation stack: TARE frontier-based exploration drives the robot +autonomously through the environment without a user-specified goal. ClickToGoal +is present for visualization but its waypoint output is disconnected so TARE +has exclusive control of LocalPlanner's waypoint input. + +Data flow: + FastLio2 → registered_scan + odometry + TarePlanner → way_point → LocalPlanner → PathFollower → G1HighLevelDdsSdk +""" + +from __future__ import annotations + +import os +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + global_map_override, + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.global_map.global_map import GlobalMap +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.tare_planner.tare_planner import TarePlanner +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import G1HighLevelDdsSdk +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Spatial3DView(origin="world", name="3D"), + ) + + +_rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/global_map": global_map_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +unitree_g1_nav_explore_onboard = ( + autoconnect( + FastLio2.blueprint( + host_ip=os.getenv("LIDAR_HOST_IP", "192.168.123.164"), + lidar_ip=os.getenv("LIDAR_IP", "192.168.123.120"), + # G1 lidar mount: 1.2m height, 180° around X (upside-down mount) + init_pose=[0.0, 0.0, 1.2, 1.0, 0.0, 0.0, 0.0], + map_freq=0.0, # GlobalMap handles accumulation + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--vehicleHeight", + "1.2", + ] + ), + TerrainMapExt.blueprint(), + TarePlanner.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.5", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--maxAccel", + "2.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + GlobalMap.blueprint(), + G1HighLevelDdsSdk.blueprint(), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config), + ) + .remappings( + [ + # FastLio2 outputs "lidar"; SmartNav modules expect "registered_scan" + (FastLio2, "lidar", "registered_scan"), + # TarePlanner drives way_point to LocalPlanner. + # Disconnect ClickToGoal's way_point so it doesn't conflict. + (ClickToGoal, "way_point", "_click_way_point_unused"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1") +) + + +def main() -> None: + unitree_g1_nav_explore_onboard.build().loop() + + +__all__ = ["unitree_g1_nav_explore_onboard"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_explore_sim.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_explore_sim.py new file mode 100644 index 0000000000..c6d8ec30fe --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_explore_sim.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 autonomous exploration sim — TARE frontier-based exploration. + +The robot autonomously explores the environment by detecting frontiers +(boundaries between explored and unexplored space) and planning paths +to maximize coverage. No click needed — just launch and watch it go. + +TARE emits waypoints that guide the local planner through unexplored areas. +Keyboard teleop (WASD) overrides exploration when active. + +Data flow: + TarePlanner → way_point → LocalPlanner → path → PathFollower + → nav_cmd_vel → CmdVelMux → cmd_vel → UnityBridgeModule +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.cmd_vel_mux import CmdVelMux +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.pgo.pgo import PGO +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.tare_planner.tare_planner import TarePlanner +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +_vis = vis_module( + viewer_backend=global_config.viewer, + rerun_config={ + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/path": path_override, + "world/way_point": waypoint_override, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, + }, +) + +unitree_g1_nav_explore_sim = ( + autoconnect( + UnityBridgeModule.blueprint( + unity_binary="", + unity_scene="home_building_1", + vehicle_height=1.24, + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + ] + ), + TerrainMapExt.blueprint(), + TarePlanner.blueprint( + sensor_range=30.0, + ), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + PGO.blueprint(), + CmdVelMux.blueprint(), + _vis, + ) + .remappings( + [ + (PathFollower, "cmd_vel", "nav_cmd_vel"), + (UnityBridgeModule, "terrain_map", "terrain_map_ext"), + # TARE plans at global scale — needs PGO-corrected odometry + (TarePlanner, "odometry", "corrected_odometry"), + (TerrainAnalysis, "odometry", "corrected_odometry"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1", simulation=True) +) + + +def main() -> None: + unitree_g1_nav_explore_sim.build().loop() + + +__all__ = ["unitree_g1_nav_explore_sim"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_far_onboard.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_far_onboard.py new file mode 100644 index 0000000000..4afb2b588e --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_far_onboard.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 with FAR global route planner on real hardware. + +Zero-ROS navigation stack: SmartNav C++ modules for terrain analysis, +local planning, and path following. FAR planner builds a visibility-graph +route to a clicked goal and feeds intermediate waypoints to LocalPlanner. + +Data flow: + FastLio2 → registered_scan + odometry + ClickToGoal.goal → FarPlanner → way_point → LocalPlanner → PathFollower + → G1HighLevelDdsSdk +""" + +from __future__ import annotations + +import os +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + global_map_override, + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.far_planner.far_planner import FarPlanner +from dimos.navigation.smartnav.modules.global_map.global_map import GlobalMap +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import G1HighLevelDdsSdk +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Spatial3DView(origin="world", name="3D"), + ) + + +_rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/global_map": global_map_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +unitree_g1_nav_far_onboard = ( + autoconnect( + FastLio2.blueprint( + host_ip=os.getenv("LIDAR_HOST_IP", "192.168.123.164"), + lidar_ip=os.getenv("LIDAR_IP", "192.168.123.120"), + # G1 lidar mount: 1.2m height, 180° around X (upside-down mount) + init_pose=[0.0, 0.0, 1.2, 1.0, 0.0, 0.0, 0.0], + map_freq=0.0, # GlobalMap handles accumulation + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--vehicleHeight", + "1.2", + ] + ), + TerrainMapExt.blueprint(), + FarPlanner.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.5", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--maxAccel", + "2.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + GlobalMap.blueprint(), + G1HighLevelDdsSdk.blueprint(), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config), + ) + .remappings( + [ + # FastLio2 outputs "lidar"; SmartNav modules expect "registered_scan" + (FastLio2, "lidar", "registered_scan"), + # FarPlanner drives way_point to LocalPlanner. + # Disconnect ClickToGoal's way_point so it doesn't conflict. + (ClickToGoal, "way_point", "_click_way_point_unused"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1") +) + + +def main() -> None: + unitree_g1_nav_far_onboard.build().loop() + + +__all__ = ["unitree_g1_nav_far_onboard"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_onboard.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_onboard.py new file mode 100644 index 0000000000..0af024ece7 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_onboard.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 nav onboard — FAR planner + PGO loop closure + local obstacle avoidance. + +Full navigation stack on real hardware with: +- FAR visibility-graph global route planner +- PGO pose graph optimization with loop closure detection (GTSAM iSAM2) +- Local planner for reactive obstacle avoidance +- Path follower for velocity control +- FastLio2 SLAM from Livox Mid-360 lidar +- G1HighLevelDdsSdk for robot velocity commands + +Odometry routing (per CMU ICRA 2022 Fig. 11): +- Local path modules (LocalPlanner, PathFollower, SensorScanGen): + use raw odometry — they follow paths in the local odometry frame. +- Global/terrain modules (FarPlanner, ClickToGoal, TerrainAnalysis): + use PGO corrected_odometry — they need globally consistent positions + for terrain classification, visibility graphs, and goal coordinates. + +Data flow: + Click → ClickToGoal (corrected_odom) → goal → FarPlanner (corrected_odom) + → way_point → LocalPlanner (raw odom) → path → PathFollower (raw odom) + → nav_cmd_vel → CmdVelMux → cmd_vel → G1HighLevelDdsSdk + + registered_scan + odometry → PGO → corrected_odometry + global_map +""" + +from __future__ import annotations + +import os +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.cmd_vel_mux import CmdVelMux +from dimos.navigation.smartnav.modules.far_planner.far_planner import FarPlanner +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.pgo.pgo import PGO +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import G1HighLevelDdsSdk +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Spatial3DView(origin="world", name="3D"), + ) + + +_rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +unitree_g1_nav_onboard = ( + autoconnect( + FastLio2.blueprint( + host_ip=os.getenv("LIDAR_HOST_IP", "192.168.123.164"), + lidar_ip=os.getenv("LIDAR_IP", "192.168.123.120"), + # G1 lidar mount: 1.2m height, 180° around X (upside-down mount) + # [x, y, z, qx, qy, qz, qw] — quaternion (1,0,0,0) = 180° X rotation + init_pose=[0.0, 0.0, 1.2, 1.0, 0.0, 0.0, 0.0], + map_freq=1.0, + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--vehicleHeight", + "1.2", + ] + ), + TerrainMapExt.blueprint(), + FarPlanner.blueprint( + sensor_range=30.0, + visibility_range=25.0, + ), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.5", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--maxAccel", + "2.0", + "--slowDwnDisThre", + "0.2", + ] + ), + PGO.blueprint(), + ClickToGoal.blueprint(), + CmdVelMux.blueprint(), + G1HighLevelDdsSdk.blueprint(), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config), + RerunWebSocketServer.blueprint(), + ) + .remappings( + [ + # FastLio2 outputs "lidar"; SmartNav modules expect "registered_scan" + (FastLio2, "lidar", "registered_scan"), + # PathFollower cmd_vel → CmdVelMux nav input (avoid name collision with mux output) + (PathFollower, "cmd_vel", "nav_cmd_vel"), + # Global-scale planners use PGO-corrected odometry (per CMU ICRA 2022): + # "Loop closure adjustments are used by the high-level planners since + # they are in charge of planning at the global scale. Modules such as + # local planner and terrain analysis only care about the local + # environment surrounding the vehicle and work in the odometry frame." + (FarPlanner, "odometry", "corrected_odometry"), + (ClickToGoal, "odometry", "corrected_odometry"), + (TerrainAnalysis, "odometry", "corrected_odometry"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1") +) + + +def main() -> None: + unitree_g1_nav_onboard.build().loop() + + +__all__ = ["unitree_g1_nav_onboard"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_pgo_onboard.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_pgo_onboard.py new file mode 100644 index 0000000000..1ddfa5b494 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_pgo_onboard.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 with PGO (pose graph optimization) on real hardware. + +Adds loop-closure-corrected mapping on top of the base SmartNav navigation +stack. PGO accumulates registered scans as keyframes and runs iSAM2 loop +closure to produce a globally consistent global map. The corrected_odometry +output can be monitored in Rerun for drift comparison. + +Data flow: + FastLio2 → registered_scan + odometry + → PGO → corrected_odometry (visualization) + global_map (accumulated) + FastLio2.odometry → SensorScanGeneration → TerrainAnalysis → LocalPlanner + → PathFollower → G1HighLevelDdsSdk +""" + +from __future__ import annotations + +import os +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + global_map_override, + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.pgo.pgo import PGO +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import G1HighLevelDdsSdk +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Spatial3DView(origin="world", name="3D"), + ) + + +_rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/global_map": global_map_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, +} + +unitree_g1_nav_pgo_onboard = ( + autoconnect( + FastLio2.blueprint( + host_ip=os.getenv("LIDAR_HOST_IP", "192.168.123.164"), + lidar_ip=os.getenv("LIDAR_IP", "192.168.123.120"), + # G1 lidar mount: 1.2m height, 180° around X (upside-down mount) + init_pose=[0.0, 0.0, 1.2, 1.0, 0.0, 0.0, 0.0], + map_freq=0.0, # PGO provides the global map + ), + PGO.blueprint(), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--vehicleHeight", + "1.2", + ] + ), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.5", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--maxAccel", + "2.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + G1HighLevelDdsSdk.blueprint(), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config), + ) + .remappings( + [ + # FastLio2 outputs "lidar"; SmartNav modules expect "registered_scan" + (FastLio2, "lidar", "registered_scan"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1") +) + + +def main() -> None: + unitree_g1_nav_pgo_onboard.build().loop() + + +__all__ = ["unitree_g1_nav_pgo_onboard"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_sim.py b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_sim.py new file mode 100644 index 0000000000..d9647e20dc --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/navigation/unitree_g1_nav_sim.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 nav sim — FAR planner + PGO loop closure + local obstacle avoidance. + +Full navigation stack with: +- FAR visibility-graph global route planner +- PGO pose graph optimization with loop closure detection (GTSAM iSAM2) +- Local planner for reactive obstacle avoidance +- Path follower for velocity control + +Odometry routing (per CMU ICRA 2022 Fig. 11): +- Local path modules (LocalPlanner, PathFollower, SensorScanGen): + use raw odometry — they follow paths in the local odometry frame. +- Global/terrain modules (FarPlanner, ClickToGoal, TerrainAnalysis): + use PGO corrected_odometry — they need globally consistent positions + for terrain classification, visibility graphs, and goal coordinates. + +Data flow: + Click → ClickToGoal (corrected_odom) → goal → FarPlanner (corrected_odom) + → way_point → LocalPlanner (raw odom) → path → PathFollower (raw odom) + → nav_cmd_vel → CmdVelMux → cmd_vel → UnityBridgeModule + + registered_scan + odometry → PGO → corrected_odometry + global_map +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.blueprints._rerun_helpers import ( + goal_path_override, + path_override, + sensor_scan_override, + static_floor, + static_robot, + terrain_map_ext_override, + terrain_map_override, + waypoint_override, +) +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.cmd_vel_mux import CmdVelMux +from dimos.navigation.smartnav.modules.far_planner.far_planner import FarPlanner +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.pgo.pgo import PGO +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +_vis = vis_module( + viewer_backend=global_config.viewer, + rerun_config={ + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "min_interval_sec": 0.25, + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + "world/sensor_scan": sensor_scan_override, + "world/terrain_map": terrain_map_override, + "world/terrain_map_ext": terrain_map_ext_override, + "world/path": path_override, + "world/way_point": waypoint_override, + "world/goal_path": goal_path_override, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/floor": static_floor, + "world/tf/robot": static_robot, + }, + }, +) + +unitree_g1_nav_sim = ( + autoconnect( + UnityBridgeModule.blueprint( + unity_binary="", + unity_scene="home_building_1", + vehicle_height=1.24, + ), + SensorScanGeneration.blueprint(), + TerrainAnalysis.blueprint( + extra_args=[ + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + ] + ), + TerrainMapExt.blueprint(), + FarPlanner.blueprint( + sensor_range=30.0, + visibility_range=25.0, + ), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-1.0", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "2.0", + "--autonomySpeed", + "2.0", + "--maxAccel", + "4.0", + "--slowDwnDisThre", + "0.2", + ] + ), + PGO.blueprint(), + ClickToGoal.blueprint(), + CmdVelMux.blueprint(), + _vis, + ) + .remappings( + [ + # PathFollower cmd_vel → CmdVelMux nav input (avoid name collision with mux output) + (PathFollower, "cmd_vel", "nav_cmd_vel"), + # Unity needs the extended (persistent) terrain map for Z-height, not the local one + (UnityBridgeModule, "terrain_map", "terrain_map_ext"), + # Global-scale planners use PGO-corrected odometry (per CMU ICRA 2022): + # "Loop closure adjustments are used by the high-level planners since + # they are in charge of planning at the global scale. Modules such as + # local planner and terrain analysis only care about the local + # environment surrounding the vehicle and work in the odometry frame." + (FarPlanner, "odometry", "corrected_odometry"), + (ClickToGoal, "odometry", "corrected_odometry"), + (TerrainAnalysis, "odometry", "corrected_odometry"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1", simulation=True) +) + + +def main() -> None: + unitree_g1_nav_sim.build().loop() + + +__all__ = ["unitree_g1_nav_sim"] + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_rosnav_onboard.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_rosnav_onboard.py new file mode 100644 index 0000000000..201e1beb91 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_rosnav_onboard.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 with ROSNav in hardware mode + replanning A* local planner.""" + +import os + +from dimos.core.blueprints import autoconnect +from dimos.navigation.replanning_a_star.module import ReplanningAStarPlanner +from dimos.navigation.rosnav.rosnav_module import ROSNav +from dimos.robot.unitree.g1.blueprints.primitive._mapper import _mapper +from dimos.robot.unitree.g1.blueprints.primitive._vis import _vis +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import G1HighLevelDdsSdk + +unitree_g1_rosnav_onboard = ( + autoconnect( + _vis, + _mapper, + G1HighLevelDdsSdk.blueprint(), + ReplanningAStarPlanner.blueprint(), + ROSNav.blueprint( + mode="hardware", + vehicle_height=1.24, + unitree_ip=os.getenv("ROBOT_IP", "192.168.12.1"), + unitree_conn=os.getenv("ROSNAV_UNITREE_CONN", "LocalAP"), + lidar_interface=os.getenv("ROSNAV_LIDAR_INTERFACE", "eth0"), + lidar_computer_ip=os.getenv("ROSNAV_LIDAR_COMPUTER_IP", "192.168.123.5"), + lidar_gateway=os.getenv("ROSNAV_LIDAR_GATEWAY", "192.168.123.1"), + lidar_ip=os.getenv("ROSNAV_LIDAR_IP", "192.168.123.120"), + ), + ) + .remappings( + [ + (ROSNav, "teleop_cmd_vel", "tele_cmd_vel"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_g1") +) + +__all__ = ["unitree_g1_rosnav_onboard"] diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_rosnav_sim.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_rosnav_sim.py new file mode 100644 index 0000000000..ed99c2bb30 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_rosnav_sim.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 with ROSNav + external Unity simulation. + +The Unity simulator runs on the host (via UnityBridgeModule) and provides +lidar, camera, and odometry data. ROSNav runs in hardware mode inside +Docker — its nav stack receives the external sensor data via ROS2 topics +republished by ROSNav's ext_* input streams. + +cmd_vel flows back from the nav stack (or teleop) through LCM to the +UnityBridgeModule, which drives the simulated robot. +""" + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.mapping.costmapper import CostMapper +from dimos.mapping.voxels import VoxelGridMapper +from dimos.navigation.rosnav.rosnav_module import ROSNav +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.robot.unitree.g1.blueprints.primitive._mapper import _mapper +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.vis_module import vis_module +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule + + +def _static_path_frame(rr: Any) -> list[Any]: + return [rr.Transform3D(parent_frame="tf#/sensor")] + + +def _static_base_link(rr: Any) -> list[Any]: + """Green wireframe box tracking the robot. + + Attached to ``tf#/sensor`` because the UnityBridgeModule publishes + ``map → sensor`` (there is no separate ``base_link`` frame in external + sim mode). + """ + return [ + rr.Boxes3D( + half_sizes=[0.2, 0.15, 0.62], + centers=[[0, 0, -0.62]], + colors=[(0, 255, 127)], + fill_mode="MajorWireframe", + ), + rr.Transform3D(parent_frame="tf#/sensor"), + ] + + +_vis_sim = vis_module( + viewer_backend=global_config.viewer, + rerun_config={ + "pubsubs": [LCM()], + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + "world/tf/base_link": _static_base_link, + "world/path": _static_path_frame, + }, + }, +) + +unitree_g1_rosnav_sim = ( + autoconnect( + _vis_sim, + _mapper, + WebsocketVisModule.blueprint(), + UnityBridgeModule.blueprint(), + ROSNav.blueprint(mode="external_sim", vehicle_height=1.24, mount_sim_assets=True), + ) + .remappings( + [ + # Wire Unity sensor outputs → ROSNav external inputs. + # Use "ext_*" names matching the UnityBridgeModule output names + # to avoid colliding with ROSNav's own output streams of the same type. + (UnityBridgeModule, "registered_scan", "ext_registered_scan"), + (UnityBridgeModule, "odometry", "ext_odometry"), + # Feed local terrain data from nav stack to Unity for Z-height adjustment + # Rename VoxelGridMapper/CostMapper streams to avoid collisions + (VoxelGridMapper, "lidar", "global_pointcloud"), + (VoxelGridMapper, "global_map", "global_voxel_map"), + (CostMapper, "global_map", "global_voxel_map"), + # Teleop: WebsocketVisModule cmd_vel → ROSNav tele_cmd_vel + (WebsocketVisModule, "cmd_vel", "tele_cmd_vel"), + ] + ) + .global_config(n_workers=4, robot_model="unitree_g1") +) + +__all__ = ["unitree_g1_rosnav_sim"] diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index 5b127fb697..37751b0222 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -19,8 +19,8 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 +from dimos.visualization.vis_module import vis_module unitree_g1_shm = autoconnect( unitree_g1.transports( @@ -30,10 +30,8 @@ ), } ), - FoxgloveBridge.blueprint( - shm_channels=[ - "/color_image#sensor_msgs.Image", - ] + vis_module( + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, ), ) diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index ff59c9b8ef..709dfbe140 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -41,6 +41,7 @@ WavefrontFrontierExplorer, ) from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.visualization.vis_module import vis_module from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule @@ -109,18 +110,7 @@ def _g1_rerun_blueprint() -> Any: }, } -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - _with_vis = autoconnect(FoxgloveBridge.blueprint()) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - _with_vis = autoconnect( - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config) - ) -else: - _with_vis = autoconnect() +_with_vis = vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config) def _create_webcam() -> Webcam: diff --git a/dimos/robot/unitree/g1/connection.py b/dimos/robot/unitree/g1/connection.py index bc2ca7d3d9..e435ab4fa7 100644 --- a/dimos/robot/unitree/g1/connection.py +++ b/dimos/robot/unitree/g1/connection.py @@ -117,3 +117,6 @@ def deploy(dimos: ModuleCoordinator, ip: str, local_planner: LocalPlanner) -> "M connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() return connection + + +__all__ = ["G1Connection", "G1ConnectionBase", "deploy"] diff --git a/dimos/robot/unitree/g1/effectors/high_level/dds_sdk.py b/dimos/robot/unitree/g1/effectors/high_level/dds_sdk.py new file mode 100644 index 0000000000..ec5e003d53 --- /dev/null +++ b/dimos/robot/unitree/g1/effectors/high_level/dds_sdk.py @@ -0,0 +1,435 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 high-level control via native Unitree SDK2 (DDS).""" + +import difflib +from enum import IntEnum +import json +import threading +import time +from typing import Any + +from reactivex.disposable import Disposable +from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import ( # type: ignore[import-not-found] + MotionSwitcherClient, +) +from unitree_sdk2py.core.channel import ChannelFactoryInitialize # type: ignore[import-not-found] +from unitree_sdk2py.g1.loco.g1_loco_api import ( # type: ignore[import-not-found] + ROBOT_API_ID_LOCO_GET_BALANCE_MODE, + ROBOT_API_ID_LOCO_GET_FSM_ID, + ROBOT_API_ID_LOCO_GET_FSM_MODE, +) +from unitree_sdk2py.g1.loco.g1_loco_client import LocoClient # type: ignore[import-not-found] + +from dimos.agents.annotation import skill +from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.robot.unitree.g1.effectors.high_level.high_level_spec import HighLevelG1Spec +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +_LOCO_API_IDS = { + "GET_FSM_ID": ROBOT_API_ID_LOCO_GET_FSM_ID, + "GET_FSM_MODE": ROBOT_API_ID_LOCO_GET_FSM_MODE, + "GET_BALANCE_MODE": ROBOT_API_ID_LOCO_GET_BALANCE_MODE, +} + + +# G1 Arm Actions - all use api_id 7106 on topic "rt/api/arm/request" +G1_ARM_CONTROLS = [ + ("Handshake", 27, "Perform a handshake gesture with the right hand."), + ("HighFive", 18, "Give a high five with the right hand."), + ("Hug", 19, "Perform a hugging gesture with both arms."), + ("HighWave", 26, "Wave with the hand raised high."), + ("Clap", 17, "Clap hands together."), + ("FaceWave", 25, "Wave near the face level."), + ("LeftKiss", 12, "Blow a kiss with the left hand."), + ("ArmHeart", 20, "Make a heart shape with both arms overhead."), + ("RightHeart", 21, "Make a heart gesture with the right hand."), + ("HandsUp", 15, "Raise both hands up in the air."), + ("XRay", 24, "Hold arms in an X-ray pose position."), + ("RightHandUp", 23, "Raise only the right hand up."), + ("Reject", 22, "Make a rejection or 'no' gesture."), + ("CancelAction", 99, "Cancel any current arm action and return hands to neutral position."), +] + +# G1 Movement Modes - all use api_id 7101 on topic "rt/api/sport/request" +G1_MODE_CONTROLS = [ + ("WalkMode", 500, "Switch to normal walking mode."), + ("WalkControlWaist", 501, "Switch to walking mode with waist control."), + ("RunMode", 801, "Switch to running mode."), +] + +_ARM_COMMANDS: dict[str, tuple[int, str]] = { + name: (id_, description) for name, id_, description in G1_ARM_CONTROLS +} + +_MODE_COMMANDS: dict[str, tuple[int, str]] = { + name: (id_, description) for name, id_, description in G1_MODE_CONTROLS +} + +_ARM_COMMANDS_DOC = "\n".join(f'- "{name}": {desc}' for name, (_, desc) in _ARM_COMMANDS.items()) +_MODE_COMMANDS_DOC = "\n".join(f'- "{name}": {desc}' for name, (_, desc) in _MODE_COMMANDS.items()) + + +class FsmState(IntEnum): + ZERO_TORQUE = 0 + DAMP = 1 + SIT = 3 + AI_MODE = 200 + LIE_TO_STANDUP = 702 + SQUAT_STANDUP_TOGGLE = 706 + + +# Module +class G1HighLevelDdsSdkConfig(ModuleConfig): + ip: str | None = None + network_interface: str = "eth0" + connection_mode: str = "ai" + ai_standup: bool = True + motion_switcher_timeout: float = 5.0 + loco_client_timeout: float = 10.0 + cmd_vel_timeout: float = 0.2 + + +class G1HighLevelDdsSdk(Module, HighLevelG1Spec): + """G1 high-level control module using the native Unitree SDK2 over DDS. + + Suitable for onboard control running directly on the robot. + """ + + cmd_vel: In[Twist] + default_config = G1HighLevelDdsSdkConfig + config: G1HighLevelDdsSdkConfig + + # Primary timing knob — individual delays in methods are fractions of this. + _standup_step_delay: float = 3.0 + + def __init__(self, *args: Any, g: GlobalConfig = global_config, **kwargs: Any) -> None: + super().__init__(*args, g=g, **kwargs) + self._global_config = g + self._stop_timer: threading.Timer | None = None + self._running = False + self._mode_selected = False + self.motion_switcher: Any = None + self.loco_client: Any = None + + # lifecycle + + @rpc + def start(self) -> None: + super().start() + + network_interface = self.config.network_interface + + # Initialise DDS channel factory + logger.info(f"Initializing DDS on interface: {network_interface}") + ChannelFactoryInitialize(0, network_interface) + + # Motion switcher (required before LocoClient commands work) + self.motion_switcher = MotionSwitcherClient() + self.motion_switcher.SetTimeout(self.config.motion_switcher_timeout) + self.motion_switcher.Init() + logger.info("Motion switcher initialized") + + # Locomotion client + self.loco_client = LocoClient() + self.loco_client.SetTimeout(self.config.loco_client_timeout) + self.loco_client.Init() + + self.loco_client._RegistApi(_LOCO_API_IDS["GET_FSM_ID"], 0) + self.loco_client._RegistApi(_LOCO_API_IDS["GET_FSM_MODE"], 0) + self.loco_client._RegistApi(_LOCO_API_IDS["GET_BALANCE_MODE"], 0) + + self._select_motion_mode() + self._running = True + + if self.cmd_vel._transport is not None: + self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + logger.info("G1 DDS SDK connection started") + + @rpc + def stop(self) -> None: + if self._stop_timer: + self._stop_timer.cancel() + self._stop_timer = None + + if self.loco_client is not None: + try: + self.loco_client.StopMove() + except Exception as e: + logger.error(f"Error stopping robot: {e}") + + self._running = False + logger.info("G1 DDS SDK connection stopped") + super().stop() + + # HighLevelG1Spec + + @rpc + def move(self, twist: Twist, duration: float = 0.0) -> bool: + assert self.loco_client is not None + vx = twist.linear.x + vy = twist.linear.y + vyaw = twist.angular.z + + if self._stop_timer: + self._stop_timer.cancel() + self._stop_timer = None + + try: + if duration > 0: + logger.info(f"Moving: vx={vx}, vy={vy}, vyaw={vyaw}, duration={duration}") + code = self.loco_client.SetVelocity(vx, vy, vyaw, duration) + if code != 0: + logger.warning(f"SetVelocity returned code: {code}") + return False + else: + + def auto_stop() -> None: + try: + logger.debug("Auto-stop timer triggered") + self.loco_client.StopMove() + except Exception as e: + logger.error(f"Auto-stop failed: {e}") + + # Send move command before starting the timeout timer to avoid + # a race where the timer fires before the move is sent. + self.loco_client.Move(vx, vy, vyaw, continous_move=True) + + self._stop_timer = threading.Timer(self.config.cmd_vel_timeout, auto_stop) + self._stop_timer.daemon = True + self._stop_timer.start() + + return True + except Exception as e: + logger.error(f"Failed to send movement command: {e}") + return False + + @rpc + def get_state(self) -> str: + fsm_id = self._get_fsm_id() + if fsm_id is None: + return "Unknown (query failed)" + try: + return FsmState(fsm_id).name + except ValueError: + return f"UNKNOWN_{fsm_id}" + + @rpc + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[str, Any]: + logger.info(f"Publishing request to topic: {topic} with data: {data}") + assert self.loco_client is not None + + api_id = data.get("api_id") + parameter = data.get("parameter", {}) + + try: + if api_id == 7101: # SET_FSM_ID + fsm_id = parameter.get("data", 0) + code = self.loco_client.SetFsmId(fsm_id) + return {"code": code} + elif api_id == 7105: # SET_VELOCITY + velocity = parameter.get("velocity", [0, 0, 0]) + dur = parameter.get("duration", 1.0) + code = self.loco_client.SetVelocity(velocity[0], velocity[1], velocity[2], dur) + return {"code": code} + else: + logger.warning(f"Unsupported API ID: {api_id}") + return {"code": -1, "error": "unsupported_api"} + except Exception as e: + logger.error(f"publish_request failed: {e}") + return {"code": -1, "error": str(e)} + + @rpc + def stand_up(self) -> bool: + assert self.loco_client is not None + try: + logger.info(f"Current state before stand_up: {self.get_state()}") + + if self.config.ai_standup: + fsm_id = self._get_fsm_id() + if fsm_id == FsmState.ZERO_TORQUE: + logger.info("Robot in zero torque, enabling damp mode...") + self.loco_client.SetFsmId(FsmState.DAMP) + time.sleep(self._standup_step_delay / 3) + if fsm_id != FsmState.AI_MODE: + logger.info("Starting AI mode...") + self.loco_client.SetFsmId(FsmState.AI_MODE) + time.sleep(self._standup_step_delay / 2) + else: + logger.info("Enabling damp mode...") + self.loco_client.SetFsmId(FsmState.DAMP) + time.sleep(self._standup_step_delay / 3) + + logger.info("Executing Squat2StandUp...") + self.loco_client.SetFsmId(FsmState.SQUAT_STANDUP_TOGGLE) + time.sleep(self._standup_step_delay) + logger.info(f"Final state: {self.get_state()}") + return True + except Exception as e: + logger.error(f"Standup failed: {e}") + return False + + @rpc + def lie_down(self) -> bool: + assert self.loco_client is not None + try: + self.loco_client.StandUp2Squat() + time.sleep(self._standup_step_delay / 3) + self.loco_client.Damp() + return True + except Exception as e: + logger.error(f"Lie down failed: {e}") + return False + + def disconnect(self) -> None: + self.stop() + + # skills (LLM-callable) + + @skill + def move_velocity( + self, x: float, y: float = 0.0, yaw: float = 0.0, duration: float = 0.0 + ) -> str: + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions. + + Example call: + args = { "x": 0.5, "y": 0.0, "yaw": 0.0, "duration": 2.0 } + move_velocity(**args) + + Args: + x: Forward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds) + """ + twist = Twist(linear=Vector3(x, y, 0), angular=Vector3(0, 0, yaw)) + self.move(twist, duration=duration) + return f"Started moving with velocity=({x}, {y}, {yaw}) for {duration} seconds" + + @skill + def execute_arm_command(self, command_name: str) -> str: + """Execute a Unitree G1 arm command.""" + return self._execute_g1_command(_ARM_COMMANDS, 7106, "rt/api/arm/request", command_name) + + execute_arm_command.__doc__ = f"""Execute a Unitree G1 arm command. + + Example usage: + + execute_arm_command("ArmHeart") + + Here are all the command names and what they do. + + {_ARM_COMMANDS_DOC} + """ + + @skill + def execute_mode_command(self, command_name: str) -> str: + """Execute a Unitree G1 mode command.""" + return self._execute_g1_command(_MODE_COMMANDS, 7101, "rt/api/sport/request", command_name) + + execute_mode_command.__doc__ = f"""Execute a Unitree G1 mode command. + + Example usage: + + execute_mode_command("RunMode") + + Here are all the command names and what they do. + + {_MODE_COMMANDS_DOC} + """ + + # private helpers + + def _execute_g1_command( + self, + command_dict: dict[str, tuple[int, str]], + api_id: int, + topic: str, + command_name: str, + ) -> str: + if command_name not in command_dict: + suggestions = difflib.get_close_matches( + command_name, command_dict.keys(), n=3, cutoff=0.6 + ) + return f"There's no '{command_name}' command. Did you mean: {suggestions}" + + id_, _ = command_dict[command_name] + + try: + self.publish_request(topic, {"api_id": api_id, "parameter": {"data": id_}}) + return f"'{command_name}' command executed successfully." + except Exception as e: + logger.error(f"Failed to execute {command_name}: {e}") + return "Failed to execute the command." + + def _select_motion_mode(self) -> None: + if not self.motion_switcher or self._mode_selected: + return + + try: + code, result = self.motion_switcher.CheckMode() + if code == 0 and result: + current_mode = result.get("name", "none") + logger.info(f"Current motion mode: {current_mode}") + if current_mode and current_mode != "none": + logger.warning( + f"Robot is in '{current_mode}' mode. " + "If SDK commands don't work, you may need to activate " + "via controller: L1+A then L1+UP " + "(for chinese L2+B then L2+up then R2+A)" + ) + except Exception as e: + logger.debug(f"Could not check current mode: {e}") + + mode = self.config.connection_mode + logger.info(f"Selecting motion mode: {mode}") + code, _ = self.motion_switcher.SelectMode(mode) + if code == 0: + logger.info(f"Motion mode '{mode}' selected successfully") + self._mode_selected = True + time.sleep(self._standup_step_delay / 6) + else: + logger.error( + f"Failed to select mode '{mode}': code={code}\n" + " The robot may need to be activated via controller first:\n" + " 1. Press L1 + A on the controller\n" + " 2. Then press L1 + UP\n" + " This enables the AI Sport client required for SDK control." + ) + + def _get_fsm_id(self) -> int | None: + try: + code, data = self.loco_client._Call(_LOCO_API_IDS["GET_FSM_ID"], "{}") + if code == 0 and data: + result = json.loads(data) if isinstance(data, str) else data + fsm_id = result.get("data") if isinstance(result, dict) else result + logger.debug(f"Current FSM ID: {fsm_id}") + return fsm_id + else: + logger.warning(f"Failed to get FSM ID: code={code}, data={data}") + return None + except Exception as e: + logger.error(f"Error getting FSM ID: {e}") + return None + + +__all__ = ["FsmState", "G1HighLevelDdsSdk", "G1HighLevelDdsSdkConfig"] diff --git a/dimos/robot/unitree/g1/effectors/high_level/high_level_spec.py b/dimos/robot/unitree/g1/effectors/high_level/high_level_spec.py new file mode 100644 index 0000000000..cb4e53d81b --- /dev/null +++ b/dimos/robot/unitree/g1/effectors/high_level/high_level_spec.py @@ -0,0 +1,50 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Spec for G1 high-level control interface. + +Any high-level control module (WebRTC, native SDK, etc.) must implement +this protocol so that skill containers and blueprints can work against +a single, stable API. +""" + +from typing import Any, Protocol + +from dimos.core.stream import In +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.spec.utils import Spec + + +class HighLevelG1Spec(Spec, Protocol): + """Common high-level control interface for the Unitree G1. + + Implementations provide velocity control, state queries, and + posture commands regardless of the underlying transport (WebRTC, + native SDK, etc.). + """ + + cmd_vel: In[Twist] + + def move(self, twist: Twist, duration: float = 0.0) -> bool: ... + + def get_state(self) -> str: ... + + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[str, Any]: ... + + def stand_up(self) -> bool: ... + + def lie_down(self) -> bool: ... + + +__all__ = ["HighLevelG1Spec"] diff --git a/dimos/robot/unitree/g1/effectors/high_level/high_level_test.py b/dimos/robot/unitree/g1/effectors/high_level/high_level_test.py new file mode 100644 index 0000000000..71d5885b58 --- /dev/null +++ b/dimos/robot/unitree/g1/effectors/high_level/high_level_test.py @@ -0,0 +1,602 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for G1 high-level control modules (DDS SDK and WebRTC).""" + +from __future__ import annotations + +from enum import IntEnum +import json +import sys +from typing import Any +from unittest.mock import MagicMock, call, patch + +import pytest + + +# Stub out unitree_sdk2py so we can import dds_sdk without the real SDK +def _install_sdk_stubs() -> dict[str, MagicMock]: + stubs: dict[str, MagicMock] = {} + for mod_name in [ + "unitree_sdk2py", + "unitree_sdk2py.comm", + "unitree_sdk2py.comm.motion_switcher", + "unitree_sdk2py.comm.motion_switcher.motion_switcher_client", + "unitree_sdk2py.core", + "unitree_sdk2py.core.channel", + "unitree_sdk2py.g1", + "unitree_sdk2py.g1.loco", + "unitree_sdk2py.g1.loco.g1_loco_api", + "unitree_sdk2py.g1.loco.g1_loco_client", + ]: + mock = MagicMock() + stubs[mod_name] = mock + sys.modules[mod_name] = mock + + # Wire up named attributes the module actually imports + api_mod = stubs["unitree_sdk2py.g1.loco.g1_loco_api"] + api_mod.ROBOT_API_ID_LOCO_GET_FSM_ID = 7001 + api_mod.ROBOT_API_ID_LOCO_GET_FSM_MODE = 7002 + api_mod.ROBOT_API_ID_LOCO_GET_BALANCE_MODE = 7003 + + client_mod = stubs["unitree_sdk2py.g1.loco.g1_loco_client"] + client_mod.LocoClient = MagicMock + + switcher_mod = stubs["unitree_sdk2py.comm.motion_switcher.motion_switcher_client"] + switcher_mod.MotionSwitcherClient = MagicMock + + channel_mod = stubs["unitree_sdk2py.core.channel"] + channel_mod.ChannelFactoryInitialize = MagicMock() + + return stubs + + +# Stub out unitree_webrtc_connect too +def _install_webrtc_stubs() -> dict[str, MagicMock]: + stubs: dict[str, MagicMock] = {} + for mod_name in [ + "unitree_webrtc_connect", + "unitree_webrtc_connect.constants", + "unitree_webrtc_connect.webrtc_driver", + ]: + mock = MagicMock() + stubs[mod_name] = mock + sys.modules[mod_name] = mock + + constants = stubs["unitree_webrtc_connect.constants"] + constants.RTC_TOPIC = "rt/topic" + constants.SPORT_CMD = "sport_cmd" + # VUI_COLOR is used both as a type and a value (VUI_COLOR.RED) in connection.py + constants.VUI_COLOR = MagicMock() + + driver = stubs["unitree_webrtc_connect.webrtc_driver"] + driver.UnitreeWebRTCConnection = MagicMock + driver.WebRTCConnectionMethod = MagicMock() + + return stubs + + +_sdk_stubs = _install_sdk_stubs() +_webrtc_stubs = _install_webrtc_stubs() + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import ( + FsmState, + G1HighLevelDdsSdk, + G1HighLevelDdsSdkConfig, +) +from dimos.robot.unitree.g1.effectors.high_level.webrtc import ( + _ARM_COMMANDS, + _MODE_COMMANDS, + G1_ARM_CONTROLS, + G1_MODE_CONTROLS, + G1HighLevelWebRtc, + G1HighLevelWebRtcConfig, +) + +# FsmState enum tests + + +class TestFsmState: + def test_is_int_enum(self) -> None: + assert issubclass(FsmState, IntEnum) + + def test_values(self) -> None: + assert FsmState.ZERO_TORQUE == 0 # type: ignore[comparison-overlap] + assert FsmState.DAMP == 1 # type: ignore[comparison-overlap] + assert FsmState.SIT == 3 # type: ignore[comparison-overlap] + assert FsmState.AI_MODE == 200 # type: ignore[comparison-overlap] + assert FsmState.LIE_TO_STANDUP == 702 # type: ignore[comparison-overlap] + assert FsmState.SQUAT_STANDUP_TOGGLE == 706 # type: ignore[comparison-overlap] + + def test_name_lookup(self) -> None: + assert FsmState(0).name == "ZERO_TORQUE" + assert FsmState(1).name == "DAMP" + assert FsmState(200).name == "AI_MODE" + assert FsmState(706).name == "SQUAT_STANDUP_TOGGLE" + + def test_int_comparison(self) -> None: + assert FsmState.DAMP == 1 # type: ignore[comparison-overlap] + assert FsmState.AI_MODE != 0 # type: ignore[comparison-overlap] + + def test_unknown_value_raises(self) -> None: + with pytest.raises(ValueError): + FsmState(999) + + def test_iteration(self) -> None: + names = [s.name for s in FsmState] + assert "ZERO_TORQUE" in names + assert "AI_MODE" in names + assert len(names) == 6 + + +# Config tests + + +class TestDdsSdkConfig: + def test_defaults(self) -> None: + cfg = G1HighLevelDdsSdkConfig() + assert cfg.ip is None + assert cfg.network_interface == "eth0" + assert cfg.connection_mode == "ai" + assert cfg.ai_standup is True + assert cfg.motion_switcher_timeout == 5.0 + assert cfg.loco_client_timeout == 10.0 + assert cfg.cmd_vel_timeout == 0.2 + + def test_override(self) -> None: + cfg = G1HighLevelDdsSdkConfig( + ip="192.168.1.1", + ai_standup=False, + cmd_vel_timeout=0.5, + ) + assert cfg.ip == "192.168.1.1" + assert cfg.ai_standup is False + assert cfg.cmd_vel_timeout == 0.5 + + +class TestWebRtcConfig: + def test_defaults(self) -> None: + cfg = G1HighLevelWebRtcConfig() + assert cfg.ip is None + assert cfg.connection_mode == "ai" + + +# DDS SDK module tests (mocked) + + +def _make_dds_module(**config_overrides: Any) -> G1HighLevelDdsSdk: + """Create a G1HighLevelDdsSdk with mocked internals.""" + gc = MagicMock() + with patch.object(G1HighLevelDdsSdk, "__init__", lambda self, *a, **kw: None): + mod = G1HighLevelDdsSdk.__new__(G1HighLevelDdsSdk) + + mod.config = G1HighLevelDdsSdkConfig(**config_overrides) + mod._global_config = gc + mod._stop_timer = None + mod._running = False + mod._mode_selected = False + mod.motion_switcher = MagicMock() + mod.loco_client = MagicMock() + mod._standup_step_delay = 0.0 # no real sleeps in tests + return mod + + +class TestDdsSdkGetState: + def test_known_fsm(self) -> None: + mod = _make_dds_module() + mod.loco_client._Call.return_value = (0, json.dumps({"data": 0})) + assert mod.get_state() == "ZERO_TORQUE" + + def test_ai_mode_fsm(self) -> None: + mod = _make_dds_module() + mod.loco_client._Call.return_value = (0, json.dumps({"data": 200})) + assert mod.get_state() == "AI_MODE" + + def test_unknown_fsm(self) -> None: + mod = _make_dds_module() + mod.loco_client._Call.return_value = (0, json.dumps({"data": 999})) + assert mod.get_state() == "UNKNOWN_999" + + def test_query_failed(self) -> None: + mod = _make_dds_module() + mod.loco_client._Call.return_value = (1, None) + assert mod.get_state() == "Unknown (query failed)" + + def test_call_raises(self) -> None: + mod = _make_dds_module() + mod.loco_client._Call.side_effect = RuntimeError("timeout") + assert mod.get_state() == "Unknown (query failed)" + + +class TestDdsSdkStandUp: + def test_ai_standup_from_zero_torque(self) -> None: + mod = _make_dds_module(ai_standup=True) + mod.loco_client._Call.return_value = (0, json.dumps({"data": FsmState.ZERO_TORQUE})) + result = mod.stand_up() + assert result is True + calls = mod.loco_client.SetFsmId.call_args_list + assert calls[0] == call(FsmState.DAMP) + assert calls[1] == call(FsmState.AI_MODE) + assert calls[2] == call(FsmState.SQUAT_STANDUP_TOGGLE) + + def test_ai_standup_already_ai_mode(self) -> None: + mod = _make_dds_module(ai_standup=True) + mod.loco_client._Call.return_value = (0, json.dumps({"data": FsmState.AI_MODE})) + result = mod.stand_up() + assert result is True + calls = mod.loco_client.SetFsmId.call_args_list + # Should skip DAMP and AI_MODE, go straight to toggle + assert len(calls) == 1 + assert calls[0] == call(FsmState.SQUAT_STANDUP_TOGGLE) + + def test_normal_standup(self) -> None: + mod = _make_dds_module(ai_standup=False) + result = mod.stand_up() + assert result is True + calls = mod.loco_client.SetFsmId.call_args_list + assert calls[0] == call(FsmState.DAMP) + assert calls[1] == call(FsmState.SQUAT_STANDUP_TOGGLE) + + def test_standup_exception(self) -> None: + mod = _make_dds_module(ai_standup=False) + mod.loco_client.SetFsmId.side_effect = RuntimeError("comms lost") + result = mod.stand_up() + assert result is False + + +class TestDdsSdkLieDown: + def test_lie_down(self) -> None: + mod = _make_dds_module() + result = mod.lie_down() + assert result is True + mod.loco_client.StandUp2Squat.assert_called_once() + mod.loco_client.Damp.assert_called_once() + + def test_lie_down_exception(self) -> None: + mod = _make_dds_module() + mod.loco_client.StandUp2Squat.side_effect = RuntimeError("err") + result = mod.lie_down() + assert result is False + + +class TestDdsSdkMove: + def test_move_with_duration(self) -> None: + mod = _make_dds_module() + mod.loco_client.SetVelocity.return_value = 0 + twist = Twist(linear=Vector3(1.0, 0.5, 0), angular=Vector3(0, 0, 0.3)) + result = mod.move(twist, duration=2.0) + assert result is True + mod.loco_client.SetVelocity.assert_called_once_with(1.0, 0.5, 0.3, 2.0) + + def test_move_with_duration_error_code(self) -> None: + mod = _make_dds_module() + mod.loco_client.SetVelocity.return_value = -1 + twist = Twist(linear=Vector3(1.0, 0, 0), angular=Vector3(0, 0, 0)) + result = mod.move(twist, duration=1.0) + assert result is False + + def test_move_continuous(self) -> None: + mod = _make_dds_module() + twist = Twist(linear=Vector3(0.5, 0, 0), angular=Vector3(0, 0, 0.1)) + result = mod.move(twist) + assert result is True + mod.loco_client.Move.assert_called_once_with(0.5, 0, 0.1, continous_move=True) + # Timer should have been started + assert mod._stop_timer is not None + mod._stop_timer.cancel() # cleanup + + def test_move_exception(self) -> None: + mod = _make_dds_module() + mod.loco_client.SetVelocity.side_effect = RuntimeError("err") + twist = Twist(linear=Vector3(1.0, 0, 0), angular=Vector3(0, 0, 0)) + result = mod.move(twist, duration=1.0) + assert result is False + + +class TestDdsSdkPublishRequest: + def test_set_fsm_id(self) -> None: + mod = _make_dds_module() + mod.loco_client.SetFsmId.return_value = 0 + result = mod.publish_request("topic", {"api_id": 7101, "parameter": {"data": 200}}) + assert result == {"code": 0} + mod.loco_client.SetFsmId.assert_called_once_with(200) + + def test_set_velocity(self) -> None: + mod = _make_dds_module() + mod.loco_client.SetVelocity.return_value = 0 + result = mod.publish_request( + "topic", + {"api_id": 7105, "parameter": {"velocity": [1.0, 0.5, 0.2], "duration": 3.0}}, + ) + assert result == {"code": 0} + mod.loco_client.SetVelocity.assert_called_once_with(1.0, 0.5, 0.2, 3.0) + + def test_unsupported_api(self) -> None: + mod = _make_dds_module() + result = mod.publish_request("topic", {"api_id": 9999}) + assert result["code"] == -1 + assert result["error"] == "unsupported_api" + + def test_exception(self) -> None: + mod = _make_dds_module() + mod.loco_client.SetFsmId.side_effect = RuntimeError("boom") + result = mod.publish_request("topic", {"api_id": 7101, "parameter": {"data": 1}}) + assert result["code"] == -1 + assert "boom" in result["error"] + + +# WebRTC module tests (mocked) + + +def _make_webrtc_module(**config_overrides: Any) -> G1HighLevelWebRtc: + with patch.object(G1HighLevelWebRtc, "__init__", lambda self, *a, **kw: None): + mod = G1HighLevelWebRtc.__new__(G1HighLevelWebRtc) + + mod.config = G1HighLevelWebRtcConfig(**config_overrides) + mod._global_config = MagicMock() + mod.connection = MagicMock() + return mod + + +class TestWebRtcConstants: + def test_arm_controls_structure(self) -> None: + for name, id_, desc in G1_ARM_CONTROLS: + assert isinstance(name, str) + assert isinstance(id_, int) + assert isinstance(desc, str) + + def test_mode_controls_structure(self) -> None: + for name, id_, desc in G1_MODE_CONTROLS: + assert isinstance(name, str) + assert isinstance(id_, int) + assert isinstance(desc, str) + + def test_arm_commands_dict(self) -> None: + assert "Handshake" in _ARM_COMMANDS + assert "CancelAction" in _ARM_COMMANDS + assert len(_ARM_COMMANDS) == len(G1_ARM_CONTROLS) + + def test_mode_commands_dict(self) -> None: + assert "WalkMode" in _MODE_COMMANDS + assert "RunMode" in _MODE_COMMANDS + assert len(_MODE_COMMANDS) == len(G1_MODE_CONTROLS) + + +class TestWebRtcGetState: + def test_connected(self) -> None: + mod = _make_webrtc_module() + assert mod.get_state() == "Connected (WebRTC)" + + def test_not_connected(self) -> None: + mod = _make_webrtc_module() + mod.connection = None + assert mod.get_state() == "Not connected" + + +class TestWebRtcMove: + def test_move_delegates(self) -> None: + mod = _make_webrtc_module() + mod.connection.move.return_value = True # type: ignore[union-attr] + twist = Twist(linear=Vector3(1.0, 0, 0), angular=Vector3(0, 0, 0)) + assert mod.move(twist, duration=2.0) is True + mod.connection.move.assert_called_once_with(twist, 2.0) # type: ignore[union-attr] + + +class TestWebRtcStandUp: + def test_stand_up_delegates(self) -> None: + mod = _make_webrtc_module() + mod.connection.standup.return_value = True # type: ignore[union-attr] + assert mod.stand_up() is True + mod.connection.standup.assert_called_once() # type: ignore[union-attr] + + +class TestWebRtcLieDown: + def test_lie_down_delegates(self) -> None: + mod = _make_webrtc_module() + mod.connection.liedown.return_value = True # type: ignore[union-attr] + assert mod.lie_down() is True + mod.connection.liedown.assert_called_once() # type: ignore[union-attr] + + +class TestWebRtcPublishRequest: + def test_delegates(self) -> None: + mod = _make_webrtc_module() + mod.connection.publish_request.return_value = {"code": 0} # type: ignore[union-attr] + result = mod.publish_request("topic", {"api_id": 7101}) + assert result == {"code": 0} + + +class TestWebRtcArmCommand: + def test_valid_command(self) -> None: + mod = _make_webrtc_module() + mod.connection.publish_request.return_value = {"code": 0} # type: ignore[union-attr] + result = mod.execute_arm_command("Handshake") + assert "successfully" in result + + def test_invalid_command(self) -> None: + mod = _make_webrtc_module() + result = mod.execute_arm_command("NotARealCommand") + assert "no" in result.lower() or "There's" in result + + +class TestWebRtcModeCommand: + def test_valid_command(self) -> None: + mod = _make_webrtc_module() + mod.connection.publish_request.return_value = {"code": 0} # type: ignore[union-attr] + result = mod.execute_mode_command("WalkMode") + assert "successfully" in result + + def test_invalid_command(self) -> None: + mod = _make_webrtc_module() + result = mod.execute_mode_command("FlyMode") + assert "no" in result.lower() or "There's" in result + + +# FSM State Machine model + transition tests + + +class FsmSimulator: + """Models the valid FSM transitions of the Unitree G1. + + Used to verify that stand_up / lie_down issue commands in a + valid order. + """ + + VALID_TRANSITIONS: dict[FsmState, set[FsmState]] = { + FsmState.ZERO_TORQUE: {FsmState.DAMP}, + FsmState.DAMP: {FsmState.AI_MODE, FsmState.SQUAT_STANDUP_TOGGLE, FsmState.ZERO_TORQUE}, + FsmState.SIT: {FsmState.DAMP, FsmState.SQUAT_STANDUP_TOGGLE}, + FsmState.AI_MODE: {FsmState.SQUAT_STANDUP_TOGGLE, FsmState.DAMP, FsmState.ZERO_TORQUE}, + FsmState.LIE_TO_STANDUP: {FsmState.DAMP, FsmState.SIT}, + FsmState.SQUAT_STANDUP_TOGGLE: { + FsmState.DAMP, + FsmState.AI_MODE, + FsmState.SIT, + FsmState.SQUAT_STANDUP_TOGGLE, + }, + } + + def __init__(self, initial: FsmState = FsmState.ZERO_TORQUE) -> None: + self.state = initial + self.history: list[FsmState] = [initial] + + def transition(self, target: FsmState) -> None: + # Self-transitions are no-ops on the real robot + if target == self.state: + self.history.append(target) + return + valid = self.VALID_TRANSITIONS.get(self.state, set()) + if target not in valid: + raise ValueError( + f"Invalid transition: {self.state.name} -> {target.name}. " + f"Valid targets: {[s.name for s in valid]}" + ) + self.state = target + self.history.append(target) + + +def _make_dds_with_fsm_sim( + initial_state: FsmState, *, ai_standup: bool = True +) -> tuple[G1HighLevelDdsSdk, FsmSimulator]: + """Build a DDS module whose loco_client tracks an FsmSimulator.""" + sim = FsmSimulator(initial_state) + mod = _make_dds_module(ai_standup=ai_standup) + + def mock_set_fsm_id(fsm_id: int) -> int: + sim.transition(FsmState(fsm_id)) + return 0 + + def mock_call(api_id: int, payload: str) -> tuple[int, str]: + return (0, json.dumps({"data": int(sim.state)})) + + mod.loco_client.SetFsmId.side_effect = mock_set_fsm_id + mod.loco_client._Call.side_effect = mock_call + + # StandUp2Squat is the high-level SDK wrapper around SQUAT_STANDUP_TOGGLE + def mock_standup2squat() -> None: + sim.transition(FsmState.SQUAT_STANDUP_TOGGLE) + + def mock_damp() -> None: + sim.transition(FsmState.DAMP) + + mod.loco_client.StandUp2Squat.side_effect = mock_standup2squat + mod.loco_client.Damp.side_effect = mock_damp + + return mod, sim + + +class TestFsmSimulator: + def test_valid_transition(self) -> None: + sim = FsmSimulator(FsmState.ZERO_TORQUE) + sim.transition(FsmState.DAMP) + assert sim.state == FsmState.DAMP + + def test_invalid_transition_raises(self) -> None: + sim = FsmSimulator(FsmState.ZERO_TORQUE) + with pytest.raises(ValueError, match="Invalid transition"): + sim.transition(FsmState.AI_MODE) + + def test_history_tracking(self) -> None: + sim = FsmSimulator(FsmState.ZERO_TORQUE) + sim.transition(FsmState.DAMP) + sim.transition(FsmState.AI_MODE) + assert sim.history == [FsmState.ZERO_TORQUE, FsmState.DAMP, FsmState.AI_MODE] + + +class TestStandUpTransitions: + def test_ai_standup_from_zero_torque_valid_transitions(self) -> None: + mod, sim = _make_dds_with_fsm_sim(FsmState.ZERO_TORQUE, ai_standup=True) + assert mod.stand_up() is True + assert sim.history == [ + FsmState.ZERO_TORQUE, + FsmState.DAMP, + FsmState.AI_MODE, + FsmState.SQUAT_STANDUP_TOGGLE, + ] + + def test_ai_standup_from_damp_valid_transitions(self) -> None: + mod, sim = _make_dds_with_fsm_sim(FsmState.DAMP, ai_standup=True) + assert mod.stand_up() is True + assert sim.history == [ + FsmState.DAMP, + FsmState.AI_MODE, + FsmState.SQUAT_STANDUP_TOGGLE, + ] + + def test_ai_standup_already_in_ai_mode(self) -> None: + mod, sim = _make_dds_with_fsm_sim(FsmState.AI_MODE, ai_standup=True) + assert mod.stand_up() is True + assert sim.history == [FsmState.AI_MODE, FsmState.SQUAT_STANDUP_TOGGLE] + + def test_normal_standup_from_zero_torque_invalid(self) -> None: + """Normal standup tries DAMP first, which is valid from ZERO_TORQUE.""" + mod, sim = _make_dds_with_fsm_sim(FsmState.ZERO_TORQUE, ai_standup=False) + assert mod.stand_up() is True + assert sim.history == [ + FsmState.ZERO_TORQUE, + FsmState.DAMP, + FsmState.SQUAT_STANDUP_TOGGLE, + ] + + def test_normal_standup_from_damp(self) -> None: + mod, sim = _make_dds_with_fsm_sim(FsmState.DAMP, ai_standup=False) + assert mod.stand_up() is True + assert sim.history == [ + FsmState.DAMP, + # DAMP -> DAMP is not in valid transitions, but SetFsmId + # is called unconditionally; the real robot handles this as a no-op. + # Our sim models it as valid since the robot stays in DAMP. + FsmState.DAMP, + FsmState.SQUAT_STANDUP_TOGGLE, + ] + + +class TestLieDownTransitions: + def test_lie_down_from_standing(self) -> None: + """Assumes the robot is in SQUAT_STANDUP_TOGGLE (standing) state.""" + mod, sim = _make_dds_with_fsm_sim(FsmState.SQUAT_STANDUP_TOGGLE) + assert mod.lie_down() is True + # StandUp2Squat toggles -> SQUAT_STANDUP_TOGGLE, then Damp -> DAMP + assert sim.history == [ + FsmState.SQUAT_STANDUP_TOGGLE, + FsmState.SQUAT_STANDUP_TOGGLE, + FsmState.DAMP, + ] + + def test_lie_down_from_ai_mode(self) -> None: + mod, sim = _make_dds_with_fsm_sim(FsmState.AI_MODE) + assert mod.lie_down() is True + assert FsmState.DAMP in sim.history diff --git a/dimos/robot/unitree/g1/effectors/high_level/webrtc.py b/dimos/robot/unitree/g1/effectors/high_level/webrtc.py new file mode 100644 index 0000000000..550123975e --- /dev/null +++ b/dimos/robot/unitree/g1/effectors/high_level/webrtc.py @@ -0,0 +1,219 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""G1 high-level control via WebRTC connection.""" + +import difflib +from typing import Any + +from reactivex.disposable import Disposable + +from dimos.agents.annotation import skill +from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.robot.unitree.connection import UnitreeWebRTCConnection +from dimos.robot.unitree.g1.effectors.high_level.high_level_spec import HighLevelG1Spec +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# G1 Arm Actions - all use api_id 7106 on topic "rt/api/arm/request" +G1_ARM_CONTROLS = [ + ("Handshake", 27, "Perform a handshake gesture with the right hand."), + ("HighFive", 18, "Give a high five with the right hand."), + ("Hug", 19, "Perform a hugging gesture with both arms."), + ("HighWave", 26, "Wave with the hand raised high."), + ("Clap", 17, "Clap hands together."), + ("FaceWave", 25, "Wave near the face level."), + ("LeftKiss", 12, "Blow a kiss with the left hand."), + ("ArmHeart", 20, "Make a heart shape with both arms overhead."), + ("RightHeart", 21, "Make a heart gesture with the right hand."), + ("HandsUp", 15, "Raise both hands up in the air."), + ("XRay", 24, "Hold arms in an X-ray pose position."), + ("RightHandUp", 23, "Raise only the right hand up."), + ("Reject", 22, "Make a rejection or 'no' gesture."), + ("CancelAction", 99, "Cancel any current arm action and return hands to neutral position."), +] + +# G1 Movement Modes - all use api_id 7101 on topic "rt/api/sport/request" +G1_MODE_CONTROLS = [ + ("WalkMode", 500, "Switch to normal walking mode."), + ("WalkControlWaist", 501, "Switch to walking mode with waist control."), + ("RunMode", 801, "Switch to running mode."), +] + +_ARM_COMMANDS: dict[str, tuple[int, str]] = { + name: (id_, description) for name, id_, description in G1_ARM_CONTROLS +} + +_MODE_COMMANDS: dict[str, tuple[int, str]] = { + name: (id_, description) for name, id_, description in G1_MODE_CONTROLS +} + +_ARM_COMMANDS_DOC = "\n".join(f'- "{name}": {desc}' for name, (_, desc) in _ARM_COMMANDS.items()) +_MODE_COMMANDS_DOC = "\n".join(f'- "{name}": {desc}' for name, (_, desc) in _MODE_COMMANDS.items()) + + +class G1HighLevelWebRtcConfig(ModuleConfig): + ip: str | None = None + connection_mode: str = "ai" + + +class G1HighLevelWebRtc(Module, HighLevelG1Spec): + """G1 high-level control module using WebRTC transport. + + Wraps :class:`UnitreeWebRTCConnection` and exposes the + :class:`HighLevelG1Spec` interface plus LLM-callable skills for + arm gestures, movement modes, and velocity control. + """ + + cmd_vel: In[Twist] + default_config = G1HighLevelWebRtcConfig + config: G1HighLevelWebRtcConfig + + connection: UnitreeWebRTCConnection | None + + def __init__(self, *args: Any, g: GlobalConfig = global_config, **kwargs: Any) -> None: + super().__init__(*args, g=g, **kwargs) + self._global_config = g + + # lifecycle + + @rpc + def start(self) -> None: + super().start() + assert self.config.ip is not None, "ip must be set in G1HighLevelWebRtcConfig" + self.connection = UnitreeWebRTCConnection(self.config.ip, self.config.connection_mode) + self.connection.start() + self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + + @rpc + def stop(self) -> None: + if self.connection is not None: + self.connection.stop() + super().stop() + + # HighLevelG1Spec + + @rpc + def move(self, twist: Twist, duration: float = 0.0) -> bool: + assert self.connection is not None + return self.connection.move(twist, duration) + + @rpc + def get_state(self) -> str: + if self.connection is None: + return "Not connected" + return "Connected (WebRTC)" + + @rpc + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[str, Any]: + logger.info(f"Publishing request to topic: {topic} with data: {data}") + assert self.connection is not None + return self.connection.publish_request(topic, data) # type: ignore[no-any-return] + + @rpc + def stand_up(self) -> bool: + assert self.connection is not None + return self.connection.standup() + + @rpc + def lie_down(self) -> bool: + assert self.connection is not None + return self.connection.liedown() + + # skills (LLM-callable) + + @skill + def move_velocity( + self, x: float, y: float = 0.0, yaw: float = 0.0, duration: float = 0.0 + ) -> str: + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions. + + Example call: + args = { "x": 0.5, "y": 0.0, "yaw": 0.0, "duration": 2.0 } + move_velocity(**args) + + Args: + x: Forward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds) + """ + twist = Twist(linear=Vector3(x, y, 0), angular=Vector3(0, 0, yaw)) + self.move(twist, duration=duration) + return f"Started moving with velocity=({x}, {y}, {yaw}) for {duration} seconds" + + @skill + def execute_arm_command(self, command_name: str) -> str: + """Execute a Unitree G1 arm command.""" + return self._execute_g1_command(_ARM_COMMANDS, 7106, "rt/api/arm/request", command_name) + + execute_arm_command.__doc__ = f"""Execute a Unitree G1 arm command. + + Example usage: + + execute_arm_command("ArmHeart") + + Here are all the command names and what they do. + + {_ARM_COMMANDS_DOC} + """ + + @skill + def execute_mode_command(self, command_name: str) -> str: + """Execute a Unitree G1 mode command.""" + return self._execute_g1_command(_MODE_COMMANDS, 7101, "rt/api/sport/request", command_name) + + execute_mode_command.__doc__ = f"""Execute a Unitree G1 mode command. + + Example usage: + + execute_mode_command("RunMode") + + Here are all the command names and what they do. + + {_MODE_COMMANDS_DOC} + """ + + # private helpers + + def _execute_g1_command( + self, + command_dict: dict[str, tuple[int, str]], + api_id: int, + topic: str, + command_name: str, + ) -> str: + if command_name not in command_dict: + suggestions = difflib.get_close_matches( + command_name, command_dict.keys(), n=3, cutoff=0.6 + ) + return f"There's no '{command_name}' command. Did you mean: {suggestions}" + + id_, _ = command_dict[command_name] + + try: + self.publish_request(topic, {"api_id": api_id, "parameter": {"data": id_}}) + return f"'{command_name}' command executed successfully." + except Exception as e: + logger.error(f"Failed to execute {command_name}: {e}") + return "Failed to execute the command." + + +__all__ = ["G1HighLevelWebRtc", "G1HighLevelWebRtcConfig"] diff --git a/dimos/robot/unitree/g1/sim.py b/dimos/robot/unitree/g1/mujoco_sim.py similarity index 99% rename from dimos/robot/unitree/g1/sim.py rename to dimos/robot/unitree/g1/mujoco_sim.py index 22fc33a978..ee7e670749 100644 --- a/dimos/robot/unitree/g1/sim.py +++ b/dimos/robot/unitree/g1/mujoco_sim.py @@ -148,3 +148,6 @@ def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: logger.info(f"Publishing request to topic: {topic} with data: {data}") assert self.connection is not None return self.connection.publish_request(topic, data) + + +__all__ = ["G1SimConnection"] diff --git a/dimos/robot/unitree/g1/skill_container.py b/dimos/robot/unitree/g1/skill_container.py index c3825ed29c..4c75f38387 100644 --- a/dimos/robot/unitree/g1/skill_container.py +++ b/dimos/robot/unitree/g1/skill_container.py @@ -153,3 +153,6 @@ def _execute_g1_command( {_mode_commands} """ + + +__all__ = ["UnitreeG1SkillContainer"] diff --git a/dimos/robot/unitree/g1/tests/test_arrow_control.py b/dimos/robot/unitree/g1/tests/test_arrow_control.py new file mode 100755 index 0000000000..9007e6887d --- /dev/null +++ b/dimos/robot/unitree/g1/tests/test_arrow_control.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Arrow key control for G1 robot. +Use arrow keys and WASD for real-time robot control. +""" + +import curses +import time +from typing import Any + +from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import G1HighLevelDdsSdk + + +def draw_ui(stdscr: Any, state_text: str = "Not connected") -> None: + """Draw the control UI.""" + stdscr.clear() + height, width = stdscr.getmaxyx() + + # Title + title = "🤖 G1 Arrow Key Control" + stdscr.addstr(0, (width - len(title)) // 2, title, curses.A_BOLD) + + # Controls + controls = [ + "", + "Movement Controls:", + " ↑/W - Move forward", + " ↓/S - Move backward", + " ←/A - Rotate left", + " →/D - Rotate right", + " Q - Strafe left", + " E - Strafe right", + " SPACE - Stop", + "", + "Robot Controls:", + " 1 - Stand up", + " 2 - Lie down", + " R - Show robot state", + "", + " ESC/Ctrl+C - Quit", + "", + f"Status: {state_text}", + ] + + start_row = 2 + for i, line in enumerate(controls): + if i < height - 1: + stdscr.addstr(start_row + i, 2, line) + + stdscr.refresh() + + +def main(stdscr: Any) -> None: + # Setup curses + curses.curs_set(0) # Hide cursor + stdscr.nodelay(1) # Non-blocking input + stdscr.timeout(100) # 100ms timeout for getch() + + draw_ui(stdscr, "Initializing...") + + # Initialize connection + conn = G1HighLevelDdsSdk(network_interface="eth0") + conn.start() + time.sleep(1) + + draw_ui(stdscr, "✓ Connected - Ready for commands") + + # Movement parameters + linear_speed = 0.3 # m/s for forward/backward/strafe + angular_speed = 0.5 # rad/s for rotation + move_duration = 0.2 # Duration of each movement pulse + + try: + last_cmd_time = 0.0 + cmd_cooldown = 0.15 # Minimum time between commands + + while True: + key = stdscr.getch() + current_time = time.time() + + # Skip if in cooldown period + if current_time - last_cmd_time < cmd_cooldown: + continue + + if key == -1: # No key pressed + continue + + # Handle quit + if key == 27 or key == 3: # ESC or Ctrl+C + break + + # Convert key to character + try: + key_char = chr(key).lower() if key < 256 else None + except ValueError: + key_char = None + + # Movement commands + twist = None + action = None + + # Arrow keys + if key == curses.KEY_UP or key_char == "w": + twist = Twist(linear=Vector3(linear_speed, 0, 0), angular=Vector3(0, 0, 0)) + action = "Moving forward..." + elif key == curses.KEY_DOWN or key_char == "s": + twist = Twist(linear=Vector3(-linear_speed, 0, 0), angular=Vector3(0, 0, 0)) + action = "Moving backward..." + elif key == curses.KEY_LEFT or key_char == "a": + twist = Twist(linear=Vector3(0, 0, 0), angular=Vector3(0, 0, angular_speed)) + action = "Rotating left..." + elif key == curses.KEY_RIGHT or key_char == "d": + twist = Twist(linear=Vector3(0, 0, 0), angular=Vector3(0, 0, -angular_speed)) + action = "Rotating right..." + elif key_char == "q": + twist = Twist(linear=Vector3(0, linear_speed, 0), angular=Vector3(0, 0, 0)) + action = "Strafing left..." + elif key_char == "e": + twist = Twist(linear=Vector3(0, -linear_speed, 0), angular=Vector3(0, 0, 0)) + action = "Strafing right..." + elif key_char == " ": + conn.move( + Twist(linear=Vector3(0, 0, 0), angular=Vector3(0, 0, 0)), duration=move_duration + ) + action = "🛑 Stopped" + last_cmd_time = current_time + + # Robot state commands + elif key_char == "1": + draw_ui(stdscr, "Standing up...") + conn.stand_up() + action = "✓ Standup complete" + last_cmd_time = current_time + elif key_char == "2": + draw_ui(stdscr, "Lying down...") + conn.lie_down() + action = "✓ Liedown complete" + last_cmd_time = current_time + elif key_char == "r": + state = conn.get_state() + action = f"State: {state}" + last_cmd_time = current_time + + # Execute movement + if twist is not None: + conn.move(twist, duration=move_duration) + last_cmd_time = current_time + + # Update UI with action + if action: + draw_ui(stdscr, action) + + except KeyboardInterrupt: + pass + finally: + draw_ui(stdscr, "Stopping and disconnecting...") + conn.disconnect() + draw_ui(stdscr, "✓ Disconnected") + time.sleep(1) + + +if __name__ == "__main__": + print("\n⚠️ WARNING: Ensure area is clear around robot!") + print("Starting in 3 seconds...") + time.sleep(3) + + try: + curses.wrapper(main) + except Exception as e: + print(f"\n✗ Error: {e}") + import traceback + + traceback.print_exc() + + print("\n✓ Done") diff --git a/dimos/robot/unitree/g1/tests/test_arrow_control_cmd_vel.py b/dimos/robot/unitree/g1/tests/test_arrow_control_cmd_vel.py new file mode 100644 index 0000000000..d53ec6fffd --- /dev/null +++ b/dimos/robot/unitree/g1/tests/test_arrow_control_cmd_vel.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Arrow key control for G1 robot via cmd_vel LCM topic. +Use arrow keys and WASD for real-time robot control. +Publishes Twist messages on /cmd_vel instead of calling .move() directly. +""" + +import curses +import time +from typing import Any + +import lcm + +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.robot.unitree.g1.effectors.high_level.dds_sdk import G1HighLevelDdsSdk + +CMD_VEL_CHANNEL = "/cmd_vel#geometry_msgs.Twist" + + +def publish_twist(lc: lcm.LCM, twist: Twist) -> None: + lc.publish(CMD_VEL_CHANNEL, twist.lcm_encode()) + + +def draw_ui(stdscr: Any, state_text: str = "Not connected") -> None: + """Draw the control UI.""" + stdscr.clear() + height, width = stdscr.getmaxyx() + + title = "G1 Arrow Key Control (cmd_vel)" + stdscr.addstr(0, (width - len(title)) // 2, title, curses.A_BOLD) + + controls = [ + "", + "Movement Controls:", + " UP/W - Move forward", + " DOWN/S - Move backward", + " LEFT/A - Rotate left", + " RIGHT/D - Rotate right", + " Q - Strafe left", + " E - Strafe right", + " SPACE - Stop", + "", + "Robot Controls:", + " 1 - Stand up", + " 2 - Lie down", + "", + " ESC/Ctrl+C - Quit", + "", + f"Status: {state_text}", + ] + + start_row = 2 + for i, line in enumerate(controls): + if i < height - 1: + stdscr.addstr(start_row + i, 2, line) + + stdscr.refresh() + + +def main(stdscr: Any) -> None: + curses.curs_set(0) + stdscr.nodelay(1) + stdscr.timeout(100) + + draw_ui(stdscr, "Initializing...") + + # Set up G1HighLevelDdsSdk with cmd_vel LCM transport so it subscribes + conn = G1HighLevelDdsSdk(network_interface="eth0") + conn.cmd_vel.transport = LCMTransport("/cmd_vel", Twist) + conn.start() + time.sleep(1) + + # Raw LCM publisher — messages go to the transport above + lc = lcm.LCM() + + draw_ui(stdscr, "Connected - publishing on " + CMD_VEL_CHANNEL) + + linear_speed = 0.3 # m/s + angular_speed = 0.5 # rad/s + cmd_cooldown = 0.15 + + try: + last_cmd_time = 0.0 + + while True: + key = stdscr.getch() + current_time = time.time() + + if current_time - last_cmd_time < cmd_cooldown: + continue + + if key == -1: + continue + + if key == 27 or key == 3: # ESC or Ctrl+C + break + + try: + key_char = chr(key).lower() if key < 256 else None + except ValueError: + key_char = None + + twist = None + action = None + + if key == curses.KEY_UP or key_char == "w": + twist = Twist(linear=Vector3(linear_speed, 0, 0), angular=Vector3(0, 0, 0)) + action = "Moving forward..." + elif key == curses.KEY_DOWN or key_char == "s": + twist = Twist(linear=Vector3(-linear_speed, 0, 0), angular=Vector3(0, 0, 0)) + action = "Moving backward..." + elif key == curses.KEY_LEFT or key_char == "a": + twist = Twist(linear=Vector3(0, 0, 0), angular=Vector3(0, 0, angular_speed)) + action = "Rotating left..." + elif key == curses.KEY_RIGHT or key_char == "d": + twist = Twist(linear=Vector3(0, 0, 0), angular=Vector3(0, 0, -angular_speed)) + action = "Rotating right..." + elif key_char == "q": + twist = Twist(linear=Vector3(0, linear_speed, 0), angular=Vector3(0, 0, 0)) + action = "Strafing left..." + elif key_char == "e": + twist = Twist(linear=Vector3(0, -linear_speed, 0), angular=Vector3(0, 0, 0)) + action = "Strafing right..." + elif key_char == " ": + stop = Twist(linear=Vector3(0, 0, 0), angular=Vector3(0, 0, 0)) + publish_twist(lc, stop) + action = "Stopped" + last_cmd_time = current_time + elif key_char == "1": + draw_ui(stdscr, "Standing up...") + conn.stand_up() + action = "Standup complete" + last_cmd_time = current_time + elif key_char == "2": + draw_ui(stdscr, "Lying down...") + conn.lie_down() + action = "Liedown complete" + last_cmd_time = current_time + + if twist is not None: + publish_twist(lc, twist) + last_cmd_time = current_time + + if action: + draw_ui(stdscr, action) + + except KeyboardInterrupt: + pass + finally: + draw_ui(stdscr, "Stopping...") + stop = Twist(linear=Vector3(0, 0, 0), angular=Vector3(0, 0, 0)) + publish_twist(lc, stop) + time.sleep(0.5) + conn.disconnect() + draw_ui(stdscr, "Done") + time.sleep(1) + + +if __name__ == "__main__": + print("\nWARNING: Ensure area is clear around robot!") + print("Starting in 3 seconds...") + time.sleep(3) + + try: + curses.wrapper(main) + except Exception as e: + print(f"\nError: {e}") + import traceback + + traceback.print_exc() + + print("\nDone") diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index f32561e11d..cae339e957 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -25,7 +25,6 @@ from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.connection import GO2Connection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule # Mac has some issue with high bandwidth UDP, so we use pSHMTransport for color_image # actually we can use pSHMTransport for all platforms, and for all streams @@ -115,28 +114,20 @@ def _go2_rerun_blueprint() -> Any: } -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module - with_vis = autoconnect( - _transports_base, - FoxgloveBridge.blueprint(shm_channels=["/color_image#sensor_msgs.Image"]), - ) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode +_vis = vis_module( + viewer_backend=global_config.viewer, + rerun_config=rerun_config, + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, +) - with_vis = autoconnect( - _transports_base, - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), - ) -else: - with_vis = _transports_base +_with_vis = autoconnect(_transports_base, _vis) unitree_go2_basic = ( autoconnect( - with_vis, + _with_vis, GO2Connection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index 1c55f3e93c..f8ade355e8 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -22,15 +22,13 @@ from dimos.core.blueprints import autoconnect from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator -from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import _with_vis as with_vis from dimos.robot.unitree.go2.fleet_connection import Go2FleetConnection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule unitree_go2_fleet = ( autoconnect( with_vis, Go2FleetConnection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_smartnav.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_smartnav.py new file mode 100644 index 0000000000..dfc69a859a --- /dev/null +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_smartnav.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Go2 SmartNav: native C++ navigation with PGO loop closure. + +Uses the SmartNav native modules (terrain analysis, local planner, +path follower) with PGO for loop-closure-corrected odometry. +OdomAdapter bridges GO2Connection's PoseStamped odom to Odometry +for the SmartNav modules. + +Data flow: + GO2Connection.lidar → registered_scan → TerrainAnalysis + LocalPlanner + PGO + GO2Connection.odom → raw_odom → OdomAdapter → odometry → all nav modules + PGO.corrected_odometry → OdomAdapter → odom (corrected PoseStamped) + TerrainAnalysis → terrain_map → TerrainMapExt → LocalPlanner + LocalPlanner → path → PathFollower → nav_cmd_vel → CmdVelMux → cmd_vel + ClickToGoal → way_point → LocalPlanner + Keyboard teleop → tele_cmd_vel → CmdVelMux → cmd_vel → GO2Connection +""" + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.navigation.smartnav.modules.click_to_goal.click_to_goal import ClickToGoal +from dimos.navigation.smartnav.modules.cmd_vel_mux import CmdVelMux +from dimos.navigation.smartnav.modules.local_planner.local_planner import LocalPlanner +from dimos.navigation.smartnav.modules.odom_adapter.odom_adapter import OdomAdapter +from dimos.navigation.smartnav.modules.path_follower.path_follower import PathFollower +from dimos.navigation.smartnav.modules.pgo.pgo import PGO +from dimos.navigation.smartnav.modules.sensor_scan_generation.sensor_scan_generation import ( + SensorScanGeneration, +) +from dimos.navigation.smartnav.modules.terrain_analysis.terrain_analysis import TerrainAnalysis +from dimos.navigation.smartnav.modules.terrain_map_ext.terrain_map_ext import TerrainMapExt +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer +from dimos.visualization.vis_module import vis_module + + +def _convert_camera_info(camera_info: Any) -> Any: + return camera_info.to_rerun( + image_topic="/world/color_image", + optical_frame="camera_optical", + ) + + +def _convert_global_map(grid: Any) -> Any: + return grid.to_rerun(voxel_size=0.1, mode="boxes") + + +def _convert_navigation_costmap(grid: Any) -> Any: + return grid.to_rerun( + colormap="Accent", + z_offset=0.015, + opacity=0.2, + background="#484981", + ) + + +def _static_base_link(rr: Any) -> list[Any]: + return [ + rr.Boxes3D( + half_sizes=[0.35, 0.155, 0.2], + colors=[(0, 255, 127)], + fill_mode="wireframe", + ), + rr.Transform3D(parent_frame="tf#/base_link"), + ] + + +def _go2_rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Horizontal( + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + rrb.Spatial3DView(origin="world", name="3D"), + column_shares=[1, 2], + ), + ) + + +_vis = vis_module( + viewer_backend=global_config.viewer, + rerun_config={ + "blueprint": _go2_rerun_blueprint, + "pubsubs": [LCM()], + "visual_override": { + "world/camera_info": _convert_camera_info, + "world/global_map": _convert_global_map, + "world/navigation_costmap": _convert_navigation_costmap, + }, + "static": { + "world/tf/base_link": _static_base_link, + }, + }, +) + +unitree_go2_smartnav = ( + autoconnect( + GO2Connection.blueprint(), + SensorScanGeneration.blueprint(), + OdomAdapter.blueprint(), + PGO.blueprint(), + TerrainAnalysis.blueprint(extra_args=["--obstacleHeightThre", "0.2", "--maxRelZ", "1.5"]), + TerrainMapExt.blueprint(), + LocalPlanner.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--obstacleHeightThre", + "0.2", + "--maxRelZ", + "1.5", + "--minRelZ", + "-0.5", + ] + ), + PathFollower.blueprint( + extra_args=[ + "--autonomyMode", + "true", + "--maxSpeed", + "1.0", + "--autonomySpeed", + "1.0", + "--maxAccel", + "2.0", + "--slowDwnDisThre", + "0.2", + ] + ), + ClickToGoal.blueprint(), + CmdVelMux.blueprint(), + _vis, + ) + .remappings( + [ + # GO2Connection outputs PoseStamped odom, rename to avoid collision + # with OdomAdapter's Odometry output + (GO2Connection, "odom", "raw_odom"), + (GO2Connection, "lidar", "registered_scan"), + # PathFollower cmd_vel → CmdVelMux nav input + (PathFollower, "cmd_vel", "nav_cmd_vel"), + # Keyboard teleop → CmdVelMux + (RerunWebSocketServer, "tele_cmd_vel", "tele_cmd_vel"), + # ClickToGoal plans at global scale — needs PGO-corrected odometry + (ClickToGoal, "odometry", "corrected_odometry"), + (TerrainAnalysis, "odometry", "corrected_odometry"), + ] + ) + .global_config(n_workers=8, robot_model="unitree_go2") +) + +__all__ = ["unitree_go2_smartnav"] diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index 5123dc9a31..e245ab13f4 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -253,8 +253,6 @@ def onimage(image: Image) -> None: self.connection.balance_stand() self.connection.set_obstacle_avoidance(self.config.g.obstacle_avoidance) - # self.record("go2_bigoffice") - @rpc def stop(self) -> None: self.liedown() diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index d051154065..837433927f 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -74,6 +74,10 @@ # LFS data asset name for the Unity sim binary _LFS_ASSET = "unity_sim_x86" +# Google Drive folder containing VLA Challenge environment zips +_GDRIVE_FOLDER_ID = "1UD5v6cSfcwIMWmsq9WSk7blJut4kgb-1" +_DEFAULT_SCENE = "office_1" + # Read timeout for the Unity TCP connection (seconds). If Unity stops # sending data for longer than this the bridge treats it as a hung # connection and drops it. @@ -146,6 +150,61 @@ def _validate_platform() -> None: ) +def _download_unity_scene(scene: str, dest_dir: Path) -> Path: + """Download a Unity environment zip from Google Drive and extract it. + + Returns the path to the Model.x86_64 binary. + """ + import zipfile + + try: + import gdown # type: ignore[import-untyped] + except ImportError: + raise RuntimeError( + "Unity sim binary not found and 'gdown' is not installed for auto-download. " + "Install it with: pip install gdown\n" + "Or manually download from: " + f"https://drive.google.com/drive/folders/{_GDRIVE_FOLDER_ID}" + ) from None + + dest_dir.mkdir(parents=True, exist_ok=True) + zip_path = dest_dir / f"{scene}.zip" + + if not zip_path.exists(): + print("\n" + "=" * 70, flush=True) + print(f" DOWNLOADING UNITY SIMULATOR — scene: '{scene}'", flush=True) + print(" Source: Google Drive (VLA Challenge environments)", flush=True) + print(f" Destination: {dest_dir}", flush=True) + print(" This is a one-time download.", flush=True) + print("=" * 70 + "\n", flush=True) + gdown.download_folder(id=_GDRIVE_FOLDER_ID, output=str(dest_dir), quiet=False) + for candidate in dest_dir.rglob(f"{scene}.zip"): + zip_path = candidate + break + + if not zip_path.exists(): + raise FileNotFoundError( + f"Failed to download scene '{scene}'. " + f"Check https://drive.google.com/drive/folders/{_GDRIVE_FOLDER_ID}" + ) + + extract_dir = dest_dir / scene + if not extract_dir.exists(): + logger.info(f"Extracting {zip_path}...") + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(dest_dir) + + binary = extract_dir / "environment" / "Model.x86_64" + if not binary.exists(): + raise FileNotFoundError( + f"Extracted scene but Model.x86_64 not found at {binary}. " + f"Expected structure: {scene}/environment/Model.x86_64" + ) + + binary.chmod(binary.stat().st_mode | 0o111) + return binary + + # Config @@ -158,9 +217,19 @@ class UnityBridgeConfig(ModuleConfig): """ # Path to the Unity x86_64 binary. Leave empty to auto-resolve - # from LFS data (unity_sim_x86/environment/Model.x86_64). + # from LFS data or auto-download from Google Drive. unity_binary: str = "" + # Scene name for auto-download (e.g. "office_1", "hotel_room_1"). + # Only used when unity_binary is not found and auto_download is True. + unity_scene: str = _DEFAULT_SCENE + + # Directory to download/cache Unity scenes. + unity_cache_dir: str = "~/.cache/dimos/unity_envs" + + # Auto-download the scene from Google Drive if binary is missing. + auto_download: bool = True + # Max seconds to wait for Unity to connect after launch. unity_connect_timeout: float = 30.0 @@ -356,6 +425,14 @@ def _resolve_binary(self) -> Path | None: except Exception as e: logger.warning(f"Failed to resolve Unity binary from LFS: {e}") + # Auto-download from Google Drive (VLA Challenge scenes) + if cfg.auto_download: + try: + cache = Path(cfg.unity_cache_dir).expanduser() + return _download_unity_scene(cfg.unity_scene, cache) + except Exception as e: + logger.warning(f"Auto-download failed: {e}") + return None def _launch_unity(self) -> None: diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index d6367310de..1b67de3b75 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -26,12 +26,12 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_extensions import ArmTeleopModule from dimos.teleop.quest.quest_types import Buttons -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module # Arm teleop with press-and-hold engage (has rerun viz) teleop_quest_rerun = autoconnect( ArmTeleopModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).transports( { ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), diff --git a/dimos/test_no_sections.py b/dimos/test_no_sections.py index 902288b2e6..377bb442d3 100644 --- a/dimos/test_no_sections.py +++ b/dimos/test_no_sections.py @@ -75,7 +75,10 @@ def _should_scan(path: str) -> bool: def _is_ignored_dir(dirpath: str) -> bool: parts = dirpath.split(os.sep) - return bool(IGNORED_DIRS.intersection(parts)) + if IGNORED_DIRS.intersection(parts): + return True + # Skip directories with .ignore suffix (e.g. logs.ignore/) + return any(p.endswith(".ignore") for p in parts) def _is_whitelisted(rel_path: str, line: str) -> bool: @@ -91,7 +94,7 @@ def find_section_markers() -> list[tuple[str, int, str]]: for dirpath, dirnames, filenames in os.walk(REPO_ROOT): # Prune ignored directories in-place - dirnames[:] = [d for d in dirnames if d not in IGNORED_DIRS] + dirnames[:] = [d for d in dirnames if d not in IGNORED_DIRS and not d.endswith(".ignore")] if _is_ignored_dir(dirpath): continue diff --git a/dimos/utils/change_detect.py b/dimos/utils/change_detect.py new file mode 100644 index 0000000000..357e60382f --- /dev/null +++ b/dimos/utils/change_detect.py @@ -0,0 +1,228 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Change detection utility for file content hashing. + +Tracks whether a set of files (by path, directory, or glob pattern) have +changed since the last check. Useful for skipping expensive rebuilds when +source files haven't been modified. + +Path entries are type-dispatched: + +- ``str`` / ``Path`` / ``LfsPath`` — treated as **literal** file or directory + paths (no glob expansion, even if the path contains ``*``). +- ``Glob`` — expanded with :func:`glob.glob` to match filesystem patterns. +""" + +from __future__ import annotations + +from collections.abc import Sequence +import fcntl +import glob as glob_mod +import hashlib +import os +from pathlib import Path +from typing import Union + +import xxhash + +from dimos.utils.data import LfsPath +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class Glob(str): + """A string that should be interpreted as a filesystem glob pattern. + + Wraps a plain ``str`` to signal that :func:`did_change` should expand it + with :func:`glob.glob` rather than treating it as a literal path. + + Example:: + + Glob("src/**/*.c") + """ + + +PathEntry = Union[str, Path, LfsPath, Glob] +"""A single entry in a change-detection path list.""" + + +def _get_cache_dir() -> Path: + """Return the directory used to store change-detection cache files. + + Uses ``/dimos_cache/change_detect/`` when running inside a + venv, otherwise falls back to ``~/.cache/dimos/change_detect/``. + """ + venv = os.environ.get("VIRTUAL_ENV") + if venv: + return Path(venv) / "dimos_cache" / "change_detect" + return Path.home() / ".cache" / "dimos" / "change_detect" + + +def _safe_filename(cache_name: str) -> str: + """Convert an arbitrary cache name into a safe filename.""" + safe_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-") + if all(c in safe_chars for c in cache_name) and len(cache_name) <= 200: + return cache_name + digest = hashlib.sha256(cache_name.encode()).hexdigest()[:16] + return digest + + +def _add_path(files: set[Path], p: Path) -> None: + """Add *p* (file or directory, walked recursively) to *files*.""" + if p.is_file(): + files.add(p.resolve()) + elif p.is_dir(): + for root, _dirs, filenames in os.walk(p): + for fname in filenames: + files.add(Path(root, fname).resolve()) + + +def _resolve_paths(paths: Sequence[PathEntry], cwd: str | Path | None = None) -> list[Path]: + """Resolve a mixed list of path entries into a sorted list of files.""" + files: set[Path] = set() + for entry in paths: + if isinstance(entry, Glob): + pattern = str(entry) + if not Path(pattern).is_absolute(): + if cwd is None: + raise ValueError( + f"Relative path {pattern!r} passed to change detection without a cwd. " + "Either provide an absolute path or pass cwd= so relatives can be resolved." + ) + pattern = str(Path(cwd) / pattern) + expanded = glob_mod.glob(pattern, recursive=True) + if not expanded: + logger.warning("Glob pattern matched no files", pattern=pattern) + continue + for match in expanded: + _add_path(files, Path(match)) + else: + path_str = str(entry) + if not Path(path_str).is_absolute(): + if cwd is None: + raise ValueError( + f"Relative path {path_str!r} passed to change detection without a cwd. " + "Either provide an absolute path or pass cwd= so relatives can be resolved." + ) + path_str = str(Path(cwd) / path_str) + p = Path(path_str) + if not p.exists(): + logger.warning("Path does not exist", path=path_str) + continue + _add_path(files, p) + return sorted(files) + + +def _hash_files(files: list[Path]) -> str: + """Compute an aggregate xxhash digest over the sorted file list.""" + h = xxhash.xxh64() + for fpath in files: + try: + # Include the path so additions/deletions/renames are detected + h.update(str(fpath).encode()) + h.update(fpath.read_bytes()) + except (OSError, PermissionError): + logger.warning("Cannot read file for hashing", path=str(fpath)) + return h.hexdigest() + + +def did_change( + cache_name: str, + paths: Sequence[PathEntry], + cwd: str | Path | None = None, +) -> bool: + """Check if any files/dirs matching the given paths have changed since last check. + + Examples:: + + # Absolute paths — no cwd needed + did_change("my_build", ["/src/main.cpp"]) + + # Use Glob for wildcard patterns (str is always literal) + did_change("c_sources", [Glob("/src/**/*.c"), Glob("/include/**/*.h")]) + + # Relative paths — must pass cwd + did_change("my_build", ["src/main.cpp"], cwd="/home/user/project") + + # Mix literal paths and globs + did_change("config_check", ["config.yaml", Glob("templates/*.j2")], cwd="/project") + + # Track a whole directory (walked recursively) + did_change("assets", ["/data/models/"]) + + # Second call with no file changes → False + did_change("my_build", ["/src/main.cpp"]) # True (first call, no cache) + did_change("my_build", ["/src/main.cpp"]) # False (nothing changed) + + # After editing a file → True again + Path("/src/main.cpp").write_text("// changed") + did_change("my_build", ["/src/main.cpp"]) # True + + # Relative path without cwd → ValueError + did_change("bad", ["src/main.cpp"]) # raises ValueError + + Returns ``True`` on the first call (no previous cache), and on subsequent + calls returns ``True`` only if file contents differ from the last check. + The cache is always updated, so two consecutive calls with no changes + return ``True`` then ``False``. + """ + if not paths: + return False + + files = _resolve_paths(paths, cwd=cwd) + + if not files: + logger.warning( + "No source files found for change detection, skipping rebuild check", + cache_name=cache_name, + ) + return False + + current_hash = _hash_files(files) + + cache_dir = _get_cache_dir() + cache_dir.mkdir(parents=True, exist_ok=True) + cache_file = cache_dir / f"{_safe_filename(cache_name)}.hash" + lock_file = cache_dir / f"{_safe_filename(cache_name)}.lock" + + changed = True + with open(lock_file, "w") as lf: + fcntl.flock(lf, fcntl.LOCK_EX) + try: + if cache_file.exists(): + previous_hash = cache_file.read_text().strip() + changed = current_hash != previous_hash + # Always update the cache with the current hash + cache_file.write_text(current_hash) + finally: + fcntl.flock(lf, fcntl.LOCK_UN) + + return changed + + +def clear_cache(cache_name: str) -> bool: + """Remove the cached hash so the next ``did_change`` call returns ``True``. + + Example:: + + clear_cache("my_build") + did_change("my_build", ["/src/main.c"]) # always True after clear + """ + cache_file = _get_cache_dir() / f"{_safe_filename(cache_name)}.hash" + if cache_file.exists(): + cache_file.unlink() + return True + return False diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py index 84168ce057..6aa1859659 100644 --- a/dimos/utils/generic.py +++ b/dimos/utils/generic.py @@ -13,13 +13,52 @@ # limitations under the License. from collections.abc import Callable +import functools import hashlib import json import os +from pathlib import Path +import platform import string +import sys from typing import Any, Generic, TypeVar, overload import uuid + +@functools.lru_cache(maxsize=1) +def is_jetson() -> bool: + """Check if running on an NVIDIA Jetson device.""" + if sys.platform != "linux": + return False + # Check kernel release for Tegra (most lightweight) + if "tegra" in platform.release().lower(): + return True + # Check device tree (works in containers with proper mounts) + try: + return "nvidia,tegra" in Path("/proc/device-tree/compatible").read_text() + except (FileNotFoundError, PermissionError): + pass + # Check for L4T release file + return Path("/etc/nv_tegra_release").exists() + + +def get_local_ips() -> list[tuple[str, str]]: + """Return ``(ip, interface_name)`` for every non-loopback IPv4 address. + + Picks up physical, virtual, and VPN interfaces (including Tailscale). + """ + import socket + + import psutil + + results: list[tuple[str, str]] = [] + for iface, addrs in psutil.net_if_addrs().items(): + for addr in addrs: + if addr.family == socket.AF_INET and not addr.address.startswith("127."): + results.append((addr.address, iface)) + return results + + _T = TypeVar("_T") diff --git a/dimos/utils/test_change_detect.py b/dimos/utils/test_change_detect.py new file mode 100644 index 0000000000..42bd8a62e9 --- /dev/null +++ b/dimos/utils/test_change_detect.py @@ -0,0 +1,135 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the change detection utility.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from dimos.utils.change_detect import Glob, clear_cache, did_change + + +@pytest.fixture(autouse=True) +def _use_tmp_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Redirect the change-detection cache to a temp dir for every test.""" + monkeypatch.setattr( + "dimos.utils.change_detect._get_cache_dir", + lambda: tmp_path / "cache", + ) + + +@pytest.fixture() +def src_dir(tmp_path: Path) -> Path: + """A temp directory with two source files for testing.""" + d = tmp_path / "src" + d.mkdir() + (d / "a.c").write_text("int main() { return 0; }") + (d / "b.c").write_text("void helper() {}") + return d + + +def test_first_call_returns_true(src_dir: Path) -> None: + assert did_change("test_cache", [str(src_dir)]) is True + + +def test_second_call_no_change_returns_false(src_dir: Path) -> None: + did_change("test_cache", [str(src_dir)]) + assert did_change("test_cache", [str(src_dir)]) is False + + +def test_file_modified_returns_true(src_dir: Path) -> None: + did_change("test_cache", [str(src_dir)]) + (src_dir / "a.c").write_text("int main() { return 1; }") + assert did_change("test_cache", [str(src_dir)]) is True + + +def test_file_added_to_dir_returns_true(src_dir: Path) -> None: + did_change("test_cache", [str(src_dir)]) + (src_dir / "c.c").write_text("void new_func() {}") + assert did_change("test_cache", [str(src_dir)]) is True + + +def test_file_deleted_returns_true(src_dir: Path) -> None: + did_change("test_cache", [str(src_dir)]) + (src_dir / "b.c").unlink() + assert did_change("test_cache", [str(src_dir)]) is True + + +def test_glob_pattern(src_dir: Path) -> None: + pattern = Glob(str(src_dir / "*.c")) + assert did_change("glob_cache", [pattern]) is True + assert did_change("glob_cache", [pattern]) is False + (src_dir / "a.c").write_text("changed!") + assert did_change("glob_cache", [pattern]) is True + + +def test_str_with_glob_chars_is_literal(tmp_path: Path) -> None: + """A plain str containing '*' must NOT be glob-expanded.""" + weird_name = tmp_path / "file[1].txt" + weird_name.write_text("content") + # str path — treated literally, should find the file + assert did_change("literal_test", [str(weird_name)]) is True + assert did_change("literal_test", [str(weird_name)]) is False + + +def test_separate_cache_names_independent(src_dir: Path) -> None: + paths = [str(src_dir)] + did_change("cache_a", paths) + did_change("cache_b", paths) + # Both caches are now up-to-date + assert did_change("cache_a", paths) is False + assert did_change("cache_b", paths) is False + # Modify a file — both caches should report changed independently + (src_dir / "a.c").write_text("changed") + assert did_change("cache_a", paths) is True + # cache_b hasn't been checked since the change + assert did_change("cache_b", paths) is True + + +def test_clear_cache(src_dir: Path) -> None: + paths = [str(src_dir)] + did_change("clear_test", paths) + assert did_change("clear_test", paths) is False + assert clear_cache("clear_test") is True + assert did_change("clear_test", paths) is True + + +def test_clear_cache_nonexistent() -> None: + assert clear_cache("does_not_exist") is False + + +def test_empty_paths_returns_false() -> None: + assert did_change("empty_test", []) is False + + +def test_nonexistent_path_warns(caplog: pytest.LogCaptureFixture) -> None: + """A non-existent absolute path logs a warning and doesn't crash.""" + result = did_change("missing_test", ["/nonexistent/path/to/file.c"]) + # No resolvable files → returns False (skip rebuild) + assert result is False + + +def test_relative_path_without_cwd_raises() -> None: + """Relative paths without cwd= should raise ValueError.""" + with pytest.raises(ValueError, match="Relative path.*without a cwd"): + did_change("rel_test", ["some/relative/path.c"]) + + +def test_relative_path_with_cwd(src_dir: Path) -> None: + """Relative paths should resolve against the provided cwd.""" + assert did_change("cwd_test", ["src/a.c"], cwd=src_dir.parent) is True + assert did_change("cwd_test", ["src/a.c"], cwd=src_dir.parent) is False diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index de89c5d347..843ae421f4 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -56,6 +56,7 @@ RERUN_GRPC_PORT = 9876 RERUN_WEB_PORT = 9090 + # TODO OUT visual annotations # # In the future it would be nice if modules can annotate their individual OUTs with (general or rerun specific) @@ -133,7 +134,9 @@ def to_rerun(self) -> RerunData: ... def _hex_to_rgba(hex_color: str) -> int: """Convert '#RRGGBB' to a 0xRRGGBBAA int (fully opaque).""" h = hex_color.lstrip("#") - return (int(h, 16) << 8) | 0xFF + if len(h) == 6: + return int(h + "ff", 16) + return int(h[:8], 16) def _with_graph_tab(bp: Blueprint) -> Blueprint: @@ -157,7 +160,7 @@ def _default_blueprint() -> Blueprint: import rerun as rr import rerun.blueprint as rrb - return rrb.Blueprint( + return rrb.Blueprint( # type: ignore[no-any-return] rrb.Spatial3DView( origin="world", background=rrb.Background(kind="SolidColor", color=[0, 0, 0]), @@ -223,10 +226,12 @@ class RerunBridgeModule(Module[Config]): """ default_config = Config + _last_log: dict[str, float] = {} - GV_SCALE = 100.0 # graphviz inches to rerun screen units - MODULE_RADIUS = 30.0 - CHANNEL_RADIUS = 20.0 + # Graphviz layout scale and node radii for blueprint graph + GV_SCALE = 100.0 + MODULE_RADIUS = 20.0 + CHANNEL_RADIUS = 12.0 @lru_cache(maxsize=256) def _visual_override_for_entity_path( @@ -310,13 +315,14 @@ def start(self) -> None: super().start() - self._last_log: dict[str, float] = {} + self._last_log: dict[str, float] = {} # reset on each start logger.info("Rerun bridge starting", viewer_mode=self.config.viewer_mode) # Initialize and spawn Rerun viewer rr.init("dimos") if self.config.viewer_mode == "native": + spawned = False try: import rerun_bindings @@ -325,6 +331,7 @@ def start(self) -> None: executable_name="dimos-viewer", memory_limit=self.config.memory_limit, ) + spawned = True except ImportError: pass # dimos-viewer not installed except Exception: @@ -332,12 +339,31 @@ def start(self) -> None: "dimos-viewer found but failed to spawn, falling back to stock rerun", exc_info=True, ) - rr.spawn(connect=True, memory_limit=self.config.memory_limit) + if not spawned: + try: + rr.spawn(connect=True, memory_limit=self.config.memory_limit) + except (RuntimeError, FileNotFoundError): + logger.warning( + "Rerun native viewer not available (headless?). " + "Bridge will continue without a viewer — data is still " + "accessible via rerun-connect or rerun-web.", + exc_info=True, + ) elif self.config.viewer_mode == "web": server_uri = rr.serve_grpc() rr.serve_web_viewer(connect_to=server_uri, open_browser=False) elif self.config.viewer_mode == "connect": - rr.connect_grpc(self.config.connect_url) + # Serve gRPC so external viewers (dimos-viewer) can connect to us. + # Extract the port from the connect_url to match what viewers expect. + from urllib.parse import urlparse + + parsed = urlparse(self.config.connect_url.replace("rerun+", "", 1)) + grpc_port = parsed.port or RERUN_GRPC_PORT + rr.serve_grpc( + grpc_port=grpc_port, + server_memory_limit=self.config.memory_limit, + ) + logger.info(f"Rerun gRPC server ready at {self.config.connect_url}") # "none" - just init, no viewer (connect externally) if self.config.blueprint: @@ -437,6 +463,7 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: @rpc def stop(self) -> None: + self._visual_override_for_entity_path.cache_clear() super().stop() diff --git a/dimos/visualization/rerun/test_viewer_ws_e2e.py b/dimos/visualization/rerun/test_viewer_ws_e2e.py new file mode 100644 index 0000000000..ea8351f2f6 --- /dev/null +++ b/dimos/visualization/rerun/test_viewer_ws_e2e.py @@ -0,0 +1,328 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end test: dimos-viewer (headless) → WebSocket → RerunWebSocketServer. + +dimos-viewer is started in ``--connect`` mode so it initialises its WebSocket +client. The viewer needs a gRPC proxy to connect to; we give it a non-existent +one so the viewer starts up anyway but produces no visualisation. The important +part is that the WebSocket client inside the viewer tries to connect to +``ws://127.0.0.1:/ws``. + +Because the viewer is a native GUI application it cannot run headlessly in CI +without a display. This test therefore verifies the connection at the protocol +level by using the ``RerunWebSocketServer`` module directly as the server and +injecting synthetic JSON messages that mimic what the viewer would send once a +user clicks in the 3D viewport. +""" + +import asyncio +import json +import os +import shutil +import subprocess +import threading +import time +from typing import Any + +import pytest + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_E2E_PORT = 13032 + + +def _make_server(port: int = _E2E_PORT) -> RerunWebSocketServer: + return RerunWebSocketServer(port=port) + + +def _wait_for_server(port: int, timeout: float = 5.0) -> None: + import websockets.asyncio.client as ws_client + + async def _probe() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): + pass + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + asyncio.run(_probe()) + return + except Exception: + time.sleep(0.05) + raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") + + +def _send_messages(port: int, messages: list[dict[str, Any]], *, delay: float = 0.05) -> None: + import websockets.asyncio.client as ws_client + + async def _run() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws") as ws: + for msg in messages: + await ws.send(json.dumps(msg)) + await asyncio.sleep(delay) + + asyncio.run(_run()) + + +class TestViewerProtocolE2E: + """Verify the full Python-server side of the viewer ↔ DimOS protocol. + + These tests use the ``RerunWebSocketServer`` as the server and a dummy + WebSocket client (playing the role of dimos-viewer) to inject messages. + They confirm every message type is correctly routed and that only click + messages produce stream publishes. + """ + + def test_viewer_click_reaches_stream(self) -> None: + """A viewer click message received over WebSocket publishes PointStamped.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + done.set() + + server.clicked_point.subscribe(_on_pt) + + _send_messages( + _E2E_PORT, + [ + { + "type": "click", + "x": 10.0, + "y": 20.0, + "z": 0.5, + "entity_path": "/world/robot", + "timestamp_ms": 42000, + } + ], + ) + + done.wait(timeout=3.0) + server.stop() + + assert len(received) == 1 + pt = received[0] + assert abs(pt.x - 10.0) < 1e-9 + assert abs(pt.y - 20.0) < 1e-9 + assert abs(pt.z - 0.5) < 1e-9 + assert pt.frame_id == "/world/robot" + assert abs(pt.ts - 42.0) < 1e-6 + + def test_viewer_keyboard_twist_no_publish(self) -> None: + """Twist messages from keyboard control do not publish clicked_point.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + server.clicked_point.subscribe(received.append) + + _send_messages( + _E2E_PORT, + [ + { + "type": "twist", + "linear_x": 0.5, + "linear_y": 0.0, + "linear_z": 0.0, + "angular_x": 0.0, + "angular_y": 0.0, + "angular_z": 0.8, + } + ], + ) + + server.stop() + assert received == [] + + def test_viewer_stop_no_publish(self) -> None: + """Stop messages do not publish clicked_point.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + server.clicked_point.subscribe(received.append) + + _send_messages(_E2E_PORT, [{"type": "stop"}]) + + server.stop() + assert received == [] + + def test_full_viewer_session_sequence(self) -> None: + """Realistic session: connect, heartbeats, click, WASD, stop → one point.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + done.set() + + server.clicked_point.subscribe(_on_pt) + + _send_messages( + _E2E_PORT, + [ + # Initial heartbeats (viewer connects and starts 1 Hz heartbeat) + {"type": "heartbeat", "timestamp_ms": 1000}, + {"type": "heartbeat", "timestamp_ms": 2000}, + # User clicks a point in the 3D viewport + { + "type": "click", + "x": 3.14, + "y": 2.71, + "z": 1.41, + "entity_path": "/world", + "timestamp_ms": 3000, + }, + # User presses W (forward) + { + "type": "twist", + "linear_x": 0.5, + "linear_y": 0.0, + "linear_z": 0.0, + "angular_x": 0.0, + "angular_y": 0.0, + "angular_z": 0.0, + }, + # User releases W + {"type": "stop"}, + # Another heartbeat + {"type": "heartbeat", "timestamp_ms": 4000}, + ], + delay=0.2, + ) + + done.wait(timeout=3.0) + server.stop() + + assert len(received) == 1, f"Expected exactly 1 click, got {len(received)}" + pt = received[0] + assert abs(pt.x - 3.14) < 1e-9 + assert abs(pt.y - 2.71) < 1e-9 + assert abs(pt.z - 1.41) < 1e-9 + + def test_reconnect_after_disconnect(self) -> None: + """Server keeps accepting new connections after a client disconnects.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + all_done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + if len(received) >= 2: + all_done.set() + + server.clicked_point.subscribe(_on_pt) + + # First connection — send one click and disconnect + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 1.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + + # Second connection (simulating viewer reconnect) — send another click + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 2.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + + all_done.wait(timeout=5.0) + server.stop() + + xs = sorted(pt.x for pt in received) + assert xs == [1.0, 2.0], f"Unexpected xs: {xs}" + + +class TestViewerBinaryConnectMode: + """Smoke test: dimos-viewer binary starts in --connect mode and its WebSocket + client attempts to connect to our Python server.""" + + @pytest.mark.skipif( + shutil.which("dimos-viewer") is None + or "--connect" + not in subprocess.run(["dimos-viewer", "--help"], capture_output=True, text=True).stdout, + reason="dimos-viewer binary not installed or does not support --connect", + ) + def test_viewer_ws_client_connects(self) -> None: + """dimos-viewer --connect starts and its WS client connects to our server.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + + def _on_pt(pt: Any) -> None: + received.append(pt) + + server.clicked_point.subscribe(_on_pt) + + # Start dimos-viewer in --connect mode, pointing it at a non-existent gRPC + # proxy (it will fail to stream data, but that's fine) and at our WS server. + # Use DISPLAY="" to prevent it from opening a window (it will exit quickly + # without a display, but the WebSocket connection happens before the GUI loop). + proc = subprocess.Popen( + [ + "dimos-viewer", + "--connect", + f"--ws-url=ws://127.0.0.1:{_E2E_PORT}/ws", + ], + env={ + **os.environ, + "DISPLAY": "", + }, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Give the viewer up to 5 s to connect its WebSocket client to our server. + # We detect the connection by waiting for the server to accept a client. + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + # Check if any connection was established by sending a message and + # verifying the viewer is still running. + if proc.poll() is not None: + # Viewer exited (expected without a display) — check if it connected first. + break + time.sleep(0.1) + + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + + stdout = proc.stdout.read().decode(errors="replace") if proc.stdout else "" + stderr = proc.stderr.read().decode(errors="replace") if proc.stderr else "" + server.stop() + + # The viewer should log that it is connecting to our WS URL. + # Check both stdout and stderr since log output destination varies. + combined = stdout + stderr + assert f"ws://127.0.0.1:{_E2E_PORT}" in combined, ( + f"Viewer did not attempt WS connection.\nstdout:\n{stdout}\nstderr:\n{stderr}" + ) diff --git a/dimos/visualization/rerun/test_websocket_server.py b/dimos/visualization/rerun/test_websocket_server.py new file mode 100644 index 0000000000..cec85fbb11 --- /dev/null +++ b/dimos/visualization/rerun/test_websocket_server.py @@ -0,0 +1,469 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RerunWebSocketServer. + +Uses ``MockViewerPublisher`` to simulate dimos-viewer sending events, matching +the exact JSON protocol used by the Rust ``WsPublisher`` in the viewer. +""" + +import asyncio +import json +import threading +import time +from typing import Any + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_TEST_PORT = 13031 + + +class MockViewerPublisher: + """Python mirror of the Rust WsPublisher in dimos-viewer. + + Connects to a running ``RerunWebSocketServer`` and exposes the same + ``send_click`` / ``send_twist`` / ``send_stop`` / ``send_heartbeat`` + API that the real viewer uses. Useful for unit tests that need to + exercise the server without a real viewer binary. + + Usage:: + + with MockViewerPublisher("ws://127.0.0.1:13031/ws") as pub: + pub.send_click(1.0, 2.0, 0.0, "/world", timestamp_ms=1000) + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + pub.send_stop() + """ + + def __init__(self, url: str) -> None: + self._url = url + self._ws: Any = None + self._loop: asyncio.AbstractEventLoop | None = None + + def __enter__(self) -> "MockViewerPublisher": + self._loop = asyncio.new_event_loop() + self._ws = self._loop.run_until_complete(self._connect()) + return self + + def __exit__(self, *_: Any) -> None: + if self._ws is not None and self._loop is not None: + self._loop.run_until_complete(self._ws.close()) + if self._loop is not None: + self._loop.close() + + async def _connect(self) -> Any: + import websockets.asyncio.client as ws_client + + return await ws_client.connect(self._url) + + def send_click( + self, + x: float, + y: float, + z: float, + entity_path: str = "", + timestamp_ms: int = 0, + ) -> None: + """Send a click event — matches viewer SelectionChange handler output.""" + self._send( + { + "type": "click", + "x": x, + "y": y, + "z": z, + "entity_path": entity_path, + "timestamp_ms": timestamp_ms, + } + ) + + def send_twist( + self, + linear_x: float, + linear_y: float, + linear_z: float, + angular_x: float, + angular_y: float, + angular_z: float, + ) -> None: + """Send a twist (WASD keyboard) event.""" + self._send( + { + "type": "twist", + "linear_x": linear_x, + "linear_y": linear_y, + "linear_z": linear_z, + "angular_x": angular_x, + "angular_y": angular_y, + "angular_z": angular_z, + } + ) + + def send_stop(self) -> None: + """Send a stop event (Space bar or key release).""" + self._send({"type": "stop"}) + + def send_heartbeat(self, timestamp_ms: int = 0) -> None: + """Send a heartbeat (1 Hz keepalive from viewer).""" + self._send({"type": "heartbeat", "timestamp_ms": timestamp_ms}) + + def flush(self, delay: float = 0.1) -> None: + """Wait briefly so the server processes queued messages.""" + time.sleep(delay) + + def _send(self, msg: dict[str, Any]) -> None: + assert self._loop is not None and self._ws is not None, "Not connected" + self._loop.run_until_complete(self._ws.send(json.dumps(msg))) + + +def _collect(received: list[Any], done: threading.Event) -> Any: + """Return a callback that appends to *received* and signals *done*.""" + + def _cb(msg: Any) -> None: + received.append(msg) + done.set() + + return _cb + + +def _make_module(port: int = _TEST_PORT) -> RerunWebSocketServer: + return RerunWebSocketServer(port=port) + + +def _wait_for_server(port: int, timeout: float = 3.0) -> None: + """Block until the WebSocket server accepts an upgrade handshake.""" + + async def _probe() -> None: + import websockets.asyncio.client as ws_client + + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): + pass + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + asyncio.run(_probe()) + return + except Exception: + time.sleep(0.05) + raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") + + +class TestRerunWebSocketServerStartup: + def test_server_binds_port(self) -> None: + """After start(), the server must be reachable on the configured port.""" + mod = _make_module() + mod.start() + try: + _wait_for_server(_TEST_PORT) + finally: + mod.stop() + + def test_stop_is_idempotent(self) -> None: + """Calling stop() twice must not raise.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + mod.stop() + mod.stop() + + +class TestClickMessages: + def test_click_publishes_point_stamped(self) -> None: + """A single click publishes one PointStamped with correct coords.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.clicked_point.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(1.5, 2.5, 0.0, "/world", timestamp_ms=1000) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + pt = received[0] + assert abs(pt.x - 1.5) < 1e-9 + assert abs(pt.y - 2.5) < 1e-9 + assert abs(pt.z - 0.0) < 1e-9 + + def test_click_sets_frame_id_from_entity_path(self) -> None: + """entity_path is stored as frame_id on the published PointStamped.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.clicked_point.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(0.0, 0.0, 0.0, "/robot/base", timestamp_ms=2000) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + assert received and received[0].frame_id == "/robot/base" + + def test_click_timestamp_converted_from_ms(self) -> None: + """timestamp_ms is converted to seconds on PointStamped.ts.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.clicked_point.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(0.0, 0.0, 0.0, "", timestamp_ms=5000) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + assert received and abs(received[0].ts - 5.0) < 1e-6 + + def test_multiple_clicks_all_published(self) -> None: + """A burst of clicks all arrive on the stream.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + all_arrived = threading.Event() + + def _cb(pt: Any) -> None: + received.append(pt) + if len(received) >= 3: + all_arrived.set() + + mod.clicked_point.subscribe(_cb) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(1.0, 0.0, 0.0) + pub.send_click(2.0, 0.0, 0.0) + pub.send_click(3.0, 0.0, 0.0) + pub.flush() + + all_arrived.wait(timeout=3.0) + mod.stop() + + assert sorted(pt.x for pt in received) == [1.0, 2.0, 3.0] + + +class TestNonClickMessages: + def test_heartbeat_does_not_publish(self) -> None: + """Heartbeat messages must not trigger a clicked_point publish.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + clicks: list[Any] = [] + twists: list[Any] = [] + twist_done = threading.Event() + mod.clicked_point.subscribe(clicks.append) + mod.tele_cmd_vel.subscribe(_collect(twists, twist_done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_heartbeat(9999) + # Send a canary twist so we know the server processed everything + pub.send_stop() + pub.flush() + + twist_done.wait(timeout=2.0) + mod.stop() + assert clicks == [] + + def test_twist_does_not_publish_clicked_point(self) -> None: + """Twist messages must not trigger a clicked_point publish.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + clicks: list[Any] = [] + twists: list[Any] = [] + twist_done = threading.Event() + mod.clicked_point.subscribe(clicks.append) + mod.tele_cmd_vel.subscribe(_collect(twists, twist_done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + pub.flush() + + twist_done.wait(timeout=2.0) + mod.stop() + assert clicks == [] + + def test_stop_does_not_publish_clicked_point(self) -> None: + """Stop messages must not trigger a clicked_point publish.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + clicks: list[Any] = [] + twists: list[Any] = [] + twist_done = threading.Event() + mod.clicked_point.subscribe(clicks.append) + mod.tele_cmd_vel.subscribe(_collect(twists, twist_done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_stop() + pub.flush() + + twist_done.wait(timeout=2.0) + mod.stop() + assert clicks == [] + + def test_twist_publishes_on_tele_cmd_vel(self) -> None: + """Twist messages publish a Twist on the tele_cmd_vel stream.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.tele_cmd_vel.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + tw = received[0] + assert abs(tw.linear.x - 0.5) < 1e-9 + assert abs(tw.angular.z - 0.8) < 1e-9 + + def test_stop_publishes_zero_twist_on_tele_cmd_vel(self) -> None: + """Stop messages publish a zero Twist on the tele_cmd_vel stream.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.tele_cmd_vel.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_stop() + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + tw = received[0] + assert tw.is_zero() + + def test_twist_publishes_stop_explore_cmd_on_first_twist(self) -> None: + """First twist publishes Bool(data=True) on stop_explore_cmd; stop resets.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + explore_cmds: list[Any] = [] + twists: list[Any] = [] + first_done = threading.Event() + mod.stop_explore_cmd.subscribe(_collect(explore_cmds, first_done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.0) + pub.flush() + + first_done.wait(timeout=2.0) + assert len(explore_cmds) == 1 + assert explore_cmds[0].data is True + + # Second twist within same connection should NOT publish another stop_explore_cmd + twist_done = threading.Event() + mod.tele_cmd_vel.subscribe(_collect(twists, twist_done)) + + pub.send_twist(0.3, 0.0, 0.0, 0.0, 0.0, 0.0) + pub.flush() + + twist_done.wait(timeout=2.0) + assert len(explore_cmds) == 1 # still just the first one + + # After stop + new twist within same connection, stop_explore_cmd should fire again + second_done = threading.Event() + + def _on_second(msg: Any) -> None: + explore_cmds.append(msg) + if len(explore_cmds) >= 2: + second_done.set() + + mod.stop_explore_cmd.subscribe(_on_second) + + pub.send_stop() + pub.send_twist(0.1, 0.0, 0.0, 0.0, 0.0, 0.0) + pub.flush() + + second_done.wait(timeout=2.0) + + mod.stop() + assert len(explore_cmds) >= 2 + + def test_invalid_json_does_not_crash(self) -> None: + """Malformed JSON is silently dropped; server stays alive.""" + import websockets.asyncio.client as ws_client + + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + async def _send_bad() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{_TEST_PORT}/ws") as ws: + await ws.send("this is not json {{") + await asyncio.sleep(0.1) + await ws.send(json.dumps({"type": "heartbeat", "timestamp_ms": 0})) + await asyncio.sleep(0.1) + + asyncio.run(_send_bad()) + mod.stop() + + def test_mixed_message_sequence(self) -> None: + """Realistic sequence: heartbeat → click → twist → stop publishes one point.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + # Subscribe before sending so we don't race against the click dispatch. + received: list[Any] = [] + done = threading.Event() + + def _cb(pt: Any) -> None: + received.append(pt) + done.set() + + mod.clicked_point.subscribe(_cb) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_heartbeat(1000) + pub.send_click(7.0, 8.0, 9.0, "/map", timestamp_ms=1100) + pub.send_twist(0.3, 0.0, 0.0, 0.0, 0.0, 0.2) + pub.send_stop() + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + assert abs(received[0].x - 7.0) < 1e-9 + assert abs(received[0].y - 8.0) < 1e-9 + assert abs(received[0].z - 9.0) < 1e-9 diff --git a/dimos/visualization/rerun/websocket_server.py b/dimos/visualization/rerun/websocket_server.py new file mode 100644 index 0000000000..9275018e32 --- /dev/null +++ b/dimos/visualization/rerun/websocket_server.py @@ -0,0 +1,231 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WebSocket server module that receives events from dimos-viewer. + +When dimos-viewer is started with ``--connect``, LCM multicast is unavailable +across machines. The viewer falls back to sending click, twist, and stop events +as JSON over a WebSocket connection. This module acts as the server-side +counterpart: it listens for those connections and translates incoming messages +into DimOS stream publishes. + +Message format (newline-delimited JSON, ``"type"`` discriminant): + + {"type":"heartbeat","timestamp_ms":1234567890} + {"type":"click","x":1.0,"y":2.0,"z":3.0,"entity_path":"/world","timestamp_ms":1234567890} + {"type":"twist","linear_x":0.5,"linear_y":0.0,"linear_z":0.0, + "angular_x":0.0,"angular_y":0.0,"angular_z":0.8} + {"type":"stop"} +""" + +import asyncio +import json +import threading +from typing import Any + +from dimos_lcm.std_msgs import Bool # type: ignore[import-untyped] +import websockets + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class Config(ModuleConfig): + # Intentionally binds 0.0.0.0 by default so the viewer can connect from + # any machine on the network (the typical robot deployment scenario). + host: str = "0.0.0.0" + port: int = 3030 + start_timeout: float = 10.0 # seconds to wait for the server to bind + + +class RerunWebSocketServer(Module[Config]): + """Receives dimos-viewer WebSocket events and publishes them as DimOS streams. + + The viewer connects to this module (not the other way around) when running + in ``--connect`` mode. Each click event is converted to a ``PointStamped`` + and published on the ``clicked_point`` stream so downstream modules (e.g. + ``ReplanningAStarPlanner``) can consume it without modification. + + Outputs: + clicked_point: 3-D world-space point from the most recent viewer click. + tele_cmd_vel: Twist velocity commands from keyboard teleop, including stop events. + """ + + default_config = Config + + clicked_point: Out[PointStamped] + tele_cmd_vel: Out[Twist] + stop_explore_cmd: Out[Bool] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._teleop_clients: set[int] = set() # ids of clients currently in teleop + self._ws_loop: asyncio.AbstractEventLoop | None = None + self._server_thread: threading.Thread | None = None + self._stop_event: asyncio.Event | None = None + self._server_ready = threading.Event() + + @rpc + def start(self) -> None: + super().start() + self._server_thread = threading.Thread( + target=self._run_server, daemon=True, name="rerun-ws-server" + ) + self._server_thread.start() + self._server_ready.wait(timeout=self.config.start_timeout) + self._log_connect_hints() + + @rpc + def stop(self) -> None: + # Wait briefly for the server thread to initialise _stop_event so we + # don't silently skip the shutdown signal (race with _serve()). + self._server_ready.wait(timeout=self.config.start_timeout) + if ( + self._ws_loop is not None + and not self._ws_loop.is_closed() + and self._stop_event is not None + ): + self._ws_loop.call_soon_threadsafe(self._stop_event.set) + super().stop() + + def _log_connect_hints(self) -> None: + """Log the WebSocket URL(s) that viewers should connect to.""" + import socket + + from dimos.utils.generic import get_local_ips + + local_ips = get_local_ips() + hostname = socket.gethostname() + ws_url = f"ws://127.0.0.1:{self.config.port}/ws" + + lines = [ + "", + "=" * 60, + f"RerunWebSocketServer listening on {ws_url}", + "", + ] + if local_ips: + lines.append("From another machine on the network:") + for ip, iface in local_ips: + lines.append(f" ws://{ip}:{self.config.port}/ws # {iface}") + lines.append("") + lines.append(f" hostname: {hostname}") + lines.append("=" * 60) + lines.append("") + + logger.info("\n".join(lines)) + + def _run_server(self) -> None: + """Entry point for the background server thread.""" + self._ws_loop = asyncio.new_event_loop() + try: + self._ws_loop.run_until_complete(self._serve()) + except Exception: + logger.error("RerunWebSocketServer failed to start", exc_info=True) + finally: + self._server_ready.set() # unblock stop() even on failure + self._ws_loop.close() + + async def _serve(self) -> None: + import websockets.asyncio.server as ws_server + + self._stop_event = asyncio.Event() + + async with ws_server.serve( + self._handle_client, + host=self.config.host, + port=self.config.port, + # Ping every 30 s, allow 30 s for pong — generous enough to + # survive brief network hiccups while still detecting dead clients. + ping_interval=30, + ping_timeout=30, + ): + self._server_ready.set() + await self._stop_event.wait() + + async def _handle_client(self, websocket: Any) -> None: + if hasattr(websocket, "request") and websocket.request.path != "/ws": + await websocket.close(1008, "Not Found") + return + addr = websocket.remote_address + client_id = id(websocket) + logger.info(f"RerunWebSocketServer: viewer connected from {addr}") + try: + async for raw in websocket: + self._dispatch(raw, client_id) + except websockets.ConnectionClosed as exc: + logger.debug(f"RerunWebSocketServer: client {addr} disconnected ({exc})") + finally: + self._teleop_clients.discard(client_id) + + def _dispatch(self, raw: str | bytes, client_id: int) -> None: + try: + msg = json.loads(raw) + except json.JSONDecodeError: + logger.warning(f"RerunWebSocketServer: ignoring non-JSON message: {raw!r}") + return + + if not isinstance(msg, dict): + logger.warning(f"RerunWebSocketServer: expected JSON object, got {type(msg).__name__}") + return + + msg_type = msg.get("type") + + if msg_type == "click": + pt = PointStamped( + x=float(msg.get("x", 0)), + y=float(msg.get("y", 0)), + z=float(msg.get("z", 0)), + ts=float(msg.get("timestamp_ms", 0)) / 1000.0, + frame_id=str(msg.get("entity_path", "")), + ) + logger.debug(f"RerunWebSocketServer: click → {pt}") + self.clicked_point.publish(pt) + + elif msg_type == "twist": + twist = Twist( + linear=Vector3( + float(msg.get("linear_x", 0)), + float(msg.get("linear_y", 0)), + float(msg.get("linear_z", 0)), + ), + angular=Vector3( + float(msg.get("angular_x", 0)), + float(msg.get("angular_y", 0)), + float(msg.get("angular_z", 0)), + ), + ) + logger.debug(f"RerunWebSocketServer: twist → {twist}") + if not self._teleop_clients: + self.stop_explore_cmd.publish(Bool(data=True)) + self._teleop_clients.add(client_id) + self.tele_cmd_vel.publish(twist) + + elif msg_type == "stop": + logger.debug("RerunWebSocketServer: stop") + self._teleop_clients.discard(client_id) + self.tele_cmd_vel.publish(Twist.zero()) + + elif msg_type == "heartbeat": + logger.debug(f"RerunWebSocketServer: heartbeat ts={msg.get('timestamp_ms')}") + + else: + logger.warning(f"RerunWebSocketServer: unknown message type {msg_type!r}") diff --git a/dimos/visualization/vis_module.py b/dimos/visualization/vis_module.py new file mode 100644 index 0000000000..400a912ce4 --- /dev/null +++ b/dimos/visualization/vis_module.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared visualization module factory for all robot blueprints.""" + +from typing import Any + +from dimos.core.blueprints import Blueprint, autoconnect +from dimos.core.global_config import ViewerBackend + + +def vis_module( + viewer_backend: ViewerBackend, + rerun_config: dict[str, Any] | None = None, + foxglove_config: dict[str, Any] | None = None, +) -> Blueprint: + """Create a visualization blueprint based on the selected viewer backend. + + Bundles the appropriate viewer module (Rerun or Foxglove) together with + the ``WebsocketVisModule`` and ``RerunWebSocketServer`` so that the web + dashboard and remote viewer connections work out of the box. + + Example usage:: + + from dimos.core.global_config import global_config + viz = vis_module( + global_config.viewer, + rerun_config={ + "visual_override": { + "world/camera_info": lambda ci: ci.to_rerun(...), + }, + "static": { + "world/tf/base_link": lambda rr: [rr.Boxes3D(...)], + }, + }, + ) + """ + from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule + + if foxglove_config is None: + foxglove_config = {} + if rerun_config is None: + rerun_config = {} + + match viewer_backend: + case "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + return autoconnect( + FoxgloveBridge.blueprint(**foxglove_config), + RerunWebSocketServer.blueprint(), + WebsocketVisModule.blueprint(), + ) + case "rerun" | "rerun-web" | "rerun-connect": + from dimos.protocol.pubsub.impl.lcmpubsub import LCM + from dimos.visualization.rerun.bridge import _BACKEND_TO_MODE, RerunBridgeModule + + rerun_config = {**rerun_config} + rerun_config.setdefault("pubsubs", [LCM()]) + viewer_mode = _BACKEND_TO_MODE.get(viewer_backend, "native") + return autoconnect( + RerunBridgeModule.blueprint(viewer_mode=viewer_mode, **rerun_config), + RerunWebSocketServer.blueprint(), + WebsocketVisModule.blueprint(), + ) + case "none": + return autoconnect(WebsocketVisModule.blueprint()) + case _: + raise ValueError( + f"Unknown viewer_backend {viewer_backend!r}. " + f"Expected one of: rerun, rerun-web, rerun-connect, foxglove, none" + ) diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 685ca2b1ee..c777e3fda0 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -165,35 +165,47 @@ def start(self) -> None: global _browser_opened with _browser_open_lock: if not _browser_opened: - try: - webbrowser.open_new_tab(url) - _browser_opened = True - except Exception as e: - logger.debug(f"Failed to open browser: {e}") + _browser_opened = True + + def _open_browser() -> None: + try: + webbrowser.open_new_tab(url) + except Exception as e: + logger.debug(f"Failed to open browser: {e}") + + threading.Thread(target=_open_browser, daemon=True).start() try: unsub = self.odom.subscribe(self._on_robot_pose) self._disposables.add(Disposable(unsub)) - except Exception: - ... + logger.info("Subscribed to odom") + except Exception as e: + logger.warning(f"Failed to subscribe to odom: {e}") try: unsub = self.gps_location.subscribe(self._on_gps_location) self._disposables.add(Disposable(unsub)) - except Exception: - ... + logger.info("Subscribed to gps_location") + except Exception as e: + logger.warning(f"Failed to subscribe to gps_location: {e}") try: unsub = self.path.subscribe(self._on_path) self._disposables.add(Disposable(unsub)) - except Exception: - ... + logger.info("Subscribed to path") + except Exception as e: + logger.warning(f"Failed to subscribe to path: {e}") + transport = getattr(self.global_costmap, "_transport", "MISSING") + logger.debug(f"[DEBUG] global_costmap transport before subscribe: {transport}") try: unsub = self.global_costmap.subscribe(self._on_global_costmap) self._disposables.add(Disposable(unsub)) - except Exception: - ... + logger.debug(f"[DEBUG] Subscribed to global_costmap OK, transport={transport}") + except Exception as e: + logger.warning(f"Failed to subscribe to global_costmap: {e}", exc_info=True) + + logger.info("WebsocketVisModule.start() complete") @rpc def stop(self) -> None: @@ -209,7 +221,11 @@ def stop(self) -> None: async def _disconnect_all() -> None: await self.sio.disconnect() - asyncio.run_coroutine_threadsafe(_disconnect_all(), self._broadcast_loop) + fut = asyncio.run_coroutine_threadsafe(_disconnect_all(), self._broadcast_loop) + try: + fut.result(timeout=2.0) + except Exception: + pass if self._broadcast_loop and not self._broadcast_loop.is_closed(): self._broadcast_loop.call_soon_threadsafe(self._broadcast_loop.stop) diff --git a/docker/navigation/Dockerfile b/docker/navigation/Dockerfile index dc2ce54f39..cd940ab375 100644 --- a/docker/navigation/Dockerfile +++ b/docker/navigation/Dockerfile @@ -325,7 +325,6 @@ RUN cat > ${WORKSPACE}/config/fastdds.xml <<'EOF' shm_transport SHM 10485760 - 1048576 diff --git a/docker/navigation/build.sh b/docker/navigation/build.sh index 371db08b49..750c105d6c 100755 --- a/docker/navigation/build.sh +++ b/docker/navigation/build.sh @@ -62,12 +62,13 @@ cd "$SCRIPT_DIR" # Use fastlio2 branch which has both arise_slam and FASTLIO2 TARGET_BRANCH="fastlio2" TARGET_REMOTE="origin" -CLONE_URL="https://github.com/dimensionalOS/ros-navigation-autonomy-stack.git" +CLONE_URL_SSH="git@github.com:dimensionalOS/ros-navigation-autonomy-stack.git" +CLONE_URL_HTTPS="https://github.com/dimensionalOS/ros-navigation-autonomy-stack.git" # Clone or checkout ros-navigation-autonomy-stack if [ ! -d "ros-navigation-autonomy-stack" ]; then echo -e "${YELLOW}Cloning ros-navigation-autonomy-stack repository (${TARGET_BRANCH} branch)...${NC}" - git clone -b ${TARGET_BRANCH} ${CLONE_URL} ros-navigation-autonomy-stack + git clone -b ${TARGET_BRANCH} ${CLONE_URL_SSH} ros-navigation-autonomy-stack || git clone -b ${TARGET_BRANCH} ${CLONE_URL_HTTPS} ros-navigation-autonomy-stack echo -e "${GREEN}Repository cloned successfully!${NC}" else # Directory exists, ensure we're on the correct branch @@ -100,14 +101,20 @@ fi if [ ! -d "unity_models" ]; then echo -e "${YELLOW}Using office_building_1 as the Unity environment...${NC}" - tar -xf ../../data/.lfs/office_building_1.tar.gz + LFS_ASSET="../../data/.lfs/office_building_1.tar.gz" + # If the file is still a Git LFS pointer (not yet downloaded), fetch it now. + if file "$LFS_ASSET" | grep -q "ASCII text"; then + echo -e "${YELLOW}office_building_1.tar.gz is an LFS pointer — fetching via git lfs...${NC}" + git -C "$(realpath ../../)" lfs pull --include="data/.lfs/office_building_1.tar.gz" + fi + tar -xf "$LFS_ASSET" mv office_building_1 unity_models fi echo "" echo -e "${YELLOW}Building Docker image with docker compose...${NC}" echo "This will take a while as it needs to:" -echo " - Download base ROS ${ROS_DISTRO^} image" +echo " - Download base ROS ${ROS_DISTRO} image" echo " - Install ROS packages and dependencies" echo " - Build the autonomy stack (arise_slam + FASTLIO2)" echo " - Build Livox-SDK2 for Mid-360 lidar" @@ -117,7 +124,31 @@ echo "" cd ../.. -docker compose -f docker/navigation/docker-compose.yml build +# Detect host architecture and pass it as a build arg so the Dockerfile's +# base-${TARGETARCH} stage resolves correctly (the standard docker builder +# does not set TARGETARCH automatically without --platform). +HOST_ARCH=$(uname -m) +case "$HOST_ARCH" in + x86_64) TARGETARCH="amd64" ;; + aarch64|arm64) TARGETARCH="arm64" ;; + *) TARGETARCH="$HOST_ARCH" ;; +esac +echo -e "${GREEN}Detected architecture: ${HOST_ARCH} → TARGETARCH=${TARGETARCH}${NC}" + +# Prefer the Docker Compose V2 plugin; fall back to the legacy standalone binary. +# Auto-install the plugin if neither is available. +if ! docker compose version &>/dev/null; then + echo -e "${YELLOW}Docker Compose not found — installing docker-compose-plugin...${NC}" + sudo apt-get update -qq && sudo apt-get install -y docker-compose-v2 || sudo apt-get install -y docker-compose-plugin + if ! docker compose version &>/dev/null; then + echo -e "${RED}Error: Failed to install Docker Compose.${NC}" + echo "Please install it manually: sudo apt-get install docker-compose-v2" + echo "or follow https://docs.docker.com/compose/install/" + exit 1 + fi +fi + +docker compose -f docker/navigation/docker-compose.yml build --build-arg TARGETARCH="$TARGETARCH" echo "" echo -e "${GREEN}============================================${NC}" @@ -127,13 +158,13 @@ echo -e "${GREEN}SLAM: arise_slam + FASTLIO2 (both included)${NC}" echo -e "${GREEN}============================================${NC}" echo "" echo "To run in SIMULATION mode:" -echo -e "${YELLOW} ./start.sh --simulation --${ROS_DISTRO}${NC}" +echo -e "${YELLOW} ./start.sh --simulation --image ${ROS_DISTRO}${NC}" echo "" echo "To run in HARDWARE mode:" echo " 1. Configure your hardware settings in .env file" echo " (copy from .env.hardware if needed)" echo " 2. Run the hardware container:" -echo -e "${YELLOW} ./start.sh --hardware --${ROS_DISTRO}${NC}" +echo -e "${YELLOW} ./start.sh --hardware --image ${ROS_DISTRO}${NC}" echo "" echo "To use FASTLIO2 instead of arise_slam, set LOCALIZATION_METHOD:" echo -e "${YELLOW} LOCALIZATION_METHOD=fastlio ./start.sh --hardware --${ROS_DISTRO}${NC}" diff --git a/docker/navigation/docker-compose.yml b/docker/navigation/docker-compose.yml index 6546968757..6a96adf4f9 100644 --- a/docker/navigation/docker-compose.yml +++ b/docker/navigation/docker-compose.yml @@ -7,6 +7,7 @@ services: network: host args: ROS_DISTRO: ${ROS_DISTRO:-humble} + TARGETARCH: ${TARGETARCH} image: dimos_autonomy_stack:${IMAGE_TAG:-humble} container_name: dimos_simulation_container profiles: ["", "simulation"] # Active by default (empty profile) AND with --profile simulation @@ -65,7 +66,8 @@ services: # Device access (for joystick controllers) devices: - /dev/input:/dev/input - - /dev/dri:/dev/dri + # DRI GPU device: set by start.sh when /dev/dri exists (desktop); falls back to /dev/null on Jetson/Tegra + - ${DRI_DEVICE:-/dev/null}:${DRI_DEVICE:-/dev/null} # Working directory working_dir: /workspace/dimos @@ -81,6 +83,7 @@ services: network: host args: ROS_DISTRO: ${ROS_DISTRO:-humble} + TARGETARCH: ${TARGETARCH} image: dimos_autonomy_stack:${IMAGE_TAG:-humble} container_name: dimos_hardware_container profiles: ["hardware"] @@ -170,8 +173,8 @@ services: devices: # Joystick controller (specific device to avoid permission issues) - /dev/input/js0:/dev/input/js0 - # GPU access - - /dev/dri:/dev/dri + # DRI GPU device: set by start.sh when /dev/dri exists (desktop); falls back to /dev/null on Jetson/Tegra + - ${DRI_DEVICE:-/dev/null}:${DRI_DEVICE:-/dev/null} # Motor controller serial ports - ${MOTOR_SERIAL_DEVICE:-/dev/ttyACM0}:${MOTOR_SERIAL_DEVICE:-/dev/ttyACM0} # Additional serial ports (can be enabled via environment) @@ -251,6 +254,7 @@ services: network: host args: ROS_DISTRO: ${ROS_DISTRO:-humble} + TARGETARCH: ${TARGETARCH} image: dimos_autonomy_stack:${IMAGE_TAG:-humble} container_name: dimos_bagfile_container profiles: ["bagfile"] @@ -302,7 +306,8 @@ services: # Device access (for joystick controllers) devices: - /dev/input:/dev/input - - /dev/dri:/dev/dri + # DRI GPU device: set by start.sh when /dev/dri exists (desktop); falls back to /dev/null on Jetson/Tegra + - ${DRI_DEVICE:-/dev/null}:${DRI_DEVICE:-/dev/null} # Working directory working_dir: /ros2_ws @@ -316,7 +321,7 @@ services: echo "Bagfile playback mode (use_sim_time=true)" echo "" echo "Launch files ready. Play your bagfile with:" - echo " ros2 bag play --clock /ros2_ws/bagfiles/" + echo " ros2 bag play /ros2_ws/bagfiles/ --clock" echo "" # Launch with SLAM method based on LOCALIZATION_METHOD if [ "$LOCALIZATION_METHOD" = "fastlio" ]; then diff --git a/docker/navigation/foxglove_utility/twist_relay.py b/docker/navigation/foxglove_utility/twist_relay.py index 6e72d5104b..68b68856e6 100644 --- a/docker/navigation/foxglove_utility/twist_relay.py +++ b/docker/navigation/foxglove_utility/twist_relay.py @@ -37,16 +37,22 @@ def __init__(self): output_topic = self.get_parameter("output_topic").value self.frame_id = self.get_parameter("frame_id").value - # QoS for real-time control - qos = QoSProfile( + # BEST_EFFORT subscriber: drop stale teleop input rather than queue it + sub_qos = QoSProfile( reliability=ReliabilityPolicy.BEST_EFFORT, history=HistoryPolicy.KEEP_LAST, depth=1 ) + # RELIABLE publisher: vehicleSimulator and the nav planner subscribe with RELIABLE (default) + pub_qos = QoSProfile( + reliability=ReliabilityPolicy.RELIABLE, history=HistoryPolicy.KEEP_LAST, depth=1 + ) # Subscribe to Twist (from Foxglove Teleop) - self.subscription = self.create_subscription(Twist, input_topic, self.twist_callback, qos) + self.subscription = self.create_subscription( + Twist, input_topic, self.twist_callback, sub_qos + ) # Publish TwistStamped - self.publisher = self.create_publisher(TwistStamped, output_topic, qos) + self.publisher = self.create_publisher(TwistStamped, output_topic, pub_qos) self.get_logger().info( f"Twist relay: {input_topic} (Twist) -> {output_topic} (TwistStamped)" diff --git a/docker/navigation/start.sh b/docker/navigation/start.sh index be45908a33..0102abf81c 100755 --- a/docker/navigation/start.sh +++ b/docker/navigation/start.sh @@ -98,6 +98,15 @@ export ROS_DISTRO export LOCALIZATION_METHOD export IMAGE_TAG="${ROS_DISTRO}" +# Detect host architecture and export for docker-compose build args +HOST_ARCH=$(uname -m) +case "$HOST_ARCH" in + x86_64) TARGETARCH="amd64" ;; + aarch64|arm64) TARGETARCH="arm64" ;; + *) TARGETARCH="$HOST_ARCH" ;; +esac +export TARGETARCH + SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" cd "$SCRIPT_DIR" @@ -349,7 +358,7 @@ elif [ "$MODE" = "bagfile" ]; then echo " - RViz2 visualization" fi echo "" - echo -e "${YELLOW}Remember to play bagfile with: ros2 bag play --clock ${NC}" + echo -e "${YELLOW}Remember to play bagfile with: ros2 bag play ${NC} --clock" echo "" echo "To enter the container from another terminal:" echo -e " ${YELLOW}docker exec -it ${CONTAINER_NAME} bash${NC}" @@ -374,7 +383,13 @@ elif [ "$MODE" = "bagfile" ]; then mkdir -p bagfiles config maps fi -# Build compose command +# Enable DRI device passthrough on systems that support it (not available on Jetson/Tegra) +if [ -e "/dev/dri" ]; then + export DRI_DEVICE="/dev/dri" + echo -e "${GREEN}/dev/dri detected — enabling DRI device passthrough${NC}" +fi + +# Build compose command (for hardware and bagfile modes) COMPOSE_CMD="docker compose -f docker-compose.yml" if [ "$DEV_MODE" = "true" ]; then COMPOSE_CMD="$COMPOSE_CMD -f docker-compose.dev.yml" diff --git a/pyproject.toml b/pyproject.toml index 7e2f38546e..e478f0d717 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ dependencies = [ "annotation-protocol>=1.4.0", "lazy_loader", "plum-dispatch==2.5.7", + "xxhash>=3.0.0", # Logging "structlog>=25.5.0,<26", "colorlog==6.9.0", @@ -291,6 +292,12 @@ sim = [ "pygame>=2.6.1", ] +navigation = [ + # PGO (pose graph optimization) — gtsam-develop has aarch64 wheels; stable gtsam 4.2 does not + "gtsam>=4.2; platform_machine != 'aarch64'", + "gtsam-develop; platform_machine == 'aarch64'", +] + # NOTE: jetson-jp6-cuda126 extra is disabled due to 404 errors from wheel URLs # The pypi.jetson-ai-lab.io URLs are currently unavailable. Update with working URLs when available. # jetson-jp6-cuda126 = [ @@ -335,6 +342,11 @@ base = [ "dimos[agents,web,perception,visualization,sim]", ] +[tool.uv] +# gtsam-develop incorrectly declares pytest as a runtime dependency (packaging bug). +# Override it to keep our pinned version. +override-dependencies = ["pytest==8.3.5"] + [tool.ruff] line-length = 100 exclude = [ @@ -460,4 +472,6 @@ ignore = [ "dimos/dashboard/dimos.rbl", "dimos/web/dimos_interface/themes.json", "dimos/manipulation/manipulation_module.py", + "dimos/navigation/smartnav/modules/*/main.cpp", + "dimos/navigation/smartnav/common/*.hpp", ] diff --git a/uv.lock b/uv.lock index 529842294b..373cf846bc 100644 --- a/uv.lock +++ b/uv.lock @@ -24,6 +24,9 @@ resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", ] +[manifest] +overrides = [{ name = "pytest", specifier = "==8.3.5" }] + [[package]] name = "absl-py" version = "2.4.0" @@ -1714,6 +1717,7 @@ dependencies = [ { name = "toolz" }, { name = "typer" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, + { name = "xxhash" }, ] [package.optional-dependencies] @@ -1920,6 +1924,10 @@ misc = [ { name = "xarm-python-sdk" }, { name = "yapf" }, ] +navigation = [ + { name = "gtsam", marker = "platform_machine != 'aarch64'" }, + { name = "gtsam-develop", marker = "platform_machine == 'aarch64'" }, +] perception = [ { name = "filterpy" }, { name = "hydra-core" }, @@ -2014,6 +2022,8 @@ requires-dist = [ { name = "filterpy", marker = "extra == 'perception'", specifier = ">=1.4.5" }, { name = "gdown", marker = "extra == 'misc'", specifier = "==5.2.0" }, { name = "googlemaps", marker = "extra == 'misc'", specifier = ">=4.10.0" }, + { name = "gtsam", marker = "platform_machine != 'aarch64' and extra == 'navigation'", specifier = ">=4.2" }, + { name = "gtsam-develop", marker = "platform_machine == 'aarch64' and extra == 'navigation'" }, { name = "hydra-core", marker = "extra == 'perception'", specifier = ">=1.3.0" }, { name = "ipykernel", marker = "extra == 'misc'" }, { name = "kaleido", marker = "extra == 'manipulation'", specifier = ">=0.2.1" }, @@ -2150,9 +2160,10 @@ requires-dist = [ { name = "xarm-python-sdk", marker = "extra == 'manipulation'", specifier = ">=1.17.0" }, { name = "xarm-python-sdk", marker = "extra == 'misc'", specifier = ">=1.17.0" }, { name = "xformers", marker = "platform_machine == 'x86_64' and extra == 'cuda'", specifier = ">=0.0.20" }, + { name = "xxhash", specifier = ">=3.0.0" }, { name = "yapf", marker = "extra == 'misc'", specifier = "==0.40.2" }, ] -provides-extras = ["misc", "visualization", "agents", "web", "perception", "unitree", "manipulation", "cpu", "cuda", "dev", "psql", "sim", "drone", "dds", "docker", "base"] +provides-extras = ["misc", "visualization", "agents", "web", "perception", "unitree", "manipulation", "cpu", "cuda", "dev", "psql", "sim", "navigation", "drone", "dds", "docker", "base"] [[package]] name = "dimos-lcm" @@ -3022,6 +3033,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/b2/b096ccce418882fbfda4f7496f9357aaa9a5af1896a9a7f60d9f2b275a06/grpcio-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:dce09d6116df20a96acfdbf85e4866258c3758180e8c49845d6ba8248b6d0bbb", size = 4929852, upload-time = "2026-02-06T09:56:45.885Z" }, ] +[[package]] +name = "gtsam" +version = "4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and sys_platform != 'linux')" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and platform_machine != 'aarch64') or (python_full_version >= '3.11' and sys_platform != 'linux')" }, + { name = "pyparsing", marker = "platform_machine != 'aarch64' or sys_platform != 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/3f/6325c300cc92ca2495570e41ab5dffc3147837d5fce77e243714c1d646dd/gtsam-4.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:e1c2958b5e8895ff5822114119e9b303fc12b2ded4c9ede0a7a6844f6eb7be1a", size = 21626497, upload-time = "2023-09-03T16:47:54.229Z" }, + { url = "https://files.pythonhosted.org/packages/11/d8/ab317fdedeca03362d316c1b5c32cabc06f7a8948363f1c96d4a7f3fca8d/gtsam-4.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:b89d231f1de264d475ce1c0fcac320fdccb1260e2df01f8fb8775cdaddf57fb0", size = 19858294, upload-time = "2023-09-03T16:48:10.742Z" }, + { url = "https://files.pythonhosted.org/packages/9d/e1/8c87e6cc713f18be917fd110570e0646a132ab189781c551e4347d175170/gtsam-4.2-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:59d452e645f7ca598e89170785c9e8ac1b5a2b0a0d32bcda2c279b7b51bb379f", size = 21621183, upload-time = "2023-09-03T16:48:27.783Z" }, + { url = "https://files.pythonhosted.org/packages/89/80/95d842fa51fef2223f3920f82ffea241bcc6f6b6ba6d7aa96c6e46e19474/gtsam-4.2-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:979b7c886724ac403d5c323613fd9800c8ac7ab224e2909faa8b266e5de742b2", size = 21788738, upload-time = "2023-09-03T16:48:45.876Z" }, + { url = "https://files.pythonhosted.org/packages/0e/25/b380725edd25c7ab50814a544ffaa748af41fe639d5294818cf4e67b2584/gtsam-4.2-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:2996c349f03182df739f284adc62f455337b5272c4227bb970cbac7622b1d8e8", size = 21626506, upload-time = "2023-09-03T16:49:03.477Z" }, + { url = "https://files.pythonhosted.org/packages/ae/41/a7e26b58289f0f634123615fffa3c7eefd9b559cc1709d7c6ae67da7c87d/gtsam-4.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:0dc13b8a76ef862359a15ea9feec7cfd7df9d2bf2545b5f96297592c63e76aad", size = 20149822, upload-time = "2023-09-03T16:49:20.709Z" }, + { url = "https://files.pythonhosted.org/packages/bd/ba/d7b3ddccb179ca28db84edb7dd223aba6733ae9f59015b6263391bc8976b/gtsam-4.2-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:473978b6d32ab45903433a745ca166f0441058a726c7f26e9f51ce5aa2a03721", size = 21620301, upload-time = "2023-09-03T16:49:37.761Z" }, + { url = "https://files.pythonhosted.org/packages/51/8e/a1cf54ea59c81b300d31f407270d8fe8b7e6fdcfd4c7353b0a0857602808/gtsam-4.2-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:3793d5524a0417924fddb9d9662afc8129357fd52f07c422776024ddab3780c1", size = 22352720, upload-time = "2023-10-02T15:24:54.997Z" }, +] + +[[package]] +name = "gtsam-develop" +version = "4.3a1.dev202603292217" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pytest" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/9c/77920ffa36c1896ad043be07b87a265bc144df1cf940c4ed329f477a16a9/gtsam_develop-4.3a1.dev202603292217-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:54cc6bc76a3f0b770868d55c260b1f4bf9b48b0ed5bef9e8c9d648cb2d3f2793", size = 40983785, upload-time = "2026-03-29T22:54:24.812Z" }, + { url = "https://files.pythonhosted.org/packages/40/e8/6409fb20da11874e1e2d7a8b17378ebabe258a0cd7db917065efad599646/gtsam_develop-4.3a1.dev202603292217-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:26b9b58466e15724de0774a5025a5d86881fd4d655fd626acf87a0e23954f1fe", size = 29140058, upload-time = "2026-03-29T22:54:27.425Z" }, + { url = "https://files.pythonhosted.org/packages/d1/24/e452f87ad57a287ec13af4b16dd4a9b2dd295c7c7c9265fddf58cdc5657e/gtsam_develop-4.3a1.dev202603292217-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:dfed09f8a7cf6c196a60c50757a9d9d92efb85543393fbf2fc8cf53130d09382", size = 41178348, upload-time = "2026-03-29T22:54:32.679Z" }, + { url = "https://files.pythonhosted.org/packages/1d/a6/219ae92eb5a77395777bf24b7978e2806e0427707e21d57f4517eded47ec/gtsam_develop-4.3a1.dev202603292217-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1caab4d70162c2e9291c4b28489c71da22d226b7cfda03ceea6d5597ef122fff", size = 29132018, upload-time = "2026-03-29T22:54:35.418Z" }, + { url = "https://files.pythonhosted.org/packages/53/16/a9c4042d47c24e5d8d8784d51248ec8a2491efac17e324c75ff791fa56cd/gtsam_develop-4.3a1.dev202603292217-cp313-cp313-macosx_10_15_universal2.whl", hash = "sha256:7311d39048f2757c9f03d3e6f5caec5a7e5afa2b6b01c5952869bd74185ea08a", size = 41177166, upload-time = "2026-03-29T22:54:40.579Z" }, + { url = "https://files.pythonhosted.org/packages/25/ae/447867b2c7a71547e1b6ab8e6bfbbbaf8b16ac03fb1131e66bbe5bd2dfe3/gtsam_develop-4.3a1.dev202603292217-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4cc2486f4516190bbaa1162355df39ddc3440142cc60d8e3a5fb9f355b37866e", size = 29133773, upload-time = "2026-03-29T22:54:43.393Z" }, + { url = "https://files.pythonhosted.org/packages/9d/07/cb3336d5d371bb233dd48608a5dd11f39762aec32f87785ccb3c0458be1d/gtsam_develop-4.3a1.dev202603292217-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:ed51a9fbbca8967ad5a3a218dcd8d02ffdc0239204c4f7b541deff1331d3e9f4", size = 41193548, upload-time = "2026-03-29T22:54:55.153Z" }, + { url = "https://files.pythonhosted.org/packages/f5/e8/6764978a14b3d3911c29d5bad5e2a3e360e6f424e8400d6ec51d894c28ce/gtsam_develop-4.3a1.dev202603292217-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:608e4ea983bdc65ae42a9adc18eac505b2936568ea64d639020819f45c2cb00d", size = 29158254, upload-time = "2026-03-29T22:54:57.873Z" }, +] + [[package]] name = "h11" version = "0.16.0"