diff --git a/changes.md b/changes.md new file mode 100644 index 0000000000..d1e4b4b2e7 --- /dev/null +++ b/changes.md @@ -0,0 +1,19 @@ +# PR #1643 (rconnect) — Paul Review Fixes + +## Commits (local, not pushed) + +### 1. `81769d273` — Log exception + unblock stop() on startup failure +- If `_serve()` throws, `_server_ready` was never set → `stop()` blocked 5s +- Now logs exception and sets `_server_ready` in finally +- **Revert:** `git revert 81769d273` + +## Reviewer was wrong on +- `_server_ready` race — it IS set inside `async with` (after bind), not before +- `msg.get("x") or 0` — code already uses `msg.get("x", 0)` correctly + +## Not addressed (need Jeff's input) +- `vis_module` always bundling `RerunWebSocketServer` — opt-out design choice +- `LCM()` instantiated for non-rerun backends — wasted resource +- `rerun-connect` skipping `WebsocketVisModule` — intentional? +- Default `host = "0.0.0.0"` — intentional for remote viewer use case +- Hardcoded test ports — should use port=0 for parallel safety diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index b8165658d9..a02f947af1 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -22,6 +22,7 @@ from dimos.agents.annotation import skill from dimos.core.blueprints import autoconnect from dimos.core.core import rpc +from dimos.core.global_config import global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware @@ -32,7 +33,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module def default_transform() -> Transform: @@ -120,5 +121,5 @@ def stop(self) -> None: demo_camera = autoconnect( CameraModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module(viewer_backend=global_config.viewer), ) diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index f3de842b46..b39dd7bcec 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -15,36 +15,45 @@ from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 from dimos.mapping.voxels import VoxelGridMapper -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module voxel_size = 0.05 mid360_fastlio = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), VoxelGridMapper.blueprint(publish_interval=1.0, voxel_size=voxel_size, carve_columns=False), - RerunBridgeModule.blueprint( - visual_override={ - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - "world/lidar": None, - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + "world/lidar": None, + }, + }, ), ).global_config(n_workers=3, robot_model="mid360_fastlio2_voxels") mid360_fastlio_voxels_native = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=3.0), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": None, - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": None, + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") diff --git a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py index c8835b3e89..958af084e2 100644 --- a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py +++ b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py @@ -14,9 +14,9 @@ from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.lidar.livox.module import Mid360 -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module mid360 = autoconnect( Mid360.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).global_config(n_workers=2, robot_model="mid360") diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 782283029b..f1ce67709e 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -14,15 +14,14 @@ # limitations under the License. from pathlib import Path -from dimos.agents.mcp.mcp_client import McpClient -from dimos.agents.mcp.mcp_server import McpServer +from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera from dimos.manipulation.grasping.graspgen_module import graspgen from dimos.manipulation.grasping.grasping import GraspingModule from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_module = RealSenseCamera.blueprint(enable_pointcloud=False) @@ -44,7 +43,6 @@ ("/tmp", "/tmp", "rw") ], # Grasp visualization debug standalone: python -m dimos.manipulation.grasping.visualize_grasps ), - FoxgloveBridge.blueprint(), - McpServer.blueprint(), - McpClient.blueprint(), + vis_module("foxglove"), + Agent.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index c6d8c96625..13fb26cbb5 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.mcp.mcp_client import McpClient -from dimos.agents.mcp.mcp_server import McpServer +from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera from dimos.hardware.sensors.camera.zed.compat import ZEDCamera from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_choice = "zed" @@ -34,7 +33,6 @@ demo_object_scene_registration = autoconnect( camera_module, ObjectSceneRegistrationModule.blueprint(target_frame="world", prompt_mode=YoloePromptMode.LRPC), - FoxgloveBridge.blueprint(), - McpServer.blueprint(), - McpClient.blueprint(), + vis_module("foxglove"), + Agent.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/robot/drone/blueprints/basic/drone_basic.py b/dimos/robot/drone/blueprints/basic/drone_basic.py index 2c0b5ccb16..f2efda9e98 100644 --- a/dimos/robot/drone/blueprints/basic/drone_basic.py +++ b/dimos/robot/drone/blueprints/basic/drone_basic.py @@ -20,10 +20,9 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config -from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.drone.camera_module import DroneCameraModule from dimos.robot.drone.connection_module import DroneConnectionModule -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module def _static_drone_body(rr: Any) -> list[Any]: @@ -60,23 +59,12 @@ def _drone_rerun_blueprint() -> Any: _rerun_config = { "blueprint": _drone_rerun_blueprint, - "pubsubs": [LCM()], "static": { "world/tf/base_link": _static_drone_body, }, } -# Conditional visualization -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - _vis = FoxgloveBridge.blueprint() -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - _vis = RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config) -else: - _vis = autoconnect() +_vis = vis_module(global_config.viewer, rerun_config=_rerun_config) # Determine connection string based on replay flag connection_string = "udp:0.0.0.0:14550" @@ -92,7 +80,6 @@ def _drone_rerun_blueprint() -> Any: outdoor=False, ), DroneCameraModule.blueprint(camera_intrinsics=[1000.0, 1000.0, 960.0, 540.0]), - WebsocketVisModule.blueprint(), ) __all__ = [ diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index 5b127fb697..721487d717 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -17,10 +17,11 @@ from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 +from dimos.visualization.vis_module import vis_module unitree_g1_shm = autoconnect( unitree_g1.transports( @@ -30,10 +31,9 @@ ), } ), - FoxgloveBridge.blueprint( - shm_channels=[ - "/color_image#sensor_msgs.Image", - ] + vis_module( + viewer_backend=global_config.viewer, + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, ), ) diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index ff59c9b8ef..97db531134 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -40,8 +40,7 @@ from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module def _convert_camera_info(camera_info: Any) -> Any: @@ -98,7 +97,6 @@ def _g1_rerun_blueprint() -> Any: rerun_config = { "blueprint": _g1_rerun_blueprint, - "pubsubs": [LCM()], "visual_override": { "world/camera_info": _convert_camera_info, "world/global_map": _convert_global_map, @@ -109,18 +107,7 @@ def _g1_rerun_blueprint() -> Any: }, } -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - _with_vis = autoconnect(FoxgloveBridge.blueprint()) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - _with_vis = autoconnect( - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config) - ) -else: - _with_vis = autoconnect() +_with_vis = vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config) def _create_webcam() -> Webcam: @@ -155,8 +142,6 @@ def _create_webcam() -> Webcam: VoxelGridMapper.blueprint(voxel_size=0.1), CostMapper.blueprint(), WavefrontFrontierExplorer.blueprint(), - # Visualization - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_g1") .transports( diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index f32561e11d..f62e4c606a 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -22,10 +22,9 @@ from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.connection import GO2Connection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module # Mac has some issue with high bandwidth UDP, so we use pSHMTransport for color_image # actually we can use pSHMTransport for all platforms, and for all streams @@ -95,9 +94,6 @@ def _go2_rerun_blueprint() -> Any: rerun_config = { "blueprint": _go2_rerun_blueprint, - # any pubsub that supports subscribe_all and topic that supports str(topic) - # is acceptable here - "pubsubs": [LCM()], # Custom converters for specific rerun entity paths # Normally all these would be specified in their respectative modules # Until this is implemented we have central overrides here @@ -114,29 +110,19 @@ def _go2_rerun_blueprint() -> Any: }, } - -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - with_vis = autoconnect( - _transports_base, - FoxgloveBridge.blueprint(shm_channels=["/color_image#sensor_msgs.Image"]), - ) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - with_vis = autoconnect( - _transports_base, - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), - ) -else: - with_vis = _transports_base +_with_vis = autoconnect( + _transports_base, + vis_module( + viewer_backend=global_config.viewer, + rerun_config=rerun_config, + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, + ), +) unitree_go2_basic = ( autoconnect( - with_vis, + _with_vis, GO2Connection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index 1c55f3e93c..0468cad40d 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -22,15 +22,13 @@ from dimos.core.blueprints import autoconnect from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator -from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import _with_vis from dimos.robot.unitree.go2.fleet_connection import Go2FleetConnection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule unitree_go2_fleet = ( autoconnect( - with_vis, + _with_vis, Go2FleetConnection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index d6367310de..1b67de3b75 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -26,12 +26,12 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_extensions import ArmTeleopModule from dimos.teleop.quest.quest_types import Buttons -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module # Arm teleop with press-and-hold engage (has rerun viz) teleop_quest_rerun = autoconnect( ArmTeleopModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).transports( { ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py index 84168ce057..6aa1859659 100644 --- a/dimos/utils/generic.py +++ b/dimos/utils/generic.py @@ -13,13 +13,52 @@ # limitations under the License. from collections.abc import Callable +import functools import hashlib import json import os +from pathlib import Path +import platform import string +import sys from typing import Any, Generic, TypeVar, overload import uuid + +@functools.lru_cache(maxsize=1) +def is_jetson() -> bool: + """Check if running on an NVIDIA Jetson device.""" + if sys.platform != "linux": + return False + # Check kernel release for Tegra (most lightweight) + if "tegra" in platform.release().lower(): + return True + # Check device tree (works in containers with proper mounts) + try: + return "nvidia,tegra" in Path("/proc/device-tree/compatible").read_text() + except (FileNotFoundError, PermissionError): + pass + # Check for L4T release file + return Path("/etc/nv_tegra_release").exists() + + +def get_local_ips() -> list[tuple[str, str]]: + """Return ``(ip, interface_name)`` for every non-loopback IPv4 address. + + Picks up physical, virtual, and VPN interfaces (including Tailscale). + """ + import socket + + import psutil + + results: list[tuple[str, str]] = [] + for iface, addrs in psutil.net_if_addrs().items(): + for addr in addrs: + if addr.family == socket.AF_INET and not addr.address.startswith("127."): + results.append((addr.address, iface)) + return results + + _T = TypeVar("_T") diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index de89c5d347..843ae421f4 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -56,6 +56,7 @@ RERUN_GRPC_PORT = 9876 RERUN_WEB_PORT = 9090 + # TODO OUT visual annotations # # In the future it would be nice if modules can annotate their individual OUTs with (general or rerun specific) @@ -133,7 +134,9 @@ def to_rerun(self) -> RerunData: ... def _hex_to_rgba(hex_color: str) -> int: """Convert '#RRGGBB' to a 0xRRGGBBAA int (fully opaque).""" h = hex_color.lstrip("#") - return (int(h, 16) << 8) | 0xFF + if len(h) == 6: + return int(h + "ff", 16) + return int(h[:8], 16) def _with_graph_tab(bp: Blueprint) -> Blueprint: @@ -157,7 +160,7 @@ def _default_blueprint() -> Blueprint: import rerun as rr import rerun.blueprint as rrb - return rrb.Blueprint( + return rrb.Blueprint( # type: ignore[no-any-return] rrb.Spatial3DView( origin="world", background=rrb.Background(kind="SolidColor", color=[0, 0, 0]), @@ -223,10 +226,12 @@ class RerunBridgeModule(Module[Config]): """ default_config = Config + _last_log: dict[str, float] = {} - GV_SCALE = 100.0 # graphviz inches to rerun screen units - MODULE_RADIUS = 30.0 - CHANNEL_RADIUS = 20.0 + # Graphviz layout scale and node radii for blueprint graph + GV_SCALE = 100.0 + MODULE_RADIUS = 20.0 + CHANNEL_RADIUS = 12.0 @lru_cache(maxsize=256) def _visual_override_for_entity_path( @@ -310,13 +315,14 @@ def start(self) -> None: super().start() - self._last_log: dict[str, float] = {} + self._last_log: dict[str, float] = {} # reset on each start logger.info("Rerun bridge starting", viewer_mode=self.config.viewer_mode) # Initialize and spawn Rerun viewer rr.init("dimos") if self.config.viewer_mode == "native": + spawned = False try: import rerun_bindings @@ -325,6 +331,7 @@ def start(self) -> None: executable_name="dimos-viewer", memory_limit=self.config.memory_limit, ) + spawned = True except ImportError: pass # dimos-viewer not installed except Exception: @@ -332,12 +339,31 @@ def start(self) -> None: "dimos-viewer found but failed to spawn, falling back to stock rerun", exc_info=True, ) - rr.spawn(connect=True, memory_limit=self.config.memory_limit) + if not spawned: + try: + rr.spawn(connect=True, memory_limit=self.config.memory_limit) + except (RuntimeError, FileNotFoundError): + logger.warning( + "Rerun native viewer not available (headless?). " + "Bridge will continue without a viewer — data is still " + "accessible via rerun-connect or rerun-web.", + exc_info=True, + ) elif self.config.viewer_mode == "web": server_uri = rr.serve_grpc() rr.serve_web_viewer(connect_to=server_uri, open_browser=False) elif self.config.viewer_mode == "connect": - rr.connect_grpc(self.config.connect_url) + # Serve gRPC so external viewers (dimos-viewer) can connect to us. + # Extract the port from the connect_url to match what viewers expect. + from urllib.parse import urlparse + + parsed = urlparse(self.config.connect_url.replace("rerun+", "", 1)) + grpc_port = parsed.port or RERUN_GRPC_PORT + rr.serve_grpc( + grpc_port=grpc_port, + server_memory_limit=self.config.memory_limit, + ) + logger.info(f"Rerun gRPC server ready at {self.config.connect_url}") # "none" - just init, no viewer (connect externally) if self.config.blueprint: @@ -437,6 +463,7 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: @rpc def stop(self) -> None: + self._visual_override_for_entity_path.cache_clear() super().stop() diff --git a/dimos/visualization/rerun/test_viewer_ws_e2e.py b/dimos/visualization/rerun/test_viewer_ws_e2e.py new file mode 100644 index 0000000000..ea8351f2f6 --- /dev/null +++ b/dimos/visualization/rerun/test_viewer_ws_e2e.py @@ -0,0 +1,328 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end test: dimos-viewer (headless) → WebSocket → RerunWebSocketServer. + +dimos-viewer is started in ``--connect`` mode so it initialises its WebSocket +client. The viewer needs a gRPC proxy to connect to; we give it a non-existent +one so the viewer starts up anyway but produces no visualisation. The important +part is that the WebSocket client inside the viewer tries to connect to +``ws://127.0.0.1:/ws``. + +Because the viewer is a native GUI application it cannot run headlessly in CI +without a display. This test therefore verifies the connection at the protocol +level by using the ``RerunWebSocketServer`` module directly as the server and +injecting synthetic JSON messages that mimic what the viewer would send once a +user clicks in the 3D viewport. +""" + +import asyncio +import json +import os +import shutil +import subprocess +import threading +import time +from typing import Any + +import pytest + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_E2E_PORT = 13032 + + +def _make_server(port: int = _E2E_PORT) -> RerunWebSocketServer: + return RerunWebSocketServer(port=port) + + +def _wait_for_server(port: int, timeout: float = 5.0) -> None: + import websockets.asyncio.client as ws_client + + async def _probe() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): + pass + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + asyncio.run(_probe()) + return + except Exception: + time.sleep(0.05) + raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") + + +def _send_messages(port: int, messages: list[dict[str, Any]], *, delay: float = 0.05) -> None: + import websockets.asyncio.client as ws_client + + async def _run() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws") as ws: + for msg in messages: + await ws.send(json.dumps(msg)) + await asyncio.sleep(delay) + + asyncio.run(_run()) + + +class TestViewerProtocolE2E: + """Verify the full Python-server side of the viewer ↔ DimOS protocol. + + These tests use the ``RerunWebSocketServer`` as the server and a dummy + WebSocket client (playing the role of dimos-viewer) to inject messages. + They confirm every message type is correctly routed and that only click + messages produce stream publishes. + """ + + def test_viewer_click_reaches_stream(self) -> None: + """A viewer click message received over WebSocket publishes PointStamped.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + done.set() + + server.clicked_point.subscribe(_on_pt) + + _send_messages( + _E2E_PORT, + [ + { + "type": "click", + "x": 10.0, + "y": 20.0, + "z": 0.5, + "entity_path": "/world/robot", + "timestamp_ms": 42000, + } + ], + ) + + done.wait(timeout=3.0) + server.stop() + + assert len(received) == 1 + pt = received[0] + assert abs(pt.x - 10.0) < 1e-9 + assert abs(pt.y - 20.0) < 1e-9 + assert abs(pt.z - 0.5) < 1e-9 + assert pt.frame_id == "/world/robot" + assert abs(pt.ts - 42.0) < 1e-6 + + def test_viewer_keyboard_twist_no_publish(self) -> None: + """Twist messages from keyboard control do not publish clicked_point.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + server.clicked_point.subscribe(received.append) + + _send_messages( + _E2E_PORT, + [ + { + "type": "twist", + "linear_x": 0.5, + "linear_y": 0.0, + "linear_z": 0.0, + "angular_x": 0.0, + "angular_y": 0.0, + "angular_z": 0.8, + } + ], + ) + + server.stop() + assert received == [] + + def test_viewer_stop_no_publish(self) -> None: + """Stop messages do not publish clicked_point.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + server.clicked_point.subscribe(received.append) + + _send_messages(_E2E_PORT, [{"type": "stop"}]) + + server.stop() + assert received == [] + + def test_full_viewer_session_sequence(self) -> None: + """Realistic session: connect, heartbeats, click, WASD, stop → one point.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + done.set() + + server.clicked_point.subscribe(_on_pt) + + _send_messages( + _E2E_PORT, + [ + # Initial heartbeats (viewer connects and starts 1 Hz heartbeat) + {"type": "heartbeat", "timestamp_ms": 1000}, + {"type": "heartbeat", "timestamp_ms": 2000}, + # User clicks a point in the 3D viewport + { + "type": "click", + "x": 3.14, + "y": 2.71, + "z": 1.41, + "entity_path": "/world", + "timestamp_ms": 3000, + }, + # User presses W (forward) + { + "type": "twist", + "linear_x": 0.5, + "linear_y": 0.0, + "linear_z": 0.0, + "angular_x": 0.0, + "angular_y": 0.0, + "angular_z": 0.0, + }, + # User releases W + {"type": "stop"}, + # Another heartbeat + {"type": "heartbeat", "timestamp_ms": 4000}, + ], + delay=0.2, + ) + + done.wait(timeout=3.0) + server.stop() + + assert len(received) == 1, f"Expected exactly 1 click, got {len(received)}" + pt = received[0] + assert abs(pt.x - 3.14) < 1e-9 + assert abs(pt.y - 2.71) < 1e-9 + assert abs(pt.z - 1.41) < 1e-9 + + def test_reconnect_after_disconnect(self) -> None: + """Server keeps accepting new connections after a client disconnects.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + all_done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + if len(received) >= 2: + all_done.set() + + server.clicked_point.subscribe(_on_pt) + + # First connection — send one click and disconnect + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 1.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + + # Second connection (simulating viewer reconnect) — send another click + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 2.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + + all_done.wait(timeout=5.0) + server.stop() + + xs = sorted(pt.x for pt in received) + assert xs == [1.0, 2.0], f"Unexpected xs: {xs}" + + +class TestViewerBinaryConnectMode: + """Smoke test: dimos-viewer binary starts in --connect mode and its WebSocket + client attempts to connect to our Python server.""" + + @pytest.mark.skipif( + shutil.which("dimos-viewer") is None + or "--connect" + not in subprocess.run(["dimos-viewer", "--help"], capture_output=True, text=True).stdout, + reason="dimos-viewer binary not installed or does not support --connect", + ) + def test_viewer_ws_client_connects(self) -> None: + """dimos-viewer --connect starts and its WS client connects to our server.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + + def _on_pt(pt: Any) -> None: + received.append(pt) + + server.clicked_point.subscribe(_on_pt) + + # Start dimos-viewer in --connect mode, pointing it at a non-existent gRPC + # proxy (it will fail to stream data, but that's fine) and at our WS server. + # Use DISPLAY="" to prevent it from opening a window (it will exit quickly + # without a display, but the WebSocket connection happens before the GUI loop). + proc = subprocess.Popen( + [ + "dimos-viewer", + "--connect", + f"--ws-url=ws://127.0.0.1:{_E2E_PORT}/ws", + ], + env={ + **os.environ, + "DISPLAY": "", + }, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Give the viewer up to 5 s to connect its WebSocket client to our server. + # We detect the connection by waiting for the server to accept a client. + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + # Check if any connection was established by sending a message and + # verifying the viewer is still running. + if proc.poll() is not None: + # Viewer exited (expected without a display) — check if it connected first. + break + time.sleep(0.1) + + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + + stdout = proc.stdout.read().decode(errors="replace") if proc.stdout else "" + stderr = proc.stderr.read().decode(errors="replace") if proc.stderr else "" + server.stop() + + # The viewer should log that it is connecting to our WS URL. + # Check both stdout and stderr since log output destination varies. + combined = stdout + stderr + assert f"ws://127.0.0.1:{_E2E_PORT}" in combined, ( + f"Viewer did not attempt WS connection.\nstdout:\n{stdout}\nstderr:\n{stderr}" + ) diff --git a/dimos/visualization/rerun/test_websocket_server.py b/dimos/visualization/rerun/test_websocket_server.py new file mode 100644 index 0000000000..cec85fbb11 --- /dev/null +++ b/dimos/visualization/rerun/test_websocket_server.py @@ -0,0 +1,469 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RerunWebSocketServer. + +Uses ``MockViewerPublisher`` to simulate dimos-viewer sending events, matching +the exact JSON protocol used by the Rust ``WsPublisher`` in the viewer. +""" + +import asyncio +import json +import threading +import time +from typing import Any + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_TEST_PORT = 13031 + + +class MockViewerPublisher: + """Python mirror of the Rust WsPublisher in dimos-viewer. + + Connects to a running ``RerunWebSocketServer`` and exposes the same + ``send_click`` / ``send_twist`` / ``send_stop`` / ``send_heartbeat`` + API that the real viewer uses. Useful for unit tests that need to + exercise the server without a real viewer binary. + + Usage:: + + with MockViewerPublisher("ws://127.0.0.1:13031/ws") as pub: + pub.send_click(1.0, 2.0, 0.0, "/world", timestamp_ms=1000) + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + pub.send_stop() + """ + + def __init__(self, url: str) -> None: + self._url = url + self._ws: Any = None + self._loop: asyncio.AbstractEventLoop | None = None + + def __enter__(self) -> "MockViewerPublisher": + self._loop = asyncio.new_event_loop() + self._ws = self._loop.run_until_complete(self._connect()) + return self + + def __exit__(self, *_: Any) -> None: + if self._ws is not None and self._loop is not None: + self._loop.run_until_complete(self._ws.close()) + if self._loop is not None: + self._loop.close() + + async def _connect(self) -> Any: + import websockets.asyncio.client as ws_client + + return await ws_client.connect(self._url) + + def send_click( + self, + x: float, + y: float, + z: float, + entity_path: str = "", + timestamp_ms: int = 0, + ) -> None: + """Send a click event — matches viewer SelectionChange handler output.""" + self._send( + { + "type": "click", + "x": x, + "y": y, + "z": z, + "entity_path": entity_path, + "timestamp_ms": timestamp_ms, + } + ) + + def send_twist( + self, + linear_x: float, + linear_y: float, + linear_z: float, + angular_x: float, + angular_y: float, + angular_z: float, + ) -> None: + """Send a twist (WASD keyboard) event.""" + self._send( + { + "type": "twist", + "linear_x": linear_x, + "linear_y": linear_y, + "linear_z": linear_z, + "angular_x": angular_x, + "angular_y": angular_y, + "angular_z": angular_z, + } + ) + + def send_stop(self) -> None: + """Send a stop event (Space bar or key release).""" + self._send({"type": "stop"}) + + def send_heartbeat(self, timestamp_ms: int = 0) -> None: + """Send a heartbeat (1 Hz keepalive from viewer).""" + self._send({"type": "heartbeat", "timestamp_ms": timestamp_ms}) + + def flush(self, delay: float = 0.1) -> None: + """Wait briefly so the server processes queued messages.""" + time.sleep(delay) + + def _send(self, msg: dict[str, Any]) -> None: + assert self._loop is not None and self._ws is not None, "Not connected" + self._loop.run_until_complete(self._ws.send(json.dumps(msg))) + + +def _collect(received: list[Any], done: threading.Event) -> Any: + """Return a callback that appends to *received* and signals *done*.""" + + def _cb(msg: Any) -> None: + received.append(msg) + done.set() + + return _cb + + +def _make_module(port: int = _TEST_PORT) -> RerunWebSocketServer: + return RerunWebSocketServer(port=port) + + +def _wait_for_server(port: int, timeout: float = 3.0) -> None: + """Block until the WebSocket server accepts an upgrade handshake.""" + + async def _probe() -> None: + import websockets.asyncio.client as ws_client + + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): + pass + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + asyncio.run(_probe()) + return + except Exception: + time.sleep(0.05) + raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") + + +class TestRerunWebSocketServerStartup: + def test_server_binds_port(self) -> None: + """After start(), the server must be reachable on the configured port.""" + mod = _make_module() + mod.start() + try: + _wait_for_server(_TEST_PORT) + finally: + mod.stop() + + def test_stop_is_idempotent(self) -> None: + """Calling stop() twice must not raise.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + mod.stop() + mod.stop() + + +class TestClickMessages: + def test_click_publishes_point_stamped(self) -> None: + """A single click publishes one PointStamped with correct coords.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.clicked_point.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(1.5, 2.5, 0.0, "/world", timestamp_ms=1000) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + pt = received[0] + assert abs(pt.x - 1.5) < 1e-9 + assert abs(pt.y - 2.5) < 1e-9 + assert abs(pt.z - 0.0) < 1e-9 + + def test_click_sets_frame_id_from_entity_path(self) -> None: + """entity_path is stored as frame_id on the published PointStamped.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.clicked_point.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(0.0, 0.0, 0.0, "/robot/base", timestamp_ms=2000) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + assert received and received[0].frame_id == "/robot/base" + + def test_click_timestamp_converted_from_ms(self) -> None: + """timestamp_ms is converted to seconds on PointStamped.ts.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.clicked_point.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(0.0, 0.0, 0.0, "", timestamp_ms=5000) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + assert received and abs(received[0].ts - 5.0) < 1e-6 + + def test_multiple_clicks_all_published(self) -> None: + """A burst of clicks all arrive on the stream.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + all_arrived = threading.Event() + + def _cb(pt: Any) -> None: + received.append(pt) + if len(received) >= 3: + all_arrived.set() + + mod.clicked_point.subscribe(_cb) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(1.0, 0.0, 0.0) + pub.send_click(2.0, 0.0, 0.0) + pub.send_click(3.0, 0.0, 0.0) + pub.flush() + + all_arrived.wait(timeout=3.0) + mod.stop() + + assert sorted(pt.x for pt in received) == [1.0, 2.0, 3.0] + + +class TestNonClickMessages: + def test_heartbeat_does_not_publish(self) -> None: + """Heartbeat messages must not trigger a clicked_point publish.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + clicks: list[Any] = [] + twists: list[Any] = [] + twist_done = threading.Event() + mod.clicked_point.subscribe(clicks.append) + mod.tele_cmd_vel.subscribe(_collect(twists, twist_done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_heartbeat(9999) + # Send a canary twist so we know the server processed everything + pub.send_stop() + pub.flush() + + twist_done.wait(timeout=2.0) + mod.stop() + assert clicks == [] + + def test_twist_does_not_publish_clicked_point(self) -> None: + """Twist messages must not trigger a clicked_point publish.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + clicks: list[Any] = [] + twists: list[Any] = [] + twist_done = threading.Event() + mod.clicked_point.subscribe(clicks.append) + mod.tele_cmd_vel.subscribe(_collect(twists, twist_done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + pub.flush() + + twist_done.wait(timeout=2.0) + mod.stop() + assert clicks == [] + + def test_stop_does_not_publish_clicked_point(self) -> None: + """Stop messages must not trigger a clicked_point publish.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + clicks: list[Any] = [] + twists: list[Any] = [] + twist_done = threading.Event() + mod.clicked_point.subscribe(clicks.append) + mod.tele_cmd_vel.subscribe(_collect(twists, twist_done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_stop() + pub.flush() + + twist_done.wait(timeout=2.0) + mod.stop() + assert clicks == [] + + def test_twist_publishes_on_tele_cmd_vel(self) -> None: + """Twist messages publish a Twist on the tele_cmd_vel stream.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.tele_cmd_vel.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + tw = received[0] + assert abs(tw.linear.x - 0.5) < 1e-9 + assert abs(tw.angular.z - 0.8) < 1e-9 + + def test_stop_publishes_zero_twist_on_tele_cmd_vel(self) -> None: + """Stop messages publish a zero Twist on the tele_cmd_vel stream.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.tele_cmd_vel.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_stop() + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + tw = received[0] + assert tw.is_zero() + + def test_twist_publishes_stop_explore_cmd_on_first_twist(self) -> None: + """First twist publishes Bool(data=True) on stop_explore_cmd; stop resets.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + explore_cmds: list[Any] = [] + twists: list[Any] = [] + first_done = threading.Event() + mod.stop_explore_cmd.subscribe(_collect(explore_cmds, first_done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.0) + pub.flush() + + first_done.wait(timeout=2.0) + assert len(explore_cmds) == 1 + assert explore_cmds[0].data is True + + # Second twist within same connection should NOT publish another stop_explore_cmd + twist_done = threading.Event() + mod.tele_cmd_vel.subscribe(_collect(twists, twist_done)) + + pub.send_twist(0.3, 0.0, 0.0, 0.0, 0.0, 0.0) + pub.flush() + + twist_done.wait(timeout=2.0) + assert len(explore_cmds) == 1 # still just the first one + + # After stop + new twist within same connection, stop_explore_cmd should fire again + second_done = threading.Event() + + def _on_second(msg: Any) -> None: + explore_cmds.append(msg) + if len(explore_cmds) >= 2: + second_done.set() + + mod.stop_explore_cmd.subscribe(_on_second) + + pub.send_stop() + pub.send_twist(0.1, 0.0, 0.0, 0.0, 0.0, 0.0) + pub.flush() + + second_done.wait(timeout=2.0) + + mod.stop() + assert len(explore_cmds) >= 2 + + def test_invalid_json_does_not_crash(self) -> None: + """Malformed JSON is silently dropped; server stays alive.""" + import websockets.asyncio.client as ws_client + + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + async def _send_bad() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{_TEST_PORT}/ws") as ws: + await ws.send("this is not json {{") + await asyncio.sleep(0.1) + await ws.send(json.dumps({"type": "heartbeat", "timestamp_ms": 0})) + await asyncio.sleep(0.1) + + asyncio.run(_send_bad()) + mod.stop() + + def test_mixed_message_sequence(self) -> None: + """Realistic sequence: heartbeat → click → twist → stop publishes one point.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + # Subscribe before sending so we don't race against the click dispatch. + received: list[Any] = [] + done = threading.Event() + + def _cb(pt: Any) -> None: + received.append(pt) + done.set() + + mod.clicked_point.subscribe(_cb) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_heartbeat(1000) + pub.send_click(7.0, 8.0, 9.0, "/map", timestamp_ms=1100) + pub.send_twist(0.3, 0.0, 0.0, 0.0, 0.0, 0.2) + pub.send_stop() + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + assert abs(received[0].x - 7.0) < 1e-9 + assert abs(received[0].y - 8.0) < 1e-9 + assert abs(received[0].z - 9.0) < 1e-9 diff --git a/dimos/visualization/rerun/websocket_server.py b/dimos/visualization/rerun/websocket_server.py new file mode 100644 index 0000000000..9275018e32 --- /dev/null +++ b/dimos/visualization/rerun/websocket_server.py @@ -0,0 +1,231 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WebSocket server module that receives events from dimos-viewer. + +When dimos-viewer is started with ``--connect``, LCM multicast is unavailable +across machines. The viewer falls back to sending click, twist, and stop events +as JSON over a WebSocket connection. This module acts as the server-side +counterpart: it listens for those connections and translates incoming messages +into DimOS stream publishes. + +Message format (newline-delimited JSON, ``"type"`` discriminant): + + {"type":"heartbeat","timestamp_ms":1234567890} + {"type":"click","x":1.0,"y":2.0,"z":3.0,"entity_path":"/world","timestamp_ms":1234567890} + {"type":"twist","linear_x":0.5,"linear_y":0.0,"linear_z":0.0, + "angular_x":0.0,"angular_y":0.0,"angular_z":0.8} + {"type":"stop"} +""" + +import asyncio +import json +import threading +from typing import Any + +from dimos_lcm.std_msgs import Bool # type: ignore[import-untyped] +import websockets + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class Config(ModuleConfig): + # Intentionally binds 0.0.0.0 by default so the viewer can connect from + # any machine on the network (the typical robot deployment scenario). + host: str = "0.0.0.0" + port: int = 3030 + start_timeout: float = 10.0 # seconds to wait for the server to bind + + +class RerunWebSocketServer(Module[Config]): + """Receives dimos-viewer WebSocket events and publishes them as DimOS streams. + + The viewer connects to this module (not the other way around) when running + in ``--connect`` mode. Each click event is converted to a ``PointStamped`` + and published on the ``clicked_point`` stream so downstream modules (e.g. + ``ReplanningAStarPlanner``) can consume it without modification. + + Outputs: + clicked_point: 3-D world-space point from the most recent viewer click. + tele_cmd_vel: Twist velocity commands from keyboard teleop, including stop events. + """ + + default_config = Config + + clicked_point: Out[PointStamped] + tele_cmd_vel: Out[Twist] + stop_explore_cmd: Out[Bool] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._teleop_clients: set[int] = set() # ids of clients currently in teleop + self._ws_loop: asyncio.AbstractEventLoop | None = None + self._server_thread: threading.Thread | None = None + self._stop_event: asyncio.Event | None = None + self._server_ready = threading.Event() + + @rpc + def start(self) -> None: + super().start() + self._server_thread = threading.Thread( + target=self._run_server, daemon=True, name="rerun-ws-server" + ) + self._server_thread.start() + self._server_ready.wait(timeout=self.config.start_timeout) + self._log_connect_hints() + + @rpc + def stop(self) -> None: + # Wait briefly for the server thread to initialise _stop_event so we + # don't silently skip the shutdown signal (race with _serve()). + self._server_ready.wait(timeout=self.config.start_timeout) + if ( + self._ws_loop is not None + and not self._ws_loop.is_closed() + and self._stop_event is not None + ): + self._ws_loop.call_soon_threadsafe(self._stop_event.set) + super().stop() + + def _log_connect_hints(self) -> None: + """Log the WebSocket URL(s) that viewers should connect to.""" + import socket + + from dimos.utils.generic import get_local_ips + + local_ips = get_local_ips() + hostname = socket.gethostname() + ws_url = f"ws://127.0.0.1:{self.config.port}/ws" + + lines = [ + "", + "=" * 60, + f"RerunWebSocketServer listening on {ws_url}", + "", + ] + if local_ips: + lines.append("From another machine on the network:") + for ip, iface in local_ips: + lines.append(f" ws://{ip}:{self.config.port}/ws # {iface}") + lines.append("") + lines.append(f" hostname: {hostname}") + lines.append("=" * 60) + lines.append("") + + logger.info("\n".join(lines)) + + def _run_server(self) -> None: + """Entry point for the background server thread.""" + self._ws_loop = asyncio.new_event_loop() + try: + self._ws_loop.run_until_complete(self._serve()) + except Exception: + logger.error("RerunWebSocketServer failed to start", exc_info=True) + finally: + self._server_ready.set() # unblock stop() even on failure + self._ws_loop.close() + + async def _serve(self) -> None: + import websockets.asyncio.server as ws_server + + self._stop_event = asyncio.Event() + + async with ws_server.serve( + self._handle_client, + host=self.config.host, + port=self.config.port, + # Ping every 30 s, allow 30 s for pong — generous enough to + # survive brief network hiccups while still detecting dead clients. + ping_interval=30, + ping_timeout=30, + ): + self._server_ready.set() + await self._stop_event.wait() + + async def _handle_client(self, websocket: Any) -> None: + if hasattr(websocket, "request") and websocket.request.path != "/ws": + await websocket.close(1008, "Not Found") + return + addr = websocket.remote_address + client_id = id(websocket) + logger.info(f"RerunWebSocketServer: viewer connected from {addr}") + try: + async for raw in websocket: + self._dispatch(raw, client_id) + except websockets.ConnectionClosed as exc: + logger.debug(f"RerunWebSocketServer: client {addr} disconnected ({exc})") + finally: + self._teleop_clients.discard(client_id) + + def _dispatch(self, raw: str | bytes, client_id: int) -> None: + try: + msg = json.loads(raw) + except json.JSONDecodeError: + logger.warning(f"RerunWebSocketServer: ignoring non-JSON message: {raw!r}") + return + + if not isinstance(msg, dict): + logger.warning(f"RerunWebSocketServer: expected JSON object, got {type(msg).__name__}") + return + + msg_type = msg.get("type") + + if msg_type == "click": + pt = PointStamped( + x=float(msg.get("x", 0)), + y=float(msg.get("y", 0)), + z=float(msg.get("z", 0)), + ts=float(msg.get("timestamp_ms", 0)) / 1000.0, + frame_id=str(msg.get("entity_path", "")), + ) + logger.debug(f"RerunWebSocketServer: click → {pt}") + self.clicked_point.publish(pt) + + elif msg_type == "twist": + twist = Twist( + linear=Vector3( + float(msg.get("linear_x", 0)), + float(msg.get("linear_y", 0)), + float(msg.get("linear_z", 0)), + ), + angular=Vector3( + float(msg.get("angular_x", 0)), + float(msg.get("angular_y", 0)), + float(msg.get("angular_z", 0)), + ), + ) + logger.debug(f"RerunWebSocketServer: twist → {twist}") + if not self._teleop_clients: + self.stop_explore_cmd.publish(Bool(data=True)) + self._teleop_clients.add(client_id) + self.tele_cmd_vel.publish(twist) + + elif msg_type == "stop": + logger.debug("RerunWebSocketServer: stop") + self._teleop_clients.discard(client_id) + self.tele_cmd_vel.publish(Twist.zero()) + + elif msg_type == "heartbeat": + logger.debug(f"RerunWebSocketServer: heartbeat ts={msg.get('timestamp_ms')}") + + else: + logger.warning(f"RerunWebSocketServer: unknown message type {msg_type!r}") diff --git a/dimos/visualization/vis_module.py b/dimos/visualization/vis_module.py new file mode 100644 index 0000000000..400a912ce4 --- /dev/null +++ b/dimos/visualization/vis_module.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared visualization module factory for all robot blueprints.""" + +from typing import Any + +from dimos.core.blueprints import Blueprint, autoconnect +from dimos.core.global_config import ViewerBackend + + +def vis_module( + viewer_backend: ViewerBackend, + rerun_config: dict[str, Any] | None = None, + foxglove_config: dict[str, Any] | None = None, +) -> Blueprint: + """Create a visualization blueprint based on the selected viewer backend. + + Bundles the appropriate viewer module (Rerun or Foxglove) together with + the ``WebsocketVisModule`` and ``RerunWebSocketServer`` so that the web + dashboard and remote viewer connections work out of the box. + + Example usage:: + + from dimos.core.global_config import global_config + viz = vis_module( + global_config.viewer, + rerun_config={ + "visual_override": { + "world/camera_info": lambda ci: ci.to_rerun(...), + }, + "static": { + "world/tf/base_link": lambda rr: [rr.Boxes3D(...)], + }, + }, + ) + """ + from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule + + if foxglove_config is None: + foxglove_config = {} + if rerun_config is None: + rerun_config = {} + + match viewer_backend: + case "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + return autoconnect( + FoxgloveBridge.blueprint(**foxglove_config), + RerunWebSocketServer.blueprint(), + WebsocketVisModule.blueprint(), + ) + case "rerun" | "rerun-web" | "rerun-connect": + from dimos.protocol.pubsub.impl.lcmpubsub import LCM + from dimos.visualization.rerun.bridge import _BACKEND_TO_MODE, RerunBridgeModule + + rerun_config = {**rerun_config} + rerun_config.setdefault("pubsubs", [LCM()]) + viewer_mode = _BACKEND_TO_MODE.get(viewer_backend, "native") + return autoconnect( + RerunBridgeModule.blueprint(viewer_mode=viewer_mode, **rerun_config), + RerunWebSocketServer.blueprint(), + WebsocketVisModule.blueprint(), + ) + case "none": + return autoconnect(WebsocketVisModule.blueprint()) + case _: + raise ValueError( + f"Unknown viewer_backend {viewer_backend!r}. " + f"Expected one of: rerun, rerun-web, rerun-connect, foxglove, none" + )