diff --git a/.gitignore b/.gitignore index 8ed8bbc..e99c1c1 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,4 @@ camera_samples build dist .strands_robots -.coverage -.kiro \ No newline at end of file +.coverage \ No newline at end of file diff --git a/README.md b/README.md index ecfab2f..3f7e904 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,23 @@ trainer = create_trainer("groot", trainer.train() ``` +### Networking: Device Connect + +Strands Robots uses [Device Connect](https://github.com/arm/device-connect) by Arm as its primary networking layer — registry-based discovery, structured RPC schemas, device-to-device events, and policy enforcement. Every `Robot()` automatically registers as a Device Connect device when `device-connect-edge` is installed (zero configuration in D2D mode). + +If `device-connect-edge` is not installed, robots fall back to a built-in Zenoh P2P mesh for basic peer discovery and coordination. + +```python +from strands_robots.tools.robot_mesh import robot_mesh + +robot_mesh(action="peers") # discover devices +robot_mesh(action="tell", target="so100-lab-1", # invoke + instruction="pick up the cube") +robot_mesh(action="emergency_stop") # e-stop all +``` + +See the [Device Connect guide](strands_robots/device_connect/GUIDE.md) for architecture, E2E demos, and production deployment. + ## Simulation Three backends. Same `Robot()` interface. diff --git a/pyproject.toml b/pyproject.toml index 3ef1792..7817f6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,8 @@ dependencies = [ "Pillow>=8.0.0", "msgpack>=1.0.0", "pyzmq>=27.0.0", + "device-connect-edge>=0.2.0", + "device-connect-agent-tools>=0.1.0", ] [project.optional-dependencies] @@ -66,14 +68,11 @@ isaac = [ "usd-core>=24.0", # OpenUSD (pxr) for MJCF→USD asset conversion ] newton = [ - "newton>=1.0.0", - "warp-lang>=1.11.0", - "mujoco-warp", - "mujoco>=3.5.0", + "warp-lang>=1.9.0", + "newton-sim>=0.3.0", "trimesh", - "scipy", + "mujoco>=3.5.0", "numpy>=1.24.0", - # Requires NVIDIA GPU + CUDA ] cosmos-transfer = [ "torch>=2.1.0", @@ -96,6 +95,10 @@ cosmos-predict = [ zenoh = [ "eclipse-zenoh>=1.0.0", ] +device-connect = [ + "device-connect-edge", + "device-connect-agent-tools", +] vla = [ "transformers>=5.0.0", "sentencepiece>=0.1.99", @@ -119,6 +122,7 @@ all = [ "strands-robots[cosmos-transfer]", "strands-robots[cosmos-predict]", "strands-robots[zenoh]", + "strands-robots[device-connect]", ] dev = [ "pytest>=6.0", diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py index cafa089..ee6139b 100644 --- a/strands_robots/__init__.py +++ b/strands_robots/__init__.py @@ -366,6 +366,27 @@ except (ImportError, AttributeError, OSError): pass +# Device Connect integration (optional — wraps robots as DC devices) +try: + from strands_robots.device_connect import ( + ReachyMiniDriver, + RobotDeviceDriver, + SimulationDeviceDriver, + init_device_connect, + ) + + __all__.extend( + ["init_device_connect", "RobotDeviceDriver", "SimulationDeviceDriver", "ReachyMiniDriver"] + ) + try: + from device_connect_agent_tools.adapters.strands import discover_devices, invoke_device + + __all__.extend(["discover_devices", "invoke_device"]) + except ImportError: + pass +except (ImportError, AttributeError, OSError): + pass + # Zenoh Robot Mesh (peer-to-peer — every Robot is a peer by default) try: from strands_robots.tools.robot_mesh import robot_mesh diff --git a/strands_robots/device_connect/GUIDE.md b/strands_robots/device_connect/GUIDE.md new file mode 100644 index 0000000..e92233b --- /dev/null +++ b/strands_robots/device_connect/GUIDE.md @@ -0,0 +1,392 @@ +# Device Connect Integration + +Strands Robots uses [Device Connect](https://github.com/arm/device-connect), a **device-aware runtime** by Arm — to handle discovery, presence, structured RPC, event routing, and safety — so you can focus on building cross-device experiences instead of re-implementing infrastructure. + +> **Fallback behavior:** If `device-connect-edge` is not installed, Strands Robots automatically falls back to a built-in Zenoh P2P mesh (`zenoh_mesh.py`) for basic peer discovery and coordination. Device Connect is the recommended and primary networking layer. + +### Quick Start + +```python +from strands_robots import Robot + +r = Robot("so100") +r.run() # starts listening for commands. Ctrl+C to stop. +``` + +`Robot()` creates the robot. `.run()` starts Device Connect with D2D defaults (Zenoh multicast scouting, no broker, no env vars) and blocks — the robot becomes discoverable on the LAN and listens for commands. Without `.run()`, the script exits and the robot is removed from the network. + +You can optionally pass `peer_id="so100-lab-1"` for a stable address; otherwise one is auto-generated (e.g. `so100-a3f1b2`). + +**Robot lifecycle:** + +| Pattern | Behavior | +|---|---| +| `r = Robot("so100"); r.run()` | **Option A — Foreground server.** Process stays alive, listens for commands. Ctrl+C to stop. | +| `r = Robot("so100")` | **Option B — Agent-controlled.** A Strands Agent discovers the robot via `robot_mesh` or `discover_devices()` and invokes commands remotely. | + +From another process, discover and invoke: + +```python +from strands_robots.tools.robot_mesh import robot_mesh + +robot_mesh(action="peers") # discover devices +robot_mesh(action="tell", target="so100-lab-1", # invoke + instruction="pick up the cube") +robot_mesh(action="emergency_stop") # e-stop all +``` + +### Architecture + +```mermaid +graph TD + subgraph "Device Connect Infrastructure" + ZENOH_R["Zenoh Router"] + ETCD["etcd (Registry)"] + REG["Registry Service"] + end + + subgraph "Robot Process" + ROBOT["Robot('so100')"] + ADAPTER["RobotDeviceDriver"] + RUNTIME["DeviceRuntime"] + ROBOT --> ADAPTER + ADAPTER --> RUNTIME + RUNTIME --> ZENOH_R + end + + subgraph "Agent Process" + AGENT["Strands Agent"] + TOOLS["discover_devices + invoke_device"] + AGENT --> TOOLS + TOOLS --> ZENOH_R + end + + ZENOH_R --> REG + REG --> ETCD +``` + +### E2E Demo + +No Docker needed. No env vars. Devices discover each other directly on the LAN via Zenoh multicast scouting. `Robot()` and `robot_mesh()` auto-configure D2D mode when no broker URL is set. + +#### Setup + +##### 1. Install + +> `setup.sh` installs `uv`, Python 3.12, creates a venv, and installs all dependencies. + +```bash +git clone --branch feat/device-connect-integration-draft https://github.com/atsyplikhin/robots.git +cd robots +./strands_robots/device_connect/setup.sh +source .venv/bin/activate +``` + +##### 2. Start a robot + +```bash +screen -S robot # start a persistent session +python -c " +from strands_robots import Robot +r = Robot('so100') +r.run() +" +# once online, Ctrl+a then d to detach +# screen -ls to list sessions +# screen -r robot to reattach +``` + +Expected output: + +``` +device_connect_edge.device. - INFO - Using ZENOH messaging backend +device_connect_edge.device. - INFO - Connected to ZENOH broker: [] +device_connect_edge.device. - INFO - Driver connected: strands_sim +device_connect_edge.device. - INFO - Subscribed to commands on device-connect.default..cmd +🤖 is online. Ctrl+C to stop. +``` + +> `` is auto-generated (e.g. `so100-a3f1b2`) unless you pass a fixed peer ID: +> ```python +> r = Robot('so100', peer_id='so100-lab-1') +> ``` + +#### Option A: Using the `robot_mesh` Strands tool + +The `robot_mesh` tool auto-detects Device Connect and uses it when available, falling back to the plain Zenoh mesh otherwise. + +**Discover peers:** + +```python +python -c " +from strands_robots.tools.robot_mesh import robot_mesh +print(robot_mesh(action='peers')) +" +``` + +Expected output: + +``` +Discovered 1 device(s): + [sim] — idle + Functions: execute, getFeatures, getStatus, reset, step, stop +``` + +**Tell a robot to execute an instruction:** + +Use the `` from the discover step above as the `target`: + +```python +python -c " +from strands_robots.tools.robot_mesh import robot_mesh +print(robot_mesh(action='tell', target='', + instruction='pick up the cube', policy_provider='mock')) +" +``` + +Expected output: + +``` +-> : pick up the cube + {"status": "success", "content": [...]} +``` + +**Emergency stop all devices:** + +```python +python -c " +from strands_robots.tools.robot_mesh import robot_mesh +print(robot_mesh(action='emergency_stop')) +" +``` + +Expected output: + +``` +E-STOP: 1/1 devices stopped +``` + +#### Option B: Discover and invoke with `device-connect-agent-tools` directly + +```python +python -c " +from device_connect_agent_tools import connect, discover_devices, invoke_device + +connect() + +devices = discover_devices() +print(f'Found {len(devices)} robot(s):') +for d in devices: + print(f' {d[\"device_id\"]} — {d.get(\"status\", {}).get(\"availability\", \"?\")}') + +if devices: + result = invoke_device( + devices[0]['device_id'], 'execute', + {'instruction': 'pick up the cube', 'policy_provider': 'mock'}, + ) + print(f'Execute result: {result}') + + status = invoke_device(devices[0]['device_id'], 'getStatus') + print(f'Status: {status}') +" +``` + +Expected output: + +``` +Found 1 robot(s): + — idle +Execute result: {'success': True, 'result': {'status': 'success', 'content': [...]}} +Status: {'success': True, 'result': {...}} # full sim state dict +``` + +#### Full Infrastructure (Optional) + +For production deployments, you can add Docker infrastructure for persistent registry, distributed state, cross-network routing, and authentication. + +Start the Device Connect infrastructure (Zenoh router + etcd + device registry): + +```bash +git clone --depth 1 https://github.com/arm/device-connect.git +cd device-connect/packages/device-connect-server +docker compose -f infra/docker-compose-dev.yml up -d +cd ../../.. +``` + +This starts: + +| Service | Port | Purpose | +|---|---|---| +| Zenoh router | `:7447` | Messaging (RPC, events, heartbeats) | +| etcd | `:2379` | Device registry storage | +| Device registry | `:8080` | REST API for device metadata | + +Set environment variables (all terminals): + +```bash +export MESSAGING_BACKEND=zenoh +export ZENOH_CONNECT=tcp/localhost:7447 +export DEVICE_CONNECT_ALLOW_INSECURE=true +``` + +All the options above (A–B) work identically with full infrastructure — the only difference is that devices register in etcd and discovery goes through the registry service instead of multicast scouting. + +> **What infrastructure adds over D2D:** +> - **Persistent device registry** — devices register with TTL-based leases; stale devices are auto-cleaned. Agents can discover devices by type, location, or capability via `discover_devices()`. +> - **Distributed state & locks** — etcd-backed key-value store with atomic distributed locks for coordinating shared resources (e.g., preventing two agents from using the same robotic arm simultaneously). +> - **Cross-network routing** — the Zenoh router (or NATS broker) enables communication across subnets and sites, not just the local LAN. +> - **Authentication & authorization** — mTLS ensures only devices with certificates signed by the trusted CA can exchange data. Full authorization (per-device permissions, topic-level ACLs, certificate revocation) requires the router/registry infrastructure. + +#### Running the Tests + +```bash +pip install pytest pytest-cov # if not already installed + +# Unit tests (no Docker needed) +python3 -m pytest tests/test_device_connect_drivers.py -v + +# Integration tests (requires Docker infrastructure) +MESSAGING_BACKEND=zenoh ZENOH_CONNECT=tcp/localhost:7447 \ + DEVICE_CONNECT_ALLOW_INSECURE=true \ + python3 -m pytest tests/test_device_connect_integration.py -v +``` + +#### Control Loop Smoke Test + +A self-contained script runs a 200-step mock-policy control loop while a Zenoh listener captures Device Connect events (stateUpdate, observationUpdate, presence, heartbeat) and asserts minimum thresholds: + +```bash +bash strands_robots/device_connect/test_control_loop_dc.sh +``` + +It installs dependencies, starts a Zenoh event listener, runs `Robot("so100")` with a mock policy for 200 steps, then validates that the expected events were published over Device Connect. + +--- + +## Reachy Mini (Zenoh-Native Devices) + +Reachy Mini has built-in Zenoh support — it publishes joint positions, head pose, and IMU data natively over Zenoh topics. This makes it a special case: it can be bridged directly via `subscribe()` or wrapped as a Device Connect device for structured RPC. + +### Bridging via Subscribe + +Use the mesh's `subscribe()` to read Reachy's native Zenoh topics directly: + +```python +sim = Robot("so100") + +# Subscribe to Reachy's head pose +sim.mesh.subscribe("reachy_mini/head_pose", + lambda topic, data: print(f"Reachy looking at: {data}")) + +# Subscribe to Reachy's joint positions +sim.mesh.subscribe("reachy_mini/joint_positions", name="reachy_joints") + +# Mirror Reachy's movements in simulation +def mirror_reachy(topic, data): + joints = data.get("head_joint_positions", []) + if joints: + # Map Reachy joints to sim joints... + pass + +sim.mesh.subscribe("reachy_mini/joint_positions", mirror_reachy) +``` + +### Architecture + +```mermaid +graph TD + subgraph "Reachy Mini Process" + REACHY["ReachyMiniDriver"] + RRUNTIME["DeviceRuntime"] + ZENOH_HW["Zenoh → Reachy HW"] + REACHY --> RRUNTIME + REACHY --> ZENOH_HW + end + + subgraph "Network" + ZENOH["Zenoh Mesh
(multicast or router)"] + end + + subgraph "Agent Process" + AGENT["Strands Agent"] + TOOLS["discover_devices + invoke_device"] + AGENT --> TOOLS + TOOLS --> ZENOH + end + + RRUNTIME --> ZENOH +``` + +### As a Device Connect Device + +Wrap Reachy Mini with `ReachyMiniDriver` to expose it as a structured Device Connect device with RPC commands (`look`, `nod`, etc.): + +```python +from strands_robots.device_connect import ReachyMiniDriver +from device_connect_edge import DeviceRuntime + +driver = ReachyMiniDriver(host="reachy-mini.local") +runtime = DeviceRuntime( + driver=driver, + device_id="reachy-mini-1", + messaging_urls=["tcp/localhost:7447"], + allow_insecure=True, +) +await runtime.run() + +# Now any agent can discover and control it: +invoke_device("reachy-mini-1", "look", {"pitch": -15, "yaw": 30}) +invoke_device("reachy-mini-1", "nod") +``` + +### E2E Demo + +> Requires a Reachy Mini robot. + +**Setup depends on your hardware variant:** + +| Variant | Connection | Setup | +|---|---|---| +| **Reachy Mini** (wireless) | Wi-Fi, onboard Pi | `host='reachy-mini.local'` — no extra install needed | +| **Reachy Mini Lite** (USB) | USB, no Pi | `pip install reachy-mini` then run `reachy-mini` daemon locally. Use `host='localhost'` | + +**Start the Reachy Mini driver:** + +```python +python -c " +import asyncio +from strands_robots.device_connect import ReachyMiniDriver +from device_connect_edge import DeviceRuntime + +# For Lite (USB): host='localhost' (requires reachy-mini daemon running) +# For Wireless: host='reachy-mini.local' +driver = ReachyMiniDriver(host='reachy-mini.local') +runtime = DeviceRuntime( + driver=driver, + device_id='reachy-mini-1', + messaging_urls=['tcp/localhost:7447'], + allow_insecure=True, +) + +asyncio.run(runtime.run()) +" +``` + +Expected output: + +``` +Reachy Mini driver connected: reachy-mini.local +device_connect_edge.device.reachy-mini-1 - INFO - Device registered +device_connect_edge.device.reachy-mini-1 - INFO - Subscribed to commands on device-connect.default.reachy-mini-1.cmd +``` + +**In another terminal, invoke RPCs:** + +```python +python -c " +from device_connect_agent_tools import connect, invoke_device +connect() +print(invoke_device('reachy-mini-1', 'look', {'pitch': -15, 'yaw': 30})) +print(invoke_device('reachy-mini-1', 'nod')) +" +``` diff --git a/strands_robots/device_connect/__init__.py b/strands_robots/device_connect/__init__.py new file mode 100644 index 0000000..e9f3200 --- /dev/null +++ b/strands_robots/device_connect/__init__.py @@ -0,0 +1,194 @@ +"""Device Connect integration for strands-robots. + +Provides DeviceDriver adapters that wrap Robot and Simulation instances, +exposing them to Device Connect's device registry, RPC routing, and event system. + +Usage: + from strands_robots.device_connect import init_device_connect + + robot = Robot("so100") + runtime = await init_device_connect(robot, peer_id="so100-lab-1") + + # Now discoverable via Device Connect tools: + # discover_devices(device_type="strands_robot") + # invoke_device("so100-lab-1", "execute", {"instruction": "pick up cube"}) +""" + +import asyncio +import logging +import os +import threading +import uuid +from typing import Optional + +from device_connect_edge import DeviceRuntime + +from strands_robots.device_connect.reachy_mini_driver import ReachyMiniDriver +from strands_robots.device_connect.robot_driver import RobotDeviceDriver +from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + +logger = logging.getLogger(__name__) + +__all__ = [ + "init_device_connect", + "init_device_connect_sync", + "RobotDeviceDriver", + "SimulationDeviceDriver", + "ReachyMiniDriver", +] + + +async def init_device_connect( + robot, + peer_id: Optional[str] = None, + peer_type: str = "robot", + messaging_url: Optional[str] = None, + messaging_backend: Optional[str] = None, + tenant: str = "default", + allow_insecure: Optional[bool] = None, +) -> DeviceRuntime: + """Initialize Device Connect for a Robot or Simulation. + + Drop-in replacement for init_mesh(). Creates a DeviceDriver adapter + and starts a DeviceRuntime in the background. + + When messaging_backend="zenoh" and messaging_url is None, the runtime + enters D2D mode — devices discover each other directly via Zenoh + multicast scouting on the LAN. No broker, no Docker, no env vars. + + Args: + robot: A Robot or Simulation instance to wrap. + peer_id: Device ID for registration (auto-generated if None). + peer_type: "robot" or "sim" — selects the appropriate driver. + messaging_url: Explicit messaging URL (overrides env vars). + messaging_backend: Messaging backend — "zenoh" or "nats". + None = auto-detect from MESSAGING_BACKEND env var (default "zenoh"). + tenant: Device Connect tenant namespace. + allow_insecure: Allow insecure connections. None = auto-detect: + respects DEVICE_CONNECT_ALLOW_INSECURE env var if set, + otherwise defaults to True in D2D mode (no broker URL). + + Returns: + The running DeviceRuntime instance. + """ + if peer_type == "sim": + driver = SimulationDeviceDriver(robot) + else: + driver = RobotDeviceDriver(robot) + + device_id = peer_id or f"{getattr(robot, 'tool_name_str', 'robot')}-{uuid.uuid4().hex[:4]}" + + urls = [messaging_url] if messaging_url else None + + # Resolve messaging_backend: explicit arg > env var > default "zenoh" + if messaging_backend is None: + messaging_backend = os.environ.get("MESSAGING_BACKEND", "zenoh") + + # Resolve allow_insecure: env var > explicit arg > D2D default + if allow_insecure is None: + env_val = os.environ.get("DEVICE_CONNECT_ALLOW_INSECURE") + if env_val is not None: + allow_insecure = env_val.lower() in ("true", "1", "yes") + elif urls is None: + # D2D mode — no broker, default insecure for dev convenience + allow_insecure = True + + runtime = DeviceRuntime( + driver=driver, + device_id=device_id, + messaging_urls=urls, + messaging_backend=messaging_backend, + tenant=tenant, + allow_insecure=allow_insecure, + ) + + # Provide robot-specific heartbeat data + runtime.set_heartbeat_provider(lambda: _build_heartbeat(robot, peer_type)) + + # Start runtime in background task; store ref to prevent GC + runtime._background_task = asyncio.create_task(runtime.run()) + + logger.info("Device Connect initialized: %s (%s, backend=%s, d2d=%s)", + device_id, peer_type, messaging_backend, urls is None) + return runtime + + +def init_device_connect_sync( + robot, + peer_id: Optional[str] = None, + peer_type: str = "robot", + messaging_url: Optional[str] = None, + messaging_backend: Optional[str] = None, + tenant: str = "default", + allow_insecure: Optional[bool] = None, +) -> "DeviceRuntime": + """Non-blocking sync wrapper around init_device_connect(). + + Starts the DeviceRuntime on a dedicated daemon thread so the caller + returns immediately — matching the Zenoh mesh ``init_mesh()`` pattern. + The runtime stays alive as long as the process (daemon thread). + + Same parameters as :func:`init_device_connect`. + """ + loop = asyncio.new_event_loop() + ready = threading.Event() + runtime_holder = [None] + error_holder = [None] + + async def _start(): + try: + rt = await init_device_connect( + robot, + peer_id=peer_id, + peer_type=peer_type, + messaging_url=messaging_url, + messaging_backend=messaging_backend, + tenant=tenant, + allow_insecure=allow_insecure, + ) + runtime_holder[0] = rt + except Exception as exc: + error_holder[0] = exc + finally: + ready.set() + + def _run(): + asyncio.set_event_loop(loop) + loop.run_until_complete(_start()) + loop.run_forever() + + thread = threading.Thread(target=_run, daemon=True, name="device-connect-runtime") + thread.start() + ready.wait(timeout=30.0) + + if error_holder[0] is not None: + raise error_holder[0] + + runtime = runtime_holder[0] + if runtime is not None: + runtime._loop = loop + runtime._thread = thread + return runtime + + +def _build_heartbeat(robot, peer_type: str) -> dict: + """Build heartbeat payload with robot-specific metadata.""" + data = { + "peer_type": peer_type, + "tool_name": getattr(robot, "tool_name_str", "unknown"), + } + + if peer_type == "robot": + task = getattr(robot, "_task_state", None) + if task: + data["task_status"] = getattr(task.status, "value", "unknown") + data["instruction"] = task.instruction or "" + data["step_count"] = task.step_count + elif peer_type == "sim": + world = getattr(robot, "_world", None) + if world: + data["sim_time"] = world.sim_time + data["step_count"] = world.step_count + data["robots"] = list(world.robots.keys()) + + return data diff --git a/strands_robots/device_connect/reachy_mini_driver.py b/strands_robots/device_connect/reachy_mini_driver.py new file mode 100644 index 0000000..d3a4639 --- /dev/null +++ b/strands_robots/device_connect/reachy_mini_driver.py @@ -0,0 +1,321 @@ +"""ReachyMiniDriver — Device Connect DeviceDriver for Pollen Reachy Mini robots. + +Auto-detects hardware variant via the daemon's ``wireless_version`` flag: +- **Wireless** (has onboard Pi): uses Zenoh transport for real-time I/O. +- **Lite** (USB-only, no Pi): uses WebSocket to the daemon directly. + +REST API calls go through reachy_transport.api() for daemon/move operations. +""" + +import asyncio +import json +import logging +import math +from typing import Optional + +from device_connect_edge.drivers import DeviceDriver, emit, on, rpc +from device_connect_edge.types import DeviceIdentity, DeviceStatus + +from strands_robots.device_connect.reachy_transport import ( + api, + rpy_to_pose, + identity_pose, + ZenohLink, + WebSocketLink, +) + +logger = logging.getLogger(__name__) + + +class ReachyMiniDriver(DeviceDriver): + """Device Connect driver for Pollen Reachy Mini. + + Auto-detects Wireless (Zenoh) vs Lite (WebSocket) via the daemon's + ``wireless_version`` flag. REST API calls work the same for both. + """ + + device_type = "reachy_mini" + + def __init__( + self, + host: str = "reachy-mini.local", + prefix: str = "reachy_mini", + api_port: int = 8000, + ): + super().__init__() + self._host = host + self._prefix = prefix + self._api_port = api_port + self._latest_joints: Optional[dict] = None + self._latest_imu: Optional[dict] = None + self._hw = None + + @property + def identity(self) -> DeviceIdentity: + return DeviceIdentity( + device_type="reachy_mini", + manufacturer="Pollen Robotics", + model=f"Reachy Mini @ {self._host}", + description="Pollen Reachy Mini expressive robot head with antennas", + ) + + @property + def status(self) -> DeviceStatus: + return DeviceStatus(availability="idle") + + async def connect(self) -> None: + """Connect to the Reachy Mini, auto-detecting Wireless vs Lite.""" + try: + status = await asyncio.to_thread( + api, self._host, self._api_port, "/api/daemon/status" + ) + is_lite = not status.get("wireless_version", True) + except Exception: + is_lite = False + + if is_lite: + self._hw = WebSocketLink(self._host, self._api_port) + logger.info("Connected to Reachy Mini Lite at %s (WebSocket)", self._host) + else: + self._hw = ZenohLink(self.transport, self._prefix) + logger.info("Connected to Reachy Mini at %s (Zenoh)", self._host) + + await self._hw.start( + on_joints=lambda d: setattr(self, "_latest_joints", d), + on_imu=lambda d: setattr(self, "_latest_imu", d), + ) + + async def disconnect(self) -> None: + """Tear down the hardware link.""" + if self._hw: + await self._hw.stop() + + # ── Helpers ──────────────────────────────────────────────── + + async def _send_cmd(self, cmd: dict) -> None: + """Send a real-time command via the active hardware link.""" + await self._hw.send_cmd(cmd) + + # ── Movement RPCs (Zenoh via transport) ──────────────────── + + @rpc() + async def look( + self, + pitch: float = 0, + roll: float = 0, + yaw: float = 0, + x: float = 0, + y: float = 0, + z: float = 0, + ) -> dict: + """Set head pose instantly. + + Args: + pitch: Pitch angle in degrees + roll: Roll angle in degrees + yaw: Yaw angle in degrees + x: X offset in mm + y: Y offset in mm + z: Z offset in mm + """ + await self._send_cmd({"head_pose": rpy_to_pose(pitch, roll, yaw, x, y, z)}) + return {"status": "success", "pitch": pitch, "roll": roll, "yaw": yaw} + + @rpc() + async def antennas(self, left: float = 0, right: float = 0) -> dict: + """Set antenna angles. + + Args: + left: Left antenna angle in degrees + right: Right antenna angle in degrees + """ + await self._send_cmd( + {"antennas_joint_positions": [math.radians(left), math.radians(right)]} + ) + return {"status": "success", "left": left, "right": right} + + @rpc() + async def body(self, yaw: float = 0) -> dict: + """Set body yaw angle. + + Args: + yaw: Body yaw angle in degrees + """ + await self._send_cmd({"body_yaw": math.radians(yaw)}) + return {"status": "success", "yaw": yaw} + + # ── Sensor RPCs (cached from transport subscription) ─────── + + @rpc() + async def getJoints(self) -> dict: + """Get current joint positions (head + antennas).""" + d = self._latest_joints + if d is not None: + head = d.get("head_joint_positions", []) + ant = d.get("antennas_joint_positions", []) + return { + "status": "success", + "head": [math.degrees(j) for j in head], + "antennas": [math.degrees(j) for j in ant], + } + return {"status": "error", "reason": "no joint data"} + + @rpc() + async def getImu(self) -> dict: + """Get IMU data (accelerometer, gyroscope, quaternion, temperature).""" + d = self._latest_imu + if d is not None: + return { + "status": "success", + "accelerometer": d.get("accelerometer"), + "gyroscope": d.get("gyroscope"), + "quaternion": d.get("quaternion"), + "temperature": d.get("temperature"), + } + return {"status": "error", "reason": "no IMU data"} + + # ── Motor RPCs (Zenoh via transport) ─────────────────────── + + @rpc() + async def enableMotors(self, motor_ids: str = "") -> dict: + """Enable motors (torque on). + + Args: + motor_ids: Comma-separated motor IDs (empty = all) + """ + ids = [s.strip() for s in motor_ids.split(",") if s.strip()] or None + await self._send_cmd({"torque": True, "ids": ids}) + return {"status": "success", "enabled": motor_ids or "all"} + + @rpc() + async def disableMotors(self, motor_ids: str = "") -> dict: + """Disable motors (torque off). + + Args: + motor_ids: Comma-separated motor IDs (empty = all) + """ + ids = [s.strip() for s in motor_ids.split(",") if s.strip()] or None + await self._send_cmd({"torque": False, "ids": ids}) + return {"status": "success", "disabled": motor_ids or "all"} + + # ── Move RPCs (REST) ────────────────────────────────────── + + @rpc() + async def playMove(self, move_name: str, library: str = "emotions") -> dict: + """Play a recorded move from the HuggingFace library. + + Args: + move_name: Name of the move to play + library: Move library (emotions or dance) + """ + ds = f"pollen-robotics/reachy-mini-{'emotions' if library == 'emotions' else 'dances'}-library" + result = await asyncio.to_thread( + api, self._host, self._api_port, + f"/api/move/play/recorded-move-dataset/{ds}/{move_name}", "POST", + ) + return {"status": "success", "move": move_name, "result": result} + + @rpc() + async def listMoves(self, library: str = "emotions") -> dict: + """List available recorded moves. + + Args: + library: Move library (emotions or dance) + """ + ds = f"pollen-robotics/reachy-mini-{'emotions' if library == 'emotions' else 'dances'}-library" + result = await asyncio.to_thread( + api, self._host, self._api_port, + f"/api/move/recorded-move-datasets/list/{ds}", + ) + return {"status": "success", "moves": result} + + # ── Expression RPCs (Zenoh animations via transport) ─────── + + @rpc() + async def nod(self) -> dict: + """Nod the head (yes gesture).""" + for _ in range(3): + await self._send_cmd({"head_pose": rpy_to_pose(15, 0, 0)}) + await asyncio.sleep(0.25) + await self._send_cmd({"head_pose": rpy_to_pose(-10, 0, 0)}) + await asyncio.sleep(0.25) + await self._send_cmd({"head_pose": identity_pose()}) + return {"status": "success", "expression": "nod"} + + @rpc() + async def shake(self) -> dict: + """Shake the head (no gesture).""" + for _ in range(3): + await self._send_cmd({"head_pose": rpy_to_pose(0, 0, 25)}) + await asyncio.sleep(0.2) + await self._send_cmd({"head_pose": rpy_to_pose(0, 0, -25)}) + await asyncio.sleep(0.2) + await self._send_cmd({"head_pose": identity_pose()}) + return {"status": "success", "expression": "shake"} + + @rpc() + async def happy(self) -> dict: + """Happy antenna wiggle expression.""" + for _ in range(4): + await self._send_cmd( + {"antennas_joint_positions": [math.radians(60), math.radians(-60)]} + ) + await asyncio.sleep(0.2) + await self._send_cmd( + {"antennas_joint_positions": [math.radians(-60), math.radians(60)]} + ) + await asyncio.sleep(0.2) + await self._send_cmd({"antennas_joint_positions": [0, 0]}) + return {"status": "success", "expression": "happy"} + + # ── Lifecycle RPCs (REST) ───────────────────────────────── + + @rpc() + async def wakeUp(self) -> dict: + """Wake up the robot (enable motors + play wake animation).""" + result = await asyncio.to_thread( + api, self._host, self._api_port, "/api/move/play/wake_up", "POST", + ) + return {"status": "success", "result": result} + + @rpc() + async def sleep(self) -> dict: + """Put robot to sleep (play sleep animation + disable motors).""" + result = await asyncio.to_thread( + api, self._host, self._api_port, "/api/move/play/goto_sleep", "POST", + ) + return {"status": "success", "result": result} + + @rpc() + async def stopMotion(self) -> dict: + """Stop all current motion.""" + result = await asyncio.to_thread( + api, self._host, self._api_port, "/api/move/stop", "POST", + ) + return {"status": "success", "result": result} + + @rpc() + async def getDaemonStatus(self) -> dict: + """Get daemon status, motor state, and control frequency.""" + result = await asyncio.to_thread( + api, self._host, self._api_port, "/api/daemon/status", + ) + return {"status": "success", **result} + + # ── Events ──────────────────────────────────────────────── + + @emit() + async def emergencyStop(self, reason: str = ""): + """Emitted when this device triggers an emergency stop. + + Args: + reason: Why the emergency stop was triggered + """ + pass + + @on(event_name="emergencyStop") + async def onEmergencyStop(self, device_id: str, event_name: str, payload: dict): + """React to emergencyStop — disable motors and stop motion.""" + logger.warning("Emergency stop received from %s — disabling motors", device_id) + await self.stopMotion() + await self.disableMotors() diff --git a/strands_robots/device_connect/reachy_transport.py b/strands_robots/device_connect/reachy_transport.py new file mode 100644 index 0000000..402bac4 --- /dev/null +++ b/strands_robots/device_connect/reachy_transport.py @@ -0,0 +1,163 @@ +"""Shared transport helpers for Reachy Mini robots. + +REST API helpers, pose math, and hardware link abstractions +used by ReachyMiniDriver. +""" + +import asyncio +import json +import logging +import math +import socket +from abc import ABC, abstractmethod +from typing import Callable, Optional + +logger = logging.getLogger(__name__) + + +def resolve_host(host: str) -> str: + """Resolve hostname to IP address.""" + try: + return socket.gethostbyname(host) + except socket.gaierror: + return host + + +# ── REST API ───────────────────────────────────────────────────── + +def api(host: str, port: int, path: str, method: str = "GET", data: Optional[dict] = None) -> dict: + """Call Reachy Mini daemon REST API.""" + import urllib.error + import urllib.request + url = f"http://{host}:{port}{path}" + req = urllib.request.Request(url, method=method) + req.add_header("Content-Type", "application/json") + body = json.dumps(data).encode() if data else None + try: + with urllib.request.urlopen(req, body, timeout=10) as resp: + return json.loads(resp.read().decode()) + except urllib.error.HTTPError as e: + return {"error": e.read().decode(), "code": e.code} + except Exception as e: + return {"error": str(e)} + + +# ── Pose math ──────────────────────────────────────────────────── + +def rpy_to_pose(pitch_deg: float, roll_deg: float, yaw_deg: float, + x_mm: float = 0, y_mm: float = 0, z_mm: float = 0) -> list: + """Convert RPY (degrees) + XYZ (mm) to 4x4 pose matrix.""" + p, r, y = math.radians(pitch_deg), math.radians(roll_deg), math.radians(yaw_deg) + cr, sr = math.cos(r), math.sin(r) + cp, sp = math.cos(p), math.sin(p) + cy, sy = math.cos(y), math.sin(y) + return [ + [cy*cp, cy*sp*sr - sy*cr, cy*sp*cr + sy*sr, x_mm/1000], + [sy*cp, sy*sp*sr + cy*cr, sy*sp*cr - cy*sr, y_mm/1000], + [-sp, cp*sr, cp*cr, z_mm/1000], + [0, 0, 0, 1], + ] + + +def identity_pose() -> list: + """Return a 4x4 identity pose matrix.""" + return [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] + + +# ── Hardware link abstraction ─────────────────────────────────── + +class HardwareLink(ABC): + """Abstract interface for real-time I/O with Reachy Mini hardware.""" + + @abstractmethod + async def start(self, on_joints: Callable, on_imu: Callable) -> None: + """Begin receiving sensor data and enable command sending.""" + + @abstractmethod + async def stop(self) -> None: + """Tear down the connection.""" + + @abstractmethod + async def send_cmd(self, cmd: dict) -> None: + """Send a real-time command to the robot.""" + + +class ZenohLink(HardwareLink): + """Wireless variant — real-time I/O via Device Connect's Zenoh transport.""" + + def __init__(self, transport, prefix: str): + self._transport = transport + self._prefix = prefix + + async def start(self, on_joints: Callable, on_imu: Callable) -> None: + async def _on_joints(data: bytes, _reply=None): + try: + on_joints(json.loads(data.decode())) + except Exception: + pass + + async def _on_imu(data: bytes, _reply=None): + try: + on_imu(json.loads(data.decode())) + except Exception: + pass + + await self._transport.subscribe(f"{self._prefix}/joint_positions", _on_joints) + await self._transport.subscribe(f"{self._prefix}/imu_data", _on_imu) + + async def stop(self) -> None: + pass # Transport teardown handled by DeviceRuntime + + async def send_cmd(self, cmd: dict) -> None: + await self._transport.publish( + f"{self._prefix}/command", json.dumps(cmd).encode() + ) + + +class WebSocketLink(HardwareLink): + """Lite variant — real-time I/O via daemon's WebSocket.""" + + _WS_CMD_MAP = { + "head_pose": lambda c: {"type": "set_target", "head": [v for row in c["head_pose"] for v in row]}, + "antennas_joint_positions": lambda c: {"type": "set_antennas", "antennas": c["antennas_joint_positions"]}, + "body_yaw": lambda c: {"type": "set_body_yaw", "body_yaw": c["body_yaw"]}, + "torque": lambda c: {"type": "set_torque", "on": c["torque"], "ids": c.get("ids")}, + } + + def __init__(self, host: str, port: int): + self._host = host + self._port = port + self._ws = None + self._read_task: Optional[asyncio.Task] = None + + async def start(self, on_joints: Callable, on_imu: Callable) -> None: + import websockets + + self._ws = await websockets.connect(f"ws://{self._host}:{self._port}/ws/sdk") + self._read_task = asyncio.create_task(self._read_loop(on_joints, on_imu)) + + async def _read_loop(self, on_joints: Callable, on_imu: Callable) -> None: + async for raw in self._ws: + try: + msg = json.loads(raw) + t = msg.get("type") + if t == "joint_positions": + on_joints(msg) + elif t == "imu_data": + on_imu(msg) + except Exception: + pass + + async def stop(self) -> None: + if self._read_task: + self._read_task.cancel() + if self._ws: + await self._ws.close() + + async def send_cmd(self, cmd: dict) -> None: + if not self._ws: + return + for key, fn in self._WS_CMD_MAP.items(): + if key in cmd: + await self._ws.send(json.dumps(fn(cmd))) + return diff --git a/strands_robots/device_connect/robot_driver.py b/strands_robots/device_connect/robot_driver.py new file mode 100644 index 0000000..30e7acf --- /dev/null +++ b/strands_robots/device_connect/robot_driver.py @@ -0,0 +1,196 @@ +"""RobotDeviceDriver — Device Connect DeviceDriver adapter wrapping a strands-robots Robot. + +Exposes the Robot's task execution, status, and observation methods as +structured RPCs and events via Device Connect's DeviceDriver interface. +""" + +import asyncio +import logging + +from device_connect_edge.drivers import DeviceDriver, emit, on, periodic, rpc +from device_connect_edge.types import DeviceIdentity, DeviceStatus + +logger = logging.getLogger(__name__) + + +class RobotDeviceDriver(DeviceDriver): + """Device Connect device driver wrapping a strands-robots Robot instance.""" + + device_type = "strands_robot" + + def __init__(self, robot): + super().__init__() + self._robot = robot + + @property + def identity(self) -> DeviceIdentity: + return DeviceIdentity( + device_type="strands_robot", + manufacturer="strands-robots", + model=getattr(self._robot, "tool_name_str", "robot"), + description="Strands Robots LeRobot-based robot arm", + ) + + @property + def status(self) -> DeviceStatus: + task = getattr(self._robot, "_task_state", None) + is_busy = ( + task is not None + and hasattr(task, "status") + and getattr(task.status, "value", "idle") == "running" + ) + return DeviceStatus( + availability="busy" if is_busy else "idle", + busy_score=1.0 if is_busy else 0.0, + ) + + async def connect(self) -> None: + """No-op — the Robot manages its own hardware connection.""" + pass + + async def disconnect(self) -> None: + """No-op — the Robot manages its own hardware shutdown.""" + pass + + # ── RPCs ────────────────────────────────────────────────── + + @rpc() + async def execute( + self, + instruction: str, + policy_provider: str = "mock", + duration: float = 30.0, + policy_port: int = 0, + ) -> dict: + """Execute a VLA task instruction on the robot. + + Args: + instruction: Natural language task instruction + policy_provider: Policy backend (groot, mock, lerobot_local, ...) + duration: Maximum task duration in seconds + policy_port: Policy server port (0 for default) + """ + return self._robot.start_task( + instruction, + policy_provider, + policy_port or None, + "localhost", + duration, + ) + + @rpc() + async def stop(self) -> dict: + """Stop the currently running task.""" + return self._robot.stop_task() + + @rpc() + async def getStatus(self) -> dict: + """Get current task execution status.""" + return self._robot.get_task_status() + + @rpc() + async def getFeatures(self) -> dict: + """Get robot observation and action features.""" + return self._robot.get_features() + + @rpc() + async def getState(self) -> dict: + """Get current robot state (joints, task info). + + Returns joint positions and task state if a task is running. + """ + result = {} + task = getattr(self._robot, "_task_state", None) + if task: + result["task_status"] = getattr(task.status, "value", "unknown") + result["instruction"] = task.instruction + result["step_count"] = task.step_count + + # Try to read observation from the underlying LeRobot robot + inner = getattr(self._robot, "robot", None) + if inner and hasattr(inner, "get_observation"): + try: + obs = await asyncio.to_thread(inner.get_observation) + # Filter out camera frames (numpy arrays) — only include scalars + result["joints"] = { + k: float(v) + for k, v in obs.items() + if not hasattr(v, "shape") + } + except Exception as e: + logger.debug("Could not read observation: %s", e) + + return result + + # ── Events ──────────────────────────────────────────────── + + @emit() + async def taskStarted(self, instruction: str, policy_provider: str): + """Emitted when a VLA task begins execution. + + Args: + instruction: The task instruction + policy_provider: The policy backend used + """ + pass + + @emit() + async def taskComplete(self, instruction: str, steps: int, duration: float): + """Emitted when a VLA task finishes. + + Args: + instruction: The task instruction + steps: Total steps executed + duration: Total execution time in seconds + """ + pass + + @emit() + async def streamStep(self, step: int, observation: dict, action: dict): + """Emitted for each VLA inference step (high frequency). + + Args: + step: Step number + observation: Observation dict (joints only, no camera frames) + action: Action dict + """ + pass + + @emit() + async def emergencyStop(self, reason: str = ""): + """Emitted when this device triggers an emergency stop. + + Args: + reason: Why the emergency stop was triggered + """ + pass + + @on(event_name="emergencyStop") + async def onEmergencyStop(self, device_id: str, event_name: str, payload: dict): + """React to emergencyStop from ANY device on the network.""" + logger.warning("Emergency stop received from %s — stopping task", device_id) + self._robot.stop_task() + + # ── Periodic state publishing ───────────────────────────── + + @periodic(interval=0.1, wait_for_completion=True) + async def _publishState(self): + """Publish robot state at 10Hz.""" + task = getattr(self._robot, "_task_state", None) + if task and getattr(task.status, "value", "idle") == "running": + await self.stateUpdate( + task_status="running", + instruction=task.instruction, + step_count=task.step_count, + ) + + @emit() + async def stateUpdate(self, task_status: str = "", instruction: str = "", step_count: int = 0): + """Periodic state update. + + Args: + task_status: Current task status + instruction: Current task instruction + step_count: Steps completed so far + """ + pass diff --git a/strands_robots/device_connect/setup.sh b/strands_robots/device_connect/setup.sh new file mode 100755 index 0000000..4f73565 --- /dev/null +++ b/strands_robots/device_connect/setup.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# setup.sh — one-command environment setup for Strands Robots + Device Connect +# +# Usage: +# ./strands_robots/device_connect/setup.sh +# +set -euo pipefail + +PYTHON_VERSION="3.12" +VENV_DIR=".venv" +REPO_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" + +echo "============================================================" +echo " Strands Robots — Environment Setup" +echo "============================================================" +echo "" + +# ── 0. Install uv (if needed) ──────────────────────────────────────────────── +if ! command -v uv &>/dev/null; then + echo "[0/2] uv not found — installing..." + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +else + echo "[0/2] uv $(uv --version) ✓" +fi + +# ── 1. Install Python (if needed) ──────────────────────────────────────────── +if ! uv python find "$PYTHON_VERSION" &>/dev/null; then + echo "[1/2] Python $PYTHON_VERSION not found — installing via uv..." + uv python install "$PYTHON_VERSION" +else + echo "[1/2] Python $PYTHON_VERSION ✓" +fi + +# ── 2. Create virtual environment and install ──────────────────────────────── +if [ ! -d "$REPO_ROOT/$VENV_DIR" ]; then + echo "[2/2] Creating virtual environment and installing packages..." + uv venv --python "$PYTHON_VERSION" "$REPO_ROOT/$VENV_DIR" +else + echo "[2/2] Virtual environment exists, installing packages..." +fi + +# shellcheck disable=SC1091 +source "$REPO_ROOT/$VENV_DIR/bin/activate" +uv pip install -e "$REPO_ROOT[sim]" + +echo "" +echo "============================================================" +echo " Setup complete" +echo "============================================================" +echo "" +echo "Activate the environment:" +echo " source $REPO_ROOT/$VENV_DIR/bin/activate" diff --git a/strands_robots/device_connect/sim_driver.py b/strands_robots/device_connect/sim_driver.py new file mode 100644 index 0000000..f4a040f --- /dev/null +++ b/strands_robots/device_connect/sim_driver.py @@ -0,0 +1,230 @@ +"""SimulationDeviceDriver — Device Connect DeviceDriver adapter wrapping a strands-robots Simulation. + +Exposes the Simulation's physics stepping, policy execution, and world +state as structured RPCs and events via Device Connect's DeviceDriver interface. +""" + +import logging + +from device_connect_edge.drivers import DeviceDriver, emit, on, periodic, rpc +from device_connect_edge.types import DeviceIdentity, DeviceStatus + +logger = logging.getLogger(__name__) + + +class SimulationDeviceDriver(DeviceDriver): + """Device Connect device driver wrapping a strands-robots Simulation instance.""" + + device_type = "strands_sim" + + def __init__(self, sim): + super().__init__() + self._sim = sim + + @property + def identity(self) -> DeviceIdentity: + return DeviceIdentity( + device_type="strands_sim", + manufacturer="strands-robots", + model=getattr(self._sim, "tool_name_str", "simulation"), + description="Strands Robots MuJoCo simulation", + ) + + @property + def status(self) -> DeviceStatus: + world = getattr(self._sim, "_world", None) + is_busy = False + if world: + for robot in world.robots.values(): + if getattr(robot, "policy_running", False): + is_busy = True + break + return DeviceStatus( + availability="busy" if is_busy else "idle", + busy_score=1.0 if is_busy else 0.0, + ) + + async def connect(self) -> None: + """No-op — the Simulation manages its own MuJoCo state.""" + pass + + async def disconnect(self) -> None: + """No-op — the Simulation manages its own cleanup.""" + pass + + # ── RPCs ────────────────────────────────────────────────── + + @rpc() + async def execute( + self, + instruction: str, + policy_provider: str = "mock", + duration: float = 30.0, + robot_name: str = "", + ) -> dict: + """Execute a policy on a simulated robot. + + Args: + instruction: Natural language task instruction + policy_provider: Policy backend (mock, lerobot_local, ...) + duration: Maximum task duration in seconds + robot_name: Target robot name (empty = first robot) + """ + # Determine robot name + name = robot_name + if not name: + world = getattr(self._sim, "_world", None) + if world and world.robots: + name = next(iter(world.robots)) + else: + return {"status": "error", "reason": "no robots in simulation"} + + print(f"▶ Executing policy '{policy_provider}' on {name}: {instruction}", flush=True) + return self._sim.start_policy( + robot_name=name, + policy_provider=policy_provider, + instruction=instruction, + duration=duration, + ) + + @rpc() + async def stop(self) -> dict: + """Stop all running policies.""" + print("⏹ Stop command received — stopping all policies", flush=True) + world = getattr(self._sim, "_world", None) + if world: + for robot in world.robots.values(): + robot.policy_running = False + return {"status": "success", "content": [{"text": "All policies stopped"}]} + + @rpc() + async def getStatus(self) -> dict: + """Get simulation state and running policies.""" + if hasattr(self._sim, "get_state"): + return self._sim.get_state() + return {"status": "idle"} + + @rpc() + async def getFeatures(self) -> dict: + """Get simulation features (joints, actuators, cameras).""" + return self._sim.get_features() + + @rpc() + async def step(self, n_steps: int = 1) -> dict: + """Step simulation physics forward. + + Args: + n_steps: Number of physics steps to take + """ + return self._sim.step(n_steps) + + @rpc() + async def reset(self) -> dict: + """Reset simulation to initial state.""" + return self._sim.reset() + + # ── Events ──────────────────────────────────────────────── + + @emit() + async def policyStarted(self, robot_name: str, instruction: str, policy_provider: str): + """Emitted when a policy begins execution. + + Args: + robot_name: The simulated robot running the policy + instruction: The task instruction + policy_provider: The policy backend used + """ + pass + + @emit() + async def policyComplete(self, robot_name: str, instruction: str, steps: int): + """Emitted when a policy finishes. + + Args: + robot_name: The simulated robot + instruction: The task instruction + steps: Total steps executed + """ + pass + + @emit() + async def emergencyStop(self, reason: str = ""): + """Emitted when this device triggers an emergency stop. + + Args: + reason: Why the emergency stop was triggered + """ + pass + + @on(event_name="emergencyStop") + async def onEmergencyStop(self, device_id: str, event_name: str, payload: dict): + """React to emergencyStop from ANY device on the network.""" + print(f"🛑 Emergency stop received from {device_id} — stopping all policies", flush=True) + world = getattr(self._sim, "_world", None) + if world: + for robot in world.robots.values(): + robot.policy_running = False + + # ── Periodic state publishing ───────────────────────────── + + @periodic(interval=0.1, wait_for_completion=True) + async def _publishState(self): + """Publish simulation state at 10Hz.""" + world = getattr(self._sim, "_world", None) + if not world: + return + running = { + name: {"steps": r.policy_steps, "instruction": r.policy_instruction} + for name, r in world.robots.items() + if r.policy_running + } + if running: + await self.stateUpdate( + sim_time=world.sim_time, + step_count=world.step_count, + running_policies=running, + ) + # Publish per-robot joint observations from MuJoCo state + data = getattr(world, "_data", None) + robots = world.robots if isinstance(world.robots, dict) else {} + for name, robot in robots.items(): + try: + joint_names = getattr(robot, "joint_names", []) + joint_ids = getattr(robot, "joint_ids", []) + joints = {} + if data is not None and joint_names and joint_ids: + for jname, jid in zip(joint_names, joint_ids): + joints[jname] = float(data.qpos[jid]) + await self.observationUpdate( + robot_name=name, + sim_time=world.sim_time, + step_count=world.step_count, + joints=joints, + ) + except Exception as e: + logger.debug("observationUpdate skipped for %s: %s", name, e) + + @emit() + async def stateUpdate(self, sim_time: float = 0.0, step_count: int = 0, running_policies: dict = None): + """Periodic simulation state update. + + Args: + sim_time: Current simulation time + step_count: Total physics steps + running_policies: Dict of running policy info per robot + """ + pass + + @emit() + async def observationUpdate( + self, robot_name: str = "", sim_time: float = 0.0, step_count: int = 0, joints: dict = None + ): + """Periodic per-robot observation with joint positions. + + Args: + robot_name: Name of the robot + sim_time: Current simulation time + step_count: Total physics steps + joints: Dict of joint name -> position (radians) + """ + pass diff --git a/strands_robots/device_connect/test_control_loop_dc.sh b/strands_robots/device_connect/test_control_loop_dc.sh new file mode 100755 index 0000000..62eb4b9 --- /dev/null +++ b/strands_robots/device_connect/test_control_loop_dc.sh @@ -0,0 +1,197 @@ +#!/usr/bin/env bash +# test_control_loop_dc.sh — End-to-end test: control loop + Zenoh event listener +# +# Verifies that Robot("so100") publishes Device Connect events over Zenoh +# while a mock-policy control loop is running. +# +# Usage: +# bash strands_robots/device_connect/test_control_loop_dc.sh +# +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" +WORKSPACE_ROOT="$(cd "$REPO_ROOT/.." && pwd)" +export MUJOCO_GL="${MUJOCO_GL:-egl}" +export DEVICE_CONNECT_ALLOW_INSECURE=true + +EVENTS_LOG=$(mktemp /tmp/zenoh_events_XXXX.log) +LOOP_LOG=$(mktemp /tmp/control_loop_XXXX.log) +LISTENER_PID="" + +cleanup() { + [ -n "$LISTENER_PID" ] && kill "$LISTENER_PID" 2>/dev/null || true + echo "" + echo "Logs:" + echo " Events: $EVENTS_LOG" + echo " Control loop: $LOOP_LOG" +} +trap cleanup EXIT + +# ── 1. Install dependencies ──────────────────────────────────────────── +echo "==> Installing device-connect-edge..." +pip install -e "$WORKSPACE_ROOT/device-connect/packages/device-connect-edge" -q + +echo "==> Installing device-connect-agent-tools..." +pip install -e "$WORKSPACE_ROOT/device-connect/packages/device-connect-agent-tools[strands]" -q + +echo "==> Installing strands-robots[sim]..." +pip install -e "$REPO_ROOT[sim]" -q + +echo "==> All dependencies installed." +echo "" + +# ── 2. Start Zenoh listener ──────────────────────────────────────────── +echo "==> Starting Zenoh event listener..." +python3 -c " +import json, time, zenoh + +def on_sample(sample): + try: + data = json.loads(sample.payload.to_bytes().decode()) + except Exception: + data = str(sample.payload.to_bytes().decode()[:200]) + print(f'[{time.strftime(\"%H:%M:%S\")}] {sample.key_expr}: {json.dumps(data, default=str)}', flush=True) + +session = zenoh.open(zenoh.Config()) +sub = session.declare_subscriber('device-connect/**', on_sample) +print('LISTENER_READY', flush=True) +try: + while True: + time.sleep(0.1) +except KeyboardInterrupt: + pass +finally: + sub.undeclare() + session.close() +" > "$EVENTS_LOG" 2>&1 & +LISTENER_PID=$! + +# Wait for listener to be ready +for i in $(seq 1 30); do + grep -q "LISTENER_READY" "$EVENTS_LOG" 2>/dev/null && break + sleep 0.2 +done +echo " Listener PID: $LISTENER_PID" +echo "" + +# ── 3. Run the control loop ──────────────────────────────────────────── +echo "==> Running control loop (200 steps @ 50Hz)..." +python3 -c " +import os, sys, time +os.environ.setdefault('MUJOCO_GL', 'egl') + +from strands_robots.factory import Robot +from strands_robots.policies import create_policy + +robot = Robot('so100') +# Wait for DC runtime to connect and start periodic publishers +time.sleep(3) + +policy = create_policy('mock') +for step in range(200): + obs = robot.get_observation() + action = policy.get_actions_sync(obs, instruction='pick up the cube') + robot.apply_action(action) + if step % 50 == 0: + print(f' Step {step}/200', flush=True) + +print('Control loop done.', flush=True) +robot.cleanup() +print('Cleanup complete.', flush=True) +" 2>&1 | tee "$LOOP_LOG" + +# Give trailing events a moment to arrive +sleep 2 + +# ── 4. Stop the listener ─────────────────────────────────────────────── +kill "$LISTENER_PID" 2>/dev/null || true +wait "$LISTENER_PID" 2>/dev/null || true +LISTENER_PID="" + +# ── 5. Validate captured events ──────────────────────────────────────── +echo "" +echo "============================================================" +echo " ZENOH EVENT SUMMARY" +echo "============================================================" + +TOTAL=$(grep -c '^\[' "$EVENTS_LOG" 2>/dev/null || echo 0) +STATE_UPDATES=$(grep -c 'event/stateUpdate' "$EVENTS_LOG" 2>/dev/null || echo 0) +OBS_UPDATES=$(grep -c 'event/observationUpdate' "$EVENTS_LOG" 2>/dev/null || echo 0) +PRESENCE=$(grep -c '/presence' "$EVENTS_LOG" 2>/dev/null || echo 0) +HEARTBEATS=$(grep -c '/heartbeat' "$EVENTS_LOG" 2>/dev/null || echo 0) + +printf " %-25s %s\n" "stateUpdate events:" "$STATE_UPDATES" +printf " %-25s %s\n" "observationUpdate events:" "$OBS_UPDATES" +printf " %-25s %s\n" "presence events:" "$PRESENCE" +printf " %-25s %s\n" "heartbeat events:" "$HEARTBEATS" +printf " %-25s %s\n" "TOTAL:" "$TOTAL" +echo "" + +# Show a sample observationUpdate with joint data +SAMPLE_OBS=$(grep 'event/observationUpdate' "$EVENTS_LOG" | tail -1) +if [ -n "$SAMPLE_OBS" ]; then + echo " Sample observationUpdate:" + echo " $SAMPLE_OBS" | python3 -c " +import sys, json +line = sys.stdin.read().strip() +payload = line.split(': ', 1)[1] +data = json.loads(payload) +params = data.get('params', {}) +print(f\" robot: {params.get('robot_name')}\") +print(f\" sim_time: {params.get('sim_time')}\") +print(f\" step: {params.get('step_count')}\") +joints = params.get('joints', {}) +for name, val in joints.items(): + print(f\" {name:>15s}: {val:+.6f} rad\") +" 2>/dev/null || echo " (could not parse sample)" + echo "" +fi + +# ── 6. Assert minimum thresholds ─────────────────────────────────────── +PASS=true + +if [ "$TOTAL" -lt 10 ]; then + echo "FAIL: Expected >= 10 total events, got $TOTAL" + PASS=false +fi + +if [ "$STATE_UPDATES" -lt 5 ]; then + echo "FAIL: Expected >= 5 stateUpdate events, got $STATE_UPDATES" + PASS=false +fi + +if [ "$PRESENCE" -lt 1 ]; then + echo "FAIL: Expected >= 1 presence event, got $PRESENCE" + PASS=false +fi + +if [ "$HEARTBEATS" -lt 1 ]; then + echo "FAIL: Expected >= 1 heartbeat event, got $HEARTBEATS" + PASS=false +fi + +if [ "$OBS_UPDATES" -lt 5 ]; then + echo "FAIL: Expected >= 5 observationUpdate events, got $OBS_UPDATES" + PASS=false +fi + +# Check no "Failed to publish" in control loop output +PUBLISH_ERRORS=$(grep -c "Failed to publish" "$LOOP_LOG" 2>/dev/null || true) +PUBLISH_ERRORS="${PUBLISH_ERRORS:-0}" +if [ "$PUBLISH_ERRORS" -gt 0 ]; then + echo "FAIL: Found $PUBLISH_ERRORS 'Failed to publish' errors (missing cleanup?)" + PASS=false +fi + +if [ "$PASS" = true ]; then + echo "============================================================" + echo " ALL CHECKS PASSED" + echo "============================================================" + exit 0 +else + echo "" + echo "============================================================" + echo " SOME CHECKS FAILED — see logs above" + echo "============================================================" + exit 1 +fi diff --git a/strands_robots/factory.py b/strands_robots/factory.py index 24ff849..0412d0f 100644 --- a/strands_robots/factory.py +++ b/strands_robots/factory.py @@ -7,6 +7,7 @@ import logging import os +import time from typing import Any, Dict, List, Optional from strands_robots.registry import get_hardware_type, has_hardware, resolve_name @@ -94,6 +95,7 @@ def Robot( mode = _auto_detect_mode(canonical) # ── Simulation backends ── + instance = None if mode == "sim": if backend == "isaac": from strands_robots.isaac.isaac_sim_backend import ( @@ -105,16 +107,15 @@ def Robot( num_envs=num_envs, device=kwargs.pop("device", "cuda:0"), ) - isaac_backend = IsaacSimBackend(config=config) - isaac_backend.create_world() - result = isaac_backend.add_robot( + instance = IsaacSimBackend(config=config) + instance.create_world() + result = instance.add_robot( name=canonical, data_config=canonical, position=position or [0.0, 0.0, 0.0], ) if result.get("status") == "error": raise RuntimeError(f"Failed to create Isaac robot '{canonical}': {result}") - return isaac_backend elif backend == "newton": from strands_robots.newton.newton_backend import NewtonBackend, NewtonConfig @@ -130,9 +131,9 @@ def Robot( substeps=kwargs.pop("substeps", 1), physics_dt=kwargs.pop("physics_dt", 1.0 / 200.0), ) - newton_backend = NewtonBackend(config=config) - newton_backend.create_world() - result = newton_backend.add_robot( + instance = NewtonBackend(config=config) + instance.create_world() + result = instance.add_robot( name=canonical, data_config=canonical, position=tuple(position) if position else (0.0, 0.0, 0.0), @@ -140,34 +141,31 @@ def Robot( if result.get("status") == "error": raise RuntimeError(f"Failed to create Newton robot '{canonical}': {result.get('message', result)}") if num_envs > 1: - newton_backend.replicate(num_envs=num_envs) - return newton_backend + instance.replicate(num_envs=num_envs) else: # MuJoCo CPU backend (default) from strands_robots.simulation import Simulation sim_name = canonical - sim = Simulation(tool_name=f"{canonical}_sim", mesh=mesh, peer_id=peer_id, **kwargs) - sim._dispatch_action("create_world", {}) - result = sim._dispatch_action( - "add_robot", - { - "robot_name": canonical, - "data_config": sim_name, - "position": position or [0.0, 0.0, 0.0], - }, + instance = Simulation( + tool_name=f"{canonical}_sim", mesh=mesh, peer_id=peer_id, **kwargs ) + instance._dispatch_action("create_world", {}) + result = instance._dispatch_action("add_robot", { + "robot_name": canonical, + "data_config": sim_name, + "position": position or [0.0, 0.0, 0.0], + }) if result.get("status") == "error": raise RuntimeError(f"Failed to create sim robot '{canonical}': {result}") - return sim # ── Real hardware ── else: from strands_robots.robot import Robot as HardwareRobot real_type = get_hardware_type(canonical) or canonical - return HardwareRobot( + instance = HardwareRobot( tool_name=canonical, robot=real_type, cameras=cameras, @@ -176,6 +174,53 @@ def Robot( **kwargs, ) + # Store metadata for .run() + instance._peer_id = peer_id or f"{canonical}-{os.urandom(3).hex()}" + instance._peer_type = "sim" if mode == "sim" else "robot" + instance._device_connect_runtime = None + + # Attach .run() method for foreground server mode + instance.run = lambda: _run_foreground(instance) + + return instance + + +def _run_foreground(instance): + """Start Device Connect and block — robot listens for commands. + + Call this to keep the process alive as a server. Ctrl+C to stop. + + Usage: + r = Robot("so100") + r.run() # blocks here, listening for commands + """ + import signal + import threading + + peer_id = getattr(instance, "_peer_id", "robot") + peer_type = getattr(instance, "_peer_type", "robot") + + # Init Device Connect + try: + from strands_robots.device_connect import init_device_connect_sync + + instance._device_connect_runtime = init_device_connect_sync( + instance, peer_id=peer_id, peer_type=peer_type, + ) + except Exception as e: + logger.warning("Device Connect init failed: %s", e) + if hasattr(instance, "_init_mesh_fallback"): + instance._init_mesh_fallback() + + print(f"🤖 {peer_id} is online. Ctrl+C to stop.") + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print(f"\n🛑 Shutting down {peer_id}...", flush=True) + print(f"👋 {peer_id} stopped.", flush=True) + os._exit(0) + def list_robots(mode: str = "all") -> List[Dict[str, Any]]: """List available robots. diff --git a/strands_robots/robot.py b/strands_robots/robot.py index a0c07eb..833fe41 100644 --- a/strands_robots/robot.py +++ b/strands_robots/robot.py @@ -15,6 +15,7 @@ import asyncio import logging +import os import threading import time from concurrent.futures import Future, ThreadPoolExecutor @@ -157,14 +158,31 @@ def __init__( if data_config: logger.info("⚙️ Data config: %s", data_config) - # Zenoh mesh — every Robot is a peer by default + # Device Connect / mesh init is deferred to the factory + self._device_connect_runtime = None + self._peer_id = peer_id + self._mesh_enabled = mesh + self.mesh = None + + def _init_mesh_fallback(self): + """Fallback: init Zenoh mesh when device-connect-edge is not installed.""" try: from strands_robots.zenoh_mesh import init_mesh - self.mesh = init_mesh(self, peer_id=peer_id, peer_type="robot", mesh=mesh) + self.mesh = init_mesh(self, peer_id=self._peer_id, peer_type="robot", mesh=self._mesh_enabled) except Exception as e: logger.debug("Mesh init skipped: %s", e) - self.mesh = None + + async def _serve(self): + """Initialize Device Connect and block — keeps robot discoverable.""" + from strands_robots.device_connect import init_device_connect + + peer_id = self._peer_id or f"{self.tool_name_str}-{os.urandom(3).hex()}" + self._device_connect_runtime = await init_device_connect( + self, peer_id=peer_id, peer_type="robot" + ) + print(f"Robot '{self.tool_name_str}' running — discoverable as '{peer_id}' via Device Connect.") + await self._device_connect_runtime._background_task def _initialize_robot( self, @@ -1358,6 +1376,13 @@ async def stream( def cleanup(self): """Cleanup resources and stop any running tasks.""" try: + # Stop Device Connect runtime + if hasattr(self, "_device_connect_runtime") and self._device_connect_runtime: + try: + self._device_connect_runtime.stop() + except Exception: + pass + # Stop mesh if hasattr(self, "mesh") and self.mesh: self.mesh.stop() diff --git a/strands_robots/tools/robot_mesh.py b/strands_robots/tools/robot_mesh.py index 750f7dc..718778c 100644 --- a/strands_robots/tools/robot_mesh.py +++ b/strands_robots/tools/robot_mesh.py @@ -1,17 +1,49 @@ -"""Robot Mesh Tool — agent-facing tool for robot mesh coordination. +#!/usr/bin/env python3 +""" +Robot Mesh Tool — agent-facing tool for robot mesh coordination. -Since every Robot() is already on the mesh, this tool just provides -the agent interface for discovery, messaging, and coordination. +Uses Device Connect's discover_devices and invoke_device for registry-based +discovery and structured RPC invocation. Falls back to the plain Zenoh +mesh if Device Connect is not available. """ import json import logging +import os from typing import Any, Dict from strands import tool logger = logging.getLogger(__name__) +_dc_connected = False + + +class _ToolResult(dict): + """Dict that prints its text content cleanly.""" + + def __str__(self): + content = self.get("content", []) + if content and isinstance(content[0], dict): + return content[0].get("text", super().__str__()) + return super().__str__() + + +def _ensure_connected(): + """Ensure Device Connect agent-side connection is established.""" + global _dc_connected + if _dc_connected: + return + # Set P2P defaults before import (MessagingConfig reads env at import time) + os.environ.setdefault("MESSAGING_BACKEND", "zenoh") + os.environ.setdefault("DEVICE_CONNECT_ALLOW_INSECURE", "true") + from device_connect_agent_tools.connection import connect, get_connection + try: + get_connection() + except Exception: + connect() + _dc_connected = True + @tool def robot_mesh( @@ -24,9 +56,9 @@ def robot_mesh( duration: float = 30.0, timeout: float = 30.0, ) -> Dict[str, Any]: - """Robot mesh — discover and coordinate all robots on the Zenoh network. + """Robot mesh — discover and coordinate all robots on the network. - Every Robot() is automatically a mesh peer. This tool lets you see them, + Every Robot() is automatically a network peer. This tool lets you see them, talk to them, and coordinate them. Also sees Reachy Mini. Args: @@ -37,11 +69,11 @@ def robot_mesh( - "broadcast": Send command to ALL peers - "stop": Stop a specific peer's task - "emergency_stop": E-STOP all robots - - "status": Mesh overview + - "status": Network overview - "subscribe": Subscribe to any Zenoh topic (target=topic pattern) - "watch": Watch a robot's VLA execution stream (target=peer_id) - "inbox": Read buffered messages from subscriptions - target: Peer ID for tell/send/stop, or topic for subscribe + target: Device ID for tell/send/stop, or topic for subscribe instruction: Natural language instruction for tell command: JSON command string for send/broadcast policy_provider: Policy for tell (groot, mock, lerobot_local, ...) @@ -54,10 +86,146 @@ def robot_mesh( Examples: robot_mesh(action="peers") - robot_mesh(action="tell", target="so100_sim-a1b2", instruction="pick up the cube") - robot_mesh(action="send", target="Mac-fc0610", command='{"action": "status"}') + robot_mesh(action="tell", target="so100-lab-1", instruction="pick up the cube") robot_mesh(action="emergency_stop") """ + try: + _ensure_connected() + return _device_connect_dispatch(action, target, instruction, command, + policy_provider, policy_port, duration, timeout) + except Exception as e: + logger.debug(f"Device Connect dispatch failed, falling back to Zenoh mesh: {e}") + return _mesh_dispatch(action, target, instruction, command, + policy_provider, policy_port, duration, timeout) + + +def _device_connect_dispatch(action, target, instruction, command, + policy_provider, policy_port, duration, timeout): + """Dispatch using Device Connect's discover_devices + invoke_device.""" + try: + from device_connect_agent_tools.connection import get_connection + conn = get_connection() + + if action == "peers": + devices = conn.list_devices() + text = f"Discovered {len(devices)} device(s):\n" + for d in devices: + dtype = d.get("device_type", "?") + icon = {"strands_robot": "robot", "strands_sim": "sim", + "reachy_mini": "reachy"}.get(dtype, dtype) + status = d.get("status", {}) + avail = status.get("availability", "?") if isinstance(status, dict) else "?" + text += f" [{icon}] {d['device_id']} — {avail}\n" + funcs = d.get("functions", []) + if funcs: + names = [f["name"] if isinstance(f, dict) else f for f in funcs] + text += f" Functions: {', '.join(names)}\n" + return _ToolResult({"status": "success", "content": [{"text": text}]}) + + elif action == "tell": + if not target or not instruction: + return _ToolResult({"status": "error", "content": [{"text": "target and instruction required"}]}) + params = { + "instruction": instruction, + "policy_provider": policy_provider, + "duration": duration, + } + if policy_port: + params["policy_port"] = policy_port + result = conn.invoke(target, "execute", params, timeout=timeout) + r = result.get("result", result) + return _ToolResult({"status": "success", "content": [{"text": f"-> {target}: {instruction}\n {json.dumps(r, default=str)}"}]}) + + elif action == "send": + if not target: + return _ToolResult({"status": "error", "content": [{"text": "target required"}]}) + func = "getStatus" + params = {} + if command: + cmd = json.loads(command) + func = cmd.pop("action", cmd.pop("function", "getStatus")) + params = cmd + result = conn.invoke(target, func, params, timeout=timeout) + r = result.get("result", result) + return _ToolResult({"status": "success", "content": [{"text": f"{target}:\n{json.dumps(r, indent=2, default=str)[:2000]}"}]}) + + elif action == "stop": + if not target: + return _ToolResult({"status": "error", "content": [{"text": "target required"}]}) + result = conn.invoke(target, "stop", timeout=5.0) + r = result.get("result", result) + return _ToolResult({"status": "success", "content": [{"text": f"Stop {target}: {json.dumps(r, default=str)}"}]}) + + elif action == "emergency_stop": + devices = conn.list_devices() + stopped = 0 + for d in devices: + try: + conn.invoke(d["device_id"], "stop", timeout=3.0) + stopped += 1 + except Exception: + pass + return _ToolResult({"status": "success", "content": [{"text": f"E-STOP: {stopped}/{len(devices)} devices stopped"}]}) + + elif action == "broadcast": + func = "getStatus" + params = {} + if command: + cmd = json.loads(command) + func = cmd.pop("action", cmd.pop("function", "getStatus")) + params = cmd + results = conn.broadcast(func, params, timeout=timeout) + text = f"Broadcast '{func}' -> {len(results)} response(s):\n" + for r in results: + status_str = "ok" if "result" in r else f"error: {r.get('error', '?')}" + text += f" {r['device_id']}: {status_str}\n" + return _ToolResult({"status": "success", "content": [{"text": text}]}) + + elif action == "subscribe": + if not target: + return _ToolResult({"status": "error", "content": [{"text": "target (subject pattern) required. E.g. 'device-connect.default.*.event.>'"}]}) + name = conn.subscribe_buffered(target) + return _ToolResult({"status": "success", "content": [{"text": f"Subscribed to: {target}\nMessages buffered in inbox['{name}']\nUse action='inbox' to read."}]}) + + elif action == "watch": + if not target: + return _ToolResult({"status": "error", "content": [{"text": "target (device_id) required"}]}) + subject = f"device-connect.{conn.zone}.{target}.event.>" + name = conn.subscribe_buffered(subject, name=f"stream:{target}") + return _ToolResult({"status": "success", "content": [{"text": f"Watching events from: {target}\nMessages in inbox['{name}']"}]}) + + elif action == "inbox": + inbox = conn.get_inbox() + if not inbox: + return _ToolResult({"status": "success", "content": [{"text": "No subscriptions active"}]}) + text = f"Inbox ({len(inbox)} subscription(s)):\n" + for name, msgs in inbox.items(): + text += f" {name}: {len(msgs)} messages\n" + if msgs: + last = msgs[-1] + # Messages are (subject, data) tuples — matching Zenoh mesh format + text += f" Latest: {json.dumps(last[1], default=str)[:200]}\n" + return _ToolResult({"status": "success", "content": [{"text": text}]}) + + elif action == "status": + devices = conn.list_devices() + text = f"Network: {len(devices)} device(s)\n" + for d in devices: + status = d.get("status", {}) + avail = status.get("availability", "?") if isinstance(status, dict) else "?" + text += f" {d['device_id']} ({d.get('device_type', '?')}) — {avail}\n" + return _ToolResult({"status": "success", "content": [{"text": text}]}) + + else: + return _ToolResult({"status": "error", "content": [{"text": f"Unknown action: {action}. Try: peers, tell, send, broadcast, stop, emergency_stop, status, subscribe, watch, inbox"}]}) + + except Exception as e: + return _ToolResult({"status": "error", "content": [{"text": f"Error: {e}"}]}) + + +def _mesh_dispatch(action, target, instruction, command, + policy_provider, policy_port, duration, timeout): + """Fallback: dispatch using plain Zenoh mesh.""" try: from strands_robots.zenoh_mesh import _LOCAL_ROBOTS, get_peers @@ -76,97 +244,60 @@ def robot_mesh( text += "**Discovered peers:**\n" for p in peers: icon = {"robot": "🤖", "sim": "🎮", "agent": "🧠"}.get(p.get("type", ""), "🔧") - text += f" {icon} {p['peer_id']} ({p.get('type', '?')}) — {p.get('hostname', '?')}, {p.get('age', 0)}s ago\n" + text += f" {icon} {p['peer_id']} ({p.get('type','?')}) — {p.get('hostname','?')}, {p.get('age',0)}s ago\n" if p.get("task_status"): text += f" Task: {p['task_status']} — {p.get('instruction', '')}\n" elif not local: text += "No peers. Create a Robot() — it auto-joins the mesh.\n" - return {"status": "success", "content": [{"text": text}]} + return _ToolResult({"status": "success", "content": [{"text": text}]}) elif action == "tell": if not target or not instruction: - return { - "status": "error", - "content": [{"text": "target and instruction required"}], - } + return _ToolResult({"status": "error", "content": [{"text": "target and instruction required"}]}) mesh = _any_mesh() if not mesh: - return { - "status": "error", - "content": [{"text": "No local robots on mesh"}], - } - cmd = { - "action": "execute", - "instruction": instruction, - "policy_provider": policy_provider, - "duration": duration, - } + return _ToolResult({"status": "error", "content": [{"text": "No local robots on mesh"}]}) + cmd = {"action": "execute", "instruction": instruction, + "policy_provider": policy_provider, "duration": duration} if policy_port: cmd["policy_port"] = policy_port r = mesh.send(target, cmd, timeout=timeout) - return { - "status": "success", - "content": [{"text": f"📨 → {target}: {instruction}\n\n{json.dumps(r, indent=2, default=str)[:2000]}"}], - } + return _ToolResult({"status": "success", "content": [{"text": f"📨 → {target}: {instruction}\n\n{json.dumps(r, indent=2, default=str)[:2000]}"}]}) elif action == "send": if not target: - return {"status": "error", "content": [{"text": "target required"}]} + return _ToolResult({"status": "error", "content": [{"text": "target required"}]}) mesh = _any_mesh() if not mesh: - return { - "status": "error", - "content": [{"text": "No local robots on mesh"}], - } + return _ToolResult({"status": "error", "content": [{"text": "No local robots on mesh"}]}) cmd = json.loads(command) if command else {"action": "status"} r = mesh.send(target, cmd, timeout=timeout) - return { - "status": "success", - "content": [{"text": f"📨 {target}:\n{json.dumps(r, indent=2, default=str)[:2000]}"}], - } + return _ToolResult({"status": "success", "content": [{"text": f"📨 {target}:\n{json.dumps(r, indent=2, default=str)[:2000]}"}]}) elif action == "broadcast": mesh = _any_mesh() if not mesh: - return { - "status": "error", - "content": [{"text": "No local robots on mesh"}], - } + return _ToolResult({"status": "error", "content": [{"text": "No local robots on mesh"}]}) cmd = json.loads(command) if command else {"action": "status"} rs = mesh.broadcast(cmd, timeout=timeout) - return { - "status": "success", - "content": [{"text": f"📢 {len(rs)} responses:\n{json.dumps(rs, indent=2, default=str)[:3000]}"}], - } + return _ToolResult({"status": "success", "content": [{"text": f"📢 {len(rs)} responses:\n{json.dumps(rs, indent=2, default=str)[:3000]}"}]}) elif action == "stop": if not target: - return {"status": "error", "content": [{"text": "target required"}]} + return _ToolResult({"status": "error", "content": [{"text": "target required"}]}) mesh = _any_mesh() if not mesh: - return { - "status": "error", - "content": [{"text": "No local robots on mesh"}], - } + return _ToolResult({"status": "error", "content": [{"text": "No local robots on mesh"}]}) r = mesh.send(target, {"action": "stop"}, timeout=5.0) - return { - "status": "success", - "content": [{"text": f"🛑 {target}: {json.dumps(r, default=str)}"}], - } + return _ToolResult({"status": "success", "content": [{"text": f"🛑 {target}: {json.dumps(r, default=str)}"}]}) elif action == "emergency_stop": mesh = _any_mesh() if not mesh: - return { - "status": "error", - "content": [{"text": "No local robots on mesh"}], - } + return _ToolResult({"status": "error", "content": [{"text": "No local robots on mesh"}]}) rs = mesh.emergency_stop() - return { - "status": "success", - "content": [{"text": f"🚨 E-STOP → {len(rs)} responses"}], - } + return _ToolResult({"status": "success", "content": [{"text": f"🚨 E-STOP → {len(rs)} responses"}]}) elif action == "status": local = list(_LOCAL_ROBOTS.keys()) @@ -175,83 +306,48 @@ def robot_mesh( for rid in local: m = _LOCAL_ROBOTS[rid] text += f" • {rid} ({m.peer_type}) alive={m.alive}\n" - return {"status": "success", "content": [{"text": text}]} + return _ToolResult({"status": "success", "content": [{"text": text}]}) elif action == "subscribe": if not target: - return { - "status": "error", - "content": [ - {"text": "target (topic pattern) required. E.g. 'reachy_mini/*' or '*/joint_positions'"} - ], - } + return _ToolResult({"status": "error", "content": [{"text": "target (topic pattern) required. E.g. 'reachy_mini/*' or '*/joint_positions'"}]}) mesh = _any_mesh() if not mesh: - return { - "status": "error", - "content": [{"text": "No local robots on mesh"}], - } + return _ToolResult({"status": "error", "content": [{"text": "No local robots on mesh"}]}) name = mesh.subscribe(target) - return { - "status": "success", - "content": [ - { - "text": f"📡 Subscribed to: {target}\nMessages buffered in inbox['{name}']\nUse action='inbox' to read." - } - ], - } + return _ToolResult({"status": "success", "content": [{"text": f"📡 Subscribed to: {target}\nMessages buffered in inbox['{name}']\nUse action='inbox' to read."}]}) elif action == "watch": if not target: - return { - "status": "error", - "content": [{"text": "target (peer_id) required"}], - } + return _ToolResult({"status": "error", "content": [{"text": "target (peer_id) required"}]}) mesh = _any_mesh() if not mesh: - return { - "status": "error", - "content": [{"text": "No local robots on mesh"}], - } + return _ToolResult({"status": "error", "content": [{"text": "No local robots on mesh"}]}) name = mesh.on_stream(target) - return { - "status": "success", - "content": [{"text": f"👁️ Watching VLA stream from: {target}\nMessages in inbox['{name}']"}], - } + return _ToolResult({"status": "success", "content": [{"text": f"👁️ Watching VLA stream from: {target}\nMessages in inbox['{name}']"}]}) elif action == "inbox": mesh = _any_mesh() - if not mesh or not hasattr(mesh, "inbox"): - return { - "status": "success", - "content": [{"text": "No subscriptions active"}], - } + if not mesh or not hasattr(mesh, 'inbox'): + return _ToolResult({"status": "success", "content": [{"text": "No subscriptions active"}]}) text = f"📬 Inbox ({len(mesh.inbox)} subscriptions):\n" for name, msgs in mesh.inbox.items(): text += f" • {name}: {len(msgs)} messages\n" if msgs: last = msgs[-1] text += f" Latest: {json.dumps(last[1], default=str)[:200]}\n" - return {"status": "success", "content": [{"text": text}]} + return _ToolResult({"status": "success", "content": [{"text": text}]}) else: - return { - "status": "error", - "content": [ - { - "text": f"Unknown action: {action}. Try: peers, tell, send, broadcast, stop, emergency_stop, status, subscribe, watch, inbox" - } - ], - } + return _ToolResult({"status": "error", "content": [{"text": f"Unknown action: {action}. Try: peers, tell, send, broadcast, stop, emergency_stop, status, subscribe, watch, inbox"}]}) except ImportError as e: - return {"status": "error", "content": [{"text": f"Mesh unavailable: {e}"}]} + return _ToolResult({"status": "error", "content": [{"text": f"Mesh unavailable: {e}"}]}) except Exception as e: - return {"status": "error", "content": [{"text": f"Error: {e}"}]} + return _ToolResult({"status": "error", "content": [{"text": f"Error: {e}"}]}) def _any_mesh(): """Get any local mesh instance.""" from strands_robots.zenoh_mesh import _LOCAL_ROBOTS - return next(iter(_LOCAL_ROBOTS.values()), None) if _LOCAL_ROBOTS else None diff --git a/tests/test_device_connect_all_robots.py b/tests/test_device_connect_all_robots.py new file mode 100644 index 0000000..7b646a1 --- /dev/null +++ b/tests/test_device_connect_all_robots.py @@ -0,0 +1,704 @@ +"""Parametrized Device Connect tests across all 38 registered robots. + +Validates that RobotDeviceDriver and SimulationDeviceDriver work correctly +with every robot's specific configuration (joint counts, observation shapes, +identity, status, RPC delegation). Also tests multi-robot simulation scenarios, +edge cases, and robot_mesh dispatch with diverse device types. + +All external dependencies (Zenoh, LeRobot, device_connect_edge, strands) are mocked. +No Docker, GPU, or hardware required. +""" + +import asyncio +import json +import pathlib +import sys +from dataclasses import dataclass +from enum import Enum +from unittest.mock import MagicMock, patch + +import pytest + +# ── Mock heavy dependencies before importing ────────────────────── + +mock_device_connect_edge = MagicMock() +mock_drivers = MagicMock() + + +class _FakeDeviceDriver: + """Minimal stub so our drivers can subclass it.""" + + device_type = None + + def __init__(self): + self._transport = None + + def set_device(self, device): + pass + + @property + def transport(self): + return self._transport + + +def _passthrough_decorator(*args, **kwargs): + if len(args) == 1 and callable(args[0]): + return args[0] + + def wrapper(func): + for k, v in kwargs.items(): + setattr(func, f"_{k}", v) + return func + + return wrapper + + +mock_drivers.DeviceDriver = _FakeDeviceDriver +mock_drivers.rpc = _passthrough_decorator +mock_drivers.emit = _passthrough_decorator +mock_drivers.periodic = _passthrough_decorator +mock_drivers.on = _passthrough_decorator + +mock_types = MagicMock() + + +@dataclass +class FakeDeviceIdentity: + device_type: str = None + manufacturer: str = None + model: str = None + description: str = None + serial_number: str = None + firmware_version: str = None + arch: str = None + commissioning_comment: str = None + + +@dataclass +class FakeDeviceStatus: + availability: str = "idle" + busy_score: float = 0.0 + location: str = None + battery: int = None + online: bool = True + error_state: str = None + + +mock_types.DeviceIdentity = FakeDeviceIdentity +mock_types.DeviceStatus = FakeDeviceStatus + +_saved_modules = {} +_mock_keys = ( + "device_connect_edge", + "device_connect_edge.drivers", + "device_connect_edge.types", + "device_connect_edge.device", +) +_strands_dc_keys = [k for k in sys.modules if k.startswith("strands_robots.device_connect")] +for _key in list(_mock_keys) + _strands_dc_keys: + _saved_modules[_key] = sys.modules.get(_key) + +sys.modules["device_connect_edge"] = mock_device_connect_edge +sys.modules["device_connect_edge.drivers"] = mock_drivers +sys.modules["device_connect_edge.types"] = mock_types +sys.modules["device_connect_edge.device"] = MagicMock() + +mock_device_runtime = MagicMock() +mock_device_connect_edge.DeviceRuntime = mock_device_runtime + +from strands_robots.device_connect.robot_driver import RobotDeviceDriver # noqa: E402 +from strands_robots.device_connect.sim_driver import SimulationDeviceDriver # noqa: E402 + + +def teardown_module(): + """Restore real device_connect_edge modules.""" + for key, original in _saved_modules.items(): + if original is None: + sys.modules.pop(key, None) + else: + sys.modules[key] = original + for key in list(sys.modules): + if key.startswith("strands_robots.device_connect"): + sys.modules.pop(key, None) + + +# ── Load robot registry ────────────────────────────────────────── + +_REGISTRY_PATH = pathlib.Path(__file__).resolve().parents[1] / "strands_robots" / "registry" / "robots.json" +_REGISTRY = json.loads(_REGISTRY_PATH.read_text())["robots"] + +ALL_ROBOTS = [(name, info) for name, info in _REGISTRY.items()] +SIM_ROBOTS = [(name, info) for name, info in ALL_ROBOTS if "asset" in info] +REAL_ONLY_ROBOTS = [(name, info) for name, info in ALL_ROBOTS if "asset" not in info] + + +# ── Task state mocks ───────────────────────────────────────────── + + +class FakeTaskStatus(Enum): + IDLE = "idle" + RUNNING = "running" + COMPLETED = "completed" + STOPPED = "stopped" + ERROR = "error" + + +@dataclass +class FakeTaskState: + status: FakeTaskStatus = FakeTaskStatus.IDLE + instruction: str = "" + start_time: float = 0.0 + duration: float = 0.0 + step_count: int = 0 + error_message: str = "" + + +# ── Observation helper ─────────────────────────────────────────── + + +class _FakeArray: + """Mimics a numpy array with a .shape attribute.""" + + def __init__(self, shape): + self.shape = shape + + +def _generate_observation(joint_count, include_arrays=True): + """Generate a realistic observation dict for a robot with N joints.""" + obs = {} + for i in range(joint_count): + obs[f"joint_{i}"] = float(i) * 0.1 + if include_arrays: + obs["image"] = _FakeArray((480, 640, 3)) + obs["depth"] = _FakeArray((480, 640)) + return obs + + +# ── Mock factories ─────────────────────────────────────────────── + + +def _get_joint_count(info): + """Get joint count from registry info, defaulting to 6 for real-only robots.""" + return info.get("joints", 6) + + +def _make_mock_robot(name, info, task_status="idle"): + """Create a mock robot matching the registry entry's configuration.""" + joint_count = _get_joint_count(info) + robot = MagicMock() + robot.tool_name_str = name + robot._task_state = FakeTaskState( + status=FakeTaskStatus(task_status), + instruction="pick up the cube" if task_status == "running" else "", + step_count=42 if task_status == "running" else 0, + ) + robot.start_task.return_value = {"status": "success", "content": [{"text": "Task started"}]} + robot.stop_task.return_value = {"status": "success", "content": [{"text": "Task stopped"}]} + robot.get_task_status.return_value = {"status": "success", "content": [{"text": "Status info"}]} + + features = {f"joint_{i}": "float" for i in range(joint_count)} + robot.get_features.return_value = { + "status": "success", + "content": [{"json": {"observation_features": features, "action_features": features}}], + } + + robot.robot = MagicMock() + robot.robot.get_observation.return_value = _generate_observation(joint_count) + return robot + + +def _make_mock_sim(name, info, robots_in_world=None): + """Create a mock simulation matching the registry entry's configuration.""" + sim = MagicMock() + sim.tool_name_str = f"{name}_sim" + + world = MagicMock() + if robots_in_world is None: + robot_data = MagicMock() + robot_data.policy_running = False + robot_data.policy_steps = 0 + robot_data.policy_instruction = "" + world.robots = {name: robot_data} + else: + world.robots = robots_in_world + world.sim_time = 0.0 + world.step_count = 0 + sim._world = world + + sim.start_policy.return_value = {"status": "success", "content": [{"text": "Policy started"}]} + sim.get_state.return_value = {"status": "success", "content": [{"text": "State info"}]} + sim.get_features.return_value = {"status": "success", "content": [{"json": {"features": {}}}]} + sim.step.return_value = {"status": "success", "content": [{"text": "Stepped"}]} + sim.reset.return_value = {"status": "success", "content": [{"text": "Reset"}]} + return sim + + +# ── TestRobotDriverAllRobots ───────────────────────────────────── + + +class TestRobotDriverAllRobots: + """Parametrized tests for RobotDeviceDriver across all 38 registered robots.""" + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_identity(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + assert driver.identity.device_type == "strands_robot" + assert driver.identity.model == robot_name + assert driver.identity.manufacturer == "strands-robots" + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_status_idle(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info, task_status="idle") + driver = RobotDeviceDriver(robot) + assert driver.status.availability == "idle" + assert driver.status.busy_score == 0.0 + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_status_busy(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info, task_status="running") + driver = RobotDeviceDriver(robot) + assert driver.status.availability == "busy" + assert driver.status.busy_score == 1.0 + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_execute_delegates(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + result = asyncio.run( + driver.execute("pick up cube", "groot", 30.0, 0) + ) + robot.start_task.assert_called_once_with("pick up cube", "groot", None, "localhost", 30.0) + assert result["status"] == "success" + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_get_state_joint_count(self, robot_name, robot_info): + joint_count = _get_joint_count(robot_info) + robot = _make_mock_robot(robot_name, robot_info, task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert "joints" in result + assert len(result["joints"]) == joint_count + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_get_state_filters_arrays(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + if "joints" in result: + for key, value in result["joints"].items(): + assert isinstance(value, float), f"Non-float value for {key}: {type(value)}" + assert not key.startswith("image") and not key.startswith("depth") + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_get_state_task_info(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info, task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result["task_status"] == "running" + assert result["instruction"] == "pick up the cube" + assert result["step_count"] == 42 + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_get_features_delegates(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getFeatures()) + robot.get_features.assert_called_once() + assert result["status"] == "success" + + +# ── TestSimDriverAllRobots ─────────────────────────────────────── + + +class TestSimDriverAllRobots: + """Parametrized tests for SimulationDeviceDriver across all 32 sim-capable robots.""" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_identity(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + assert driver.identity.device_type == "strands_sim" + assert driver.identity.model == f"{robot_name}_sim" + assert driver.identity.manufacturer == "strands-robots" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_status_idle(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + assert driver.status.availability == "idle" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_status_busy(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + sim._world.robots[robot_name].policy_running = True + assert driver.status.availability == "busy" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_execute_auto_detects_robot(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + result = asyncio.run( + driver.execute("pick up cube", "mock", 30.0, "") + ) + sim.start_policy.assert_called_once_with( + robot_name=robot_name, policy_provider="mock", instruction="pick up cube", duration=30.0 + ) + assert result["status"] == "success" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_stop_sets_policy_running_false(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + sim._world.robots[robot_name].policy_running = True + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.stop()) + assert sim._world.robots[robot_name].policy_running is False + assert result["status"] == "success" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_step_delegates(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.step(10)) + sim.step.assert_called_once_with(10) + assert result["status"] == "success" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_reset_delegates(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.reset()) + sim.reset.assert_called_once() + assert result["status"] == "success" + + +# ── TestRealOnlyRobots ─────────────────────────────────────────── + + +class TestRealOnlyRobots: + """Tests for real-only robots (no sim asset): lekiwi, reachy2, hope_jr, earthrover, omx, bi_openarm.""" + + @pytest.mark.parametrize("robot_name,robot_info", REAL_ONLY_ROBOTS, ids=[r[0] for r in REAL_ONLY_ROBOTS]) + def test_driver_creation(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + assert driver is not None + + @pytest.mark.parametrize("robot_name,robot_info", REAL_ONLY_ROBOTS, ids=[r[0] for r in REAL_ONLY_ROBOTS]) + def test_identity_no_asset(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + assert driver.identity.model == robot_name + assert driver.identity.device_type == "strands_robot" + + @pytest.mark.parametrize("robot_name,robot_info", REAL_ONLY_ROBOTS, ids=[r[0] for r in REAL_ONLY_ROBOTS]) + def test_execute_delegates(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + result = asyncio.run( + driver.execute("move forward", "mock", 10.0, 0) + ) + robot.start_task.assert_called_once() + assert result["status"] == "success" + + +# ── TestMultiRobotSimulation ───────────────────────────────────── + + +class TestMultiRobotSimulation: + """Tests for multi-robot simulation scenarios with diverse joint counts.""" + + def _make_robot_data(self, running=False): + robot_data = MagicMock() + robot_data.policy_running = running + robot_data.policy_steps = 0 + robot_data.policy_instruction = "" + return robot_data + + def test_mixed_joint_counts(self): + """so100 (13 joints) + unitree_g1 (46 joints) in one world.""" + robots_in_world = { + "so100": self._make_robot_data(), + "unitree_g1": self._make_robot_data(), + } + sim = _make_mock_sim("mixed", _REGISTRY["so100"], robots_in_world=robots_in_world) + driver = SimulationDeviceDriver(sim) + # Execute auto-detects first robot + asyncio.run( + driver.execute("test", "mock", 10.0, "") + ) + sim.start_policy.assert_called_once() + call_kwargs = sim.start_policy.call_args + assert call_kwargs[1]["robot_name"] in ("so100", "unitree_g1") + + def test_stop_all_policies(self): + """Stop sets policy_running=False on all robots in the world.""" + robots_in_world = { + "so100": self._make_robot_data(running=True), + "panda": self._make_robot_data(running=True), + "unitree_go2": self._make_robot_data(running=True), + } + sim = _make_mock_sim("fleet", _REGISTRY["so100"], robots_in_world=robots_in_world) + driver = SimulationDeviceDriver(sim) + asyncio.run(driver.stop()) + for name, robot_data in robots_in_world.items(): + assert robot_data.policy_running is False, f"{name} still running" + + def test_execute_with_explicit_robot_name(self): + """Target a specific robot in a multi-robot sim.""" + robots_in_world = { + "so100": self._make_robot_data(), + "unitree_g1": self._make_robot_data(), + } + sim = _make_mock_sim("multi", _REGISTRY["so100"], robots_in_world=robots_in_world) + driver = SimulationDeviceDriver(sim) + asyncio.run( + driver.execute("walk forward", "mock", 30.0, "unitree_g1") + ) + sim.start_policy.assert_called_once_with( + robot_name="unitree_g1", policy_provider="mock", instruction="walk forward", duration=30.0 + ) + + def test_execute_empty_world(self): + """Returns error when no robots in simulation.""" + sim = _make_mock_sim("empty", _REGISTRY["so100"], robots_in_world={}) + driver = SimulationDeviceDriver(sim) + result = asyncio.run( + driver.execute("test", "mock", 10.0, "") + ) + assert result["status"] == "error" + + +# ── TestEdgeCases ──────────────────────────────────────────────── + + +class TestEdgeCases: + """Edge case tests for driver robustness.""" + + def test_observation_all_arrays(self): + """Observation with only array values → joints dict is empty.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"]) + robot.robot.get_observation.return_value = { + "camera_front": _FakeArray((480, 640, 3)), + "camera_wrist": _FakeArray((480, 640, 3)), + "depth": _FakeArray((480, 640)), + } + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result.get("joints", {}) == {} + + def test_observation_empty(self): + """Empty observation → no joints key or empty joints.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"]) + robot.robot.get_observation.return_value = {} + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result.get("joints", {}) == {} + + def test_observation_raises(self): + """get_observation() throws → getState still returns task info.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"], task_status="running") + robot.robot.get_observation.side_effect = RuntimeError("hardware disconnected") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result["task_status"] == "running" + assert result["instruction"] == "pick up the cube" + assert "joints" not in result + + def test_no_inner_robot(self): + """robot.robot is None → getState skips observation.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"], task_status="running") + robot.robot = None + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result["task_status"] == "running" + assert "joints" not in result + + def test_no_task_state(self): + """_task_state is None → status is idle, getState has no task info.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"]) + robot._task_state = None + driver = RobotDeviceDriver(robot) + assert driver.status.availability == "idle" + result = asyncio.run(driver.getState()) + assert "task_status" not in result + + def test_float_conversion_failure(self): + """Non-numeric scalar in observation → graceful handling via exception catch.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"]) + robot.robot.get_observation.return_value = { + "joint_0": 0.5, + "metadata": "not_a_number", + } + driver = RobotDeviceDriver(robot) + # float("not_a_number") raises ValueError; the driver wraps get_observation + # in a try/except, so it either filters it out or catches the error + result = asyncio.run(driver.getState()) + # Either joints has only joint_0, or the whole observation was skipped + if "joints" in result: + assert "metadata" not in result["joints"] or isinstance(result["joints"].get("metadata"), float) + + def test_max_joint_robot(self): + """unitree_g1 (46 joints) — all joints appear in getState.""" + info = _REGISTRY["unitree_g1"] + robot = _make_mock_robot("unitree_g1", info, task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert len(result["joints"]) == 46 + + def test_min_joint_robot(self): + """koch (7 joints) — correct joint count.""" + info = _REGISTRY["koch"] + robot = _make_mock_robot("koch", info, task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert len(result["joints"]) == 7 + + +# ── TestRobotMeshDispatchAllTypes ──────────────────────────────── + + +# Mock strands @tool decorator and device_connect_agent_tools for robot_mesh imports +_mesh_mock_keys = ("strands", "device_connect_agent_tools", "device_connect_agent_tools.connection") +for _key in _mesh_mock_keys: + if _key not in sys.modules: + if _key == "strands": + _m = MagicMock() + _m.tool = lambda fn: fn + sys.modules[_key] = _m + else: + sys.modules.setdefault(_key, MagicMock()) + + +class _FakeConnection: + """Fake connection with all methods the dispatch uses.""" + + def __init__(self, devices=None): + self.zone = "default" + self._devices = devices or [] + self._invoke_results = {} + self._inbox = {} + self._sync_subs = {} + + def list_devices(self, device_type=None): + if device_type: + return [d for d in self._devices if d.get("device_type") == device_type] + return list(self._devices) + + def invoke(self, device_id, function, params=None, timeout=30.0): + key = (device_id, function) + if key in self._invoke_results: + return self._invoke_results[key] + return {"result": {"status": "ok"}} + + def broadcast(self, function, params=None, timeout=5.0): + results = [] + for d in self._devices: + try: + r = self.invoke(d["device_id"], function, params, timeout=timeout) + results.append({"device_id": d["device_id"], "result": r}) + except Exception as e: + results.append({"device_id": d["device_id"], "error": str(e)}) + return results + + def subscribe_buffered(self, subject, name=None): + name = name or subject + self._inbox[name] = [] + self._sync_subs[name] = True + return name + + def get_inbox(self, name=None): + if name is not None: + return {name: list(self._inbox.get(name, []))} + return {k: list(v) for k, v in self._inbox.items()} + + +# Build a diverse fleet of sample devices from the registry +_CATEGORY_REPRESENTATIVES = { + "arm": ("so100", "strands_robot"), + "bimanual": ("aloha", "strands_robot"), + "hand": ("shadow_hand", "strands_robot"), + "humanoid": ("unitree_g1", "strands_sim"), + "expressive": ("reachy_mini", "strands_robot"), + "mobile": ("unitree_go2", "strands_sim"), + "mobile_manip": ("google_robot", "strands_sim"), +} + +DIVERSE_DEVICES = [] +for category, (robot_name, device_type) in _CATEGORY_REPRESENTATIVES.items(): + DIVERSE_DEVICES.append({ + "device_id": f"{robot_name}-{category}-1", + "device_type": device_type, + "status": {"availability": "idle"}, + "functions": [{"name": "execute"}, {"name": "stop"}, {"name": "getStatus"}], + "events": ["taskStarted", "taskComplete"] if device_type == "strands_robot" else ["stateUpdate"], + }) + + +class TestRobotMeshDispatchAllTypes: + """Tests robot_mesh dispatch with a diverse fleet spanning all robot categories.""" + + def _get_dispatch(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + return _device_connect_dispatch + + def _call(self, dispatch, conn, action, **kwargs): + defaults = dict( + target="", instruction="", command="", + policy_provider="mock", policy_port=0, + duration=30.0, timeout=5.0, + ) + defaults.update(kwargs) + with patch("device_connect_agent_tools.connection.get_connection", return_value=conn): + return dispatch(action, **{k: defaults[k] for k in [ + "target", "instruction", "command", + "policy_provider", "policy_port", "duration", "timeout", + ]}) + + def test_peers_lists_all_categories(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "peers") + assert result["status"] == "success" + text = result["content"][0]["text"] + assert f"{len(DIVERSE_DEVICES)} device(s)" in text + for device in DIVERSE_DEVICES: + assert device["device_id"] in text + + def test_tell_arm_robot(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "tell", target="so100-arm-1", instruction="pick up cube") + assert result["status"] == "success" + assert "so100-arm-1" in result["content"][0]["text"] + + def test_tell_humanoid_sim(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "tell", target="unitree_g1-humanoid-1", instruction="walk forward") + assert result["status"] == "success" + assert "unitree_g1-humanoid-1" in result["content"][0]["text"] + + def test_tell_mobile_robot(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "tell", target="unitree_go2-mobile-1", instruction="navigate to door") + assert result["status"] == "success" + assert "unitree_go2-mobile-1" in result["content"][0]["text"] + + def test_emergency_stop_all_types(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "emergency_stop") + assert result["status"] == "success" + text = result["content"][0]["text"] + assert "E-STOP" in text + assert f"{len(DIVERSE_DEVICES)}/{len(DIVERSE_DEVICES)}" in text + + def test_status_mixed_fleet(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "status") + assert result["status"] == "success" + assert f"{len(DIVERSE_DEVICES)} device(s)" in result["content"][0]["text"] diff --git a/tests/test_device_connect_drivers.py b/tests/test_device_connect_drivers.py new file mode 100644 index 0000000..1a49ff6 --- /dev/null +++ b/tests/test_device_connect_drivers.py @@ -0,0 +1,855 @@ +"""Unit tests for Device Connect DeviceDriver adapters. + +Tests RobotDeviceDriver, SimulationDeviceDriver, ReachyMiniDriver, +init_device_connect(), and the updated robot_mesh tool. + +All external dependencies (Zenoh, LeRobot, device_connect_edge, strands) are mocked. +""" + +import asyncio +import json +import math +import sys +import unittest +from dataclasses import dataclass +from enum import Enum +from unittest.mock import AsyncMock, MagicMock, patch + + +# ── Mock heavy dependencies before importing ────────────────────── + +# Mock device_connect_edge +mock_device_connect_edge = MagicMock() +mock_drivers = MagicMock() + + +class _FakeDeviceDriver: + """Minimal stub so our drivers can subclass it.""" + device_type = None + + def __init__(self): + self._transport = None + + def set_device(self, device): + pass + + @property + def transport(self): + return self._transport + + +# Make @rpc, @emit, @periodic, @on pass-through decorators +def _passthrough_decorator(*args, **kwargs): + if len(args) == 1 and callable(args[0]): + return args[0] + def wrapper(func): + # Tag the function so tests can verify decorator usage + for k, v in kwargs.items(): + setattr(func, f"_{k}", v) + return func + return wrapper + + +mock_drivers.DeviceDriver = _FakeDeviceDriver +mock_drivers.rpc = _passthrough_decorator +mock_drivers.emit = _passthrough_decorator +mock_drivers.periodic = _passthrough_decorator +mock_drivers.on = _passthrough_decorator + +mock_types = MagicMock() + + +@dataclass +class FakeDeviceIdentity: + device_type: str = None + manufacturer: str = None + model: str = None + description: str = None + serial_number: str = None + firmware_version: str = None + arch: str = None + commissioning_comment: str = None + + +@dataclass +class FakeDeviceStatus: + availability: str = "idle" + busy_score: float = 0.0 + location: str = None + battery: int = None + online: bool = True + error_state: str = None + + +mock_types.DeviceIdentity = FakeDeviceIdentity +mock_types.DeviceStatus = FakeDeviceStatus + +# Save originals so we can restore after this module's tests run +_saved_modules = {} +_mock_keys = ("device_connect_edge", "device_connect_edge.drivers", + "device_connect_edge.types", "device_connect_edge.device") +# Also track strands_robots.device_connect submodules that will be imported +# with the mocked base class — they need to be purged so later tests re-import +# with the real base class. +_strands_dc_keys = [k for k in sys.modules if k.startswith("strands_robots.device_connect")] +for _key in list(_mock_keys) + _strands_dc_keys: + _saved_modules[_key] = sys.modules.get(_key) + +sys.modules["device_connect_edge"] = mock_device_connect_edge +sys.modules["device_connect_edge.drivers"] = mock_drivers +sys.modules["device_connect_edge.types"] = mock_types +sys.modules["device_connect_edge.device"] = MagicMock() + +# Mock DeviceRuntime +mock_device_runtime = MagicMock() +mock_device_connect_edge.DeviceRuntime = mock_device_runtime + +# Now import our modules +from strands_robots.device_connect.robot_driver import RobotDeviceDriver +from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + + +def teardown_module(): + """Restore real device_connect_edge modules so other test files are not affected. + + Also purge cached strands_robots.device_connect submodules that were imported + with the mock base class, so later test files get fresh imports with the real base. + """ + # Restore device_connect_edge modules + for key, original in _saved_modules.items(): + if original is None: + sys.modules.pop(key, None) + else: + sys.modules[key] = original + # Purge ALL strands_robots.device_connect submodules — they were imported + # with the mock DeviceDriver base class and must be re-imported fresh. + for key in list(sys.modules): + if key.startswith("strands_robots.device_connect"): + sys.modules.pop(key, None) + + +# ── Task state mocks ────────────────────────────────────────────── + +class FakeTaskStatus(Enum): + IDLE = "idle" + RUNNING = "running" + COMPLETED = "completed" + STOPPED = "stopped" + ERROR = "error" + + +@dataclass +class FakeTaskState: + status: FakeTaskStatus = FakeTaskStatus.IDLE + instruction: str = "" + start_time: float = 0.0 + duration: float = 0.0 + step_count: int = 0 + error_message: str = "" + + +def _make_mock_robot(tool_name="so100", task_status="idle"): + robot = MagicMock() + robot.tool_name_str = tool_name + robot._task_state = FakeTaskState( + status=FakeTaskStatus(task_status), + instruction="pick up the cube" if task_status == "running" else "", + step_count=42 if task_status == "running" else 0, + ) + robot.start_task.return_value = {"status": "success", "content": [{"text": "Task started"}]} + robot.stop_task.return_value = {"status": "success", "content": [{"text": "Task stopped"}]} + robot.get_task_status.return_value = {"status": "success", "content": [{"text": "Status info"}]} + robot.get_features.return_value = { + "status": "success", + "content": [{"json": {"observation_features": {"joint1": "float"}, "action_features": {"joint1": "float"}}}], + } + # Mock inner lerobot robot + robot.robot = MagicMock() + robot.robot.get_observation.return_value = {"joint1": 0.5, "joint2": -1.2} + return robot + + +def _make_mock_sim(tool_name="so100_sim"): + sim = MagicMock() + sim.tool_name_str = tool_name + + # SimWorld-like structure + robot_data = MagicMock() + robot_data.policy_running = False + robot_data.policy_steps = 0 + robot_data.policy_instruction = "" + + world = MagicMock() + world.robots = {"so100": robot_data} + world.sim_time = 0.0 + world.step_count = 0 + sim._world = world + + sim.start_policy.return_value = {"status": "success", "content": [{"text": "Policy started"}]} + sim.get_state.return_value = {"status": "success", "content": [{"text": "State info"}]} + sim.get_features.return_value = {"status": "success", "content": [{"json": {"features": {}}}]} + sim.step.return_value = {"status": "success", "content": [{"text": "Stepped"}]} + sim.reset.return_value = {"status": "success", "content": [{"text": "Reset"}]} + return sim + + +# ── TestRobotDeviceDriver ───────────────────────────────────────── + +class TestRobotDeviceDriver(unittest.TestCase): + + def test_identity(self): + robot = _make_mock_robot(tool_name="so100") + driver = RobotDeviceDriver(robot) + identity = driver.identity + self.assertEqual(identity.device_type, "strands_robot") + self.assertEqual(identity.manufacturer, "strands-robots") + self.assertEqual(identity.model, "so100") + + def test_status_idle(self): + robot = _make_mock_robot(task_status="idle") + driver = RobotDeviceDriver(robot) + status = driver.status + self.assertEqual(status.availability, "idle") + self.assertEqual(status.busy_score, 0.0) + + def test_status_busy(self): + robot = _make_mock_robot(task_status="running") + driver = RobotDeviceDriver(robot) + status = driver.status + self.assertEqual(status.availability, "busy") + self.assertEqual(status.busy_score, 1.0) + + def test_execute_rpc(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + result = asyncio.run( + driver.execute("pick up cube", "groot", 30.0, 0) + ) + robot.start_task.assert_called_once_with("pick up cube", "groot", None, "localhost", 30.0) + self.assertEqual(result["status"], "success") + + def test_execute_rpc_with_port(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + asyncio.run( + driver.execute("wave", "groot", 10.0, 50051) + ) + robot.start_task.assert_called_once_with("wave", "groot", 50051, "localhost", 10.0) + + def test_stop_rpc(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.stop()) + robot.stop_task.assert_called_once() + self.assertEqual(result["status"], "success") + + def test_get_status_rpc(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getStatus()) + robot.get_task_status.assert_called_once() + self.assertEqual(result["status"], "success") + + def test_get_features_rpc(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getFeatures()) + robot.get_features.assert_called_once() + self.assertEqual(result["status"], "success") + + def test_get_state_rpc(self): + robot = _make_mock_robot(task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + self.assertEqual(result["task_status"], "running") + self.assertEqual(result["instruction"], "pick up the cube") + self.assertEqual(result["step_count"], 42) + # Joints should be read from inner robot + self.assertIn("joints", result) + self.assertAlmostEqual(result["joints"]["joint1"], 0.5) + + def test_connect_disconnect_noop(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + asyncio.run(driver.connect()) + asyncio.run(driver.disconnect()) + # Should not raise + + def test_emergency_stop_handler(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + asyncio.run( + driver.onEmergencyStop("other-robot", "emergencyStop", {"reason": "test"}) + ) + robot.stop_task.assert_called_once() + + +# ── TestSimulationDeviceDriver ──────────────────────────────────── + +class TestSimulationDeviceDriver(unittest.TestCase): + + def test_identity(self): + sim = _make_mock_sim(tool_name="mujoco_sim") + driver = SimulationDeviceDriver(sim) + identity = driver.identity + self.assertEqual(identity.device_type, "strands_sim") + self.assertEqual(identity.manufacturer, "strands-robots") + self.assertEqual(identity.model, "mujoco_sim") + + def test_status_idle(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + status = driver.status + self.assertEqual(status.availability, "idle") + + def test_status_busy(self): + sim = _make_mock_sim() + sim._world.robots["so100"].policy_running = True + driver = SimulationDeviceDriver(sim) + status = driver.status + self.assertEqual(status.availability, "busy") + + def test_identity_sim_type(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + self.assertEqual(driver.device_type, "strands_sim") + + def test_execute_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run( + driver.execute("pick up cube", "mock", 10.0) + ) + sim.start_policy.assert_called_once_with( + robot_name="so100", + policy_provider="mock", + instruction="pick up cube", + duration=10.0, + ) + self.assertEqual(result["status"], "success") + + def test_execute_with_robot_name(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + asyncio.run( + driver.execute("wave", "mock", 5.0, robot_name="arm2") + ) + sim.start_policy.assert_called_once_with( + robot_name="arm2", + policy_provider="mock", + instruction="wave", + duration=5.0, + ) + + def test_stop_rpc(self): + sim = _make_mock_sim() + sim._world.robots["so100"].policy_running = True + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.stop()) + self.assertEqual(result["status"], "success") + self.assertFalse(sim._world.robots["so100"].policy_running) + + def test_get_status_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.getStatus()) + sim.get_state.assert_called_once() + + def test_get_features_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.getFeatures()) + sim.get_features.assert_called_once() + + def test_step_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.step(10)) + sim.step.assert_called_once_with(10) + + def test_reset_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.reset()) + sim.reset.assert_called_once() + + def test_emergency_stop_handler(self): + sim = _make_mock_sim() + sim._world.robots["so100"].policy_running = True + driver = SimulationDeviceDriver(sim) + asyncio.run( + driver.onEmergencyStop("other-device", "emergencyStop", {"reason": "test"}) + ) + self.assertFalse(sim._world.robots["so100"].policy_running) + + +# ── TestReachyMiniDriver ───────────────────────────────────────── + +class TestReachyMiniDriver(unittest.TestCase): + + def setUp(self): + # Mock reachy_transport module but keep real ZenohLink/WebSocketLink + from strands_robots.device_connect.reachy_transport import ZenohLink, WebSocketLink + self.mock_transport_mod = MagicMock() + self.mock_transport_mod.api.return_value = {"status": "ok"} + self.mock_transport_mod.rpy_to_pose.side_effect = lambda *args, **kwargs: [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] + self.mock_transport_mod.identity_pose.return_value = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] + self.mock_transport_mod.ZenohLink = ZenohLink + self.mock_transport_mod.WebSocketLink = WebSocketLink + + self.transport_patcher = patch.dict(sys.modules, { + "strands_robots.device_connect.reachy_transport": self.mock_transport_mod, + }) + self.transport_patcher.start() + + # Re-import to pick up mocks + if "strands_robots.device_connect.reachy_mini_driver" in sys.modules: + del sys.modules["strands_robots.device_connect.reachy_mini_driver"] + from strands_robots.device_connect.reachy_mini_driver import ReachyMiniDriver + self.ReachyMiniDriver = ReachyMiniDriver + + def tearDown(self): + self.transport_patcher.stop() + + def _make_driver(self, **kwargs): + """Create a driver with a mocked Device Connect transport and ZenohLink-like _hw.""" + driver = self.ReachyMiniDriver(**kwargs) + mock_transport = AsyncMock() + mock_transport.publish = AsyncMock() + mock_transport.subscribe = AsyncMock() + driver._transport = mock_transport + + # Create a HW link that delegates to mock_transport (like ZenohLink does) + prefix = driver._prefix + class _MockZenohLink: + async def send_cmd(self, cmd): + await mock_transport.publish( + f"{prefix}/command", json.dumps(cmd).encode() + ) + async def start(self, on_joints, on_imu): + await mock_transport.subscribe(f"{prefix}/joint_positions", on_joints) + await mock_transport.subscribe(f"{prefix}/imu_data", on_imu) + async def stop(self): + pass + driver._hw = _MockZenohLink() + return driver + + def test_identity(self): + driver = self.ReachyMiniDriver(host="192.168.1.50") + identity = driver.identity + self.assertEqual(identity.device_type, "reachy_mini") + self.assertEqual(identity.manufacturer, "Pollen Robotics") + self.assertIn("192.168.1.50", identity.model) + + def test_look_rpc(self): + driver = self._make_driver() + result = asyncio.run( + driver.look(pitch=15, yaw=30) + ) + self.assertEqual(result["status"], "success") + self.assertEqual(result["pitch"], 15) + self.assertEqual(result["yaw"], 30) + # Verify transport.publish was called with the command topic + driver._transport.publish.assert_awaited() + topic = driver._transport.publish.call_args[0][0] + self.assertEqual(topic, "reachy_mini/command") + + def test_antennas_rpc(self): + driver = self._make_driver() + result = asyncio.run( + driver.antennas(left=45, right=-30) + ) + self.assertEqual(result["status"], "success") + self.assertEqual(result["left"], 45) + self.assertEqual(result["right"], -30) + driver._transport.publish.assert_awaited() + + def test_get_joints_rpc(self): + driver = self._make_driver() + # Pre-populate cached joint data + driver._latest_joints = { + "head_joint_positions": [0.1, 0.2, 0.3], + "antennas_joint_positions": [0.5, -0.5], + } + result = asyncio.run(driver.getJoints()) + self.assertEqual(result["status"], "success") + self.assertIn("head", result) + self.assertIn("antennas", result) + + def test_get_joints_no_data(self): + driver = self._make_driver() + result = asyncio.run(driver.getJoints()) + self.assertEqual(result["status"], "error") + + def test_get_imu_rpc(self): + driver = self._make_driver() + driver._latest_imu = { + "accelerometer": [0.1, 0.2, 9.8], + "gyroscope": [0.0, 0.0, 0.0], + "quaternion": [1, 0, 0, 0], + "temperature": 35.2, + } + result = asyncio.run(driver.getImu()) + self.assertEqual(result["status"], "success") + self.assertAlmostEqual(result["temperature"], 35.2) + + def test_get_imu_no_data(self): + driver = self._make_driver() + result = asyncio.run(driver.getImu()) + self.assertEqual(result["status"], "error") + + def test_enable_motors_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.enableMotors()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["enabled"], "all") + driver._transport.publish.assert_awaited() + + def test_disable_motors_rpc(self): + driver = self._make_driver() + result = asyncio.run( + driver.disableMotors(motor_ids="head_pitch,head_yaw") + ) + self.assertEqual(result["status"], "success") + self.assertEqual(result["disabled"], "head_pitch,head_yaw") + driver._transport.publish.assert_awaited() + + def test_play_move_rpc(self): + driver = self._make_driver() + result = asyncio.run( + driver.playMove("happy", library="emotions") + ) + self.assertEqual(result["status"], "success") + self.assertEqual(result["move"], "happy") + + def test_nod_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.nod()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["expression"], "nod") + # nod sends multiple publish calls (head_pose animation) + self.assertGreater(driver._transport.publish.await_count, 1) + + def test_shake_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.shake()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["expression"], "shake") + + def test_happy_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.happy()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["expression"], "happy") + + def test_wake_up_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.wakeUp()) + self.assertEqual(result["status"], "success") + + def test_sleep_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.sleep()) + self.assertEqual(result["status"], "success") + + def test_stop_motion_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.stopMotion()) + self.assertEqual(result["status"], "success") + + def test_daemon_status_rpc(self): + self.mock_transport_mod.api.return_value = {"state": "ready", "version": "1.0"} + driver = self._make_driver() + result = asyncio.run(driver.getDaemonStatus()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["state"], "ready") + + @patch("strands_robots.device_connect.reachy_mini_driver.api") + def test_connect_subscribes(self, mock_api): + # Simulate wireless variant (wireless_version=True) + mock_api.return_value = {"wireless_version": True} + driver = self.ReachyMiniDriver() + mock_transport = AsyncMock() + mock_transport.publish = AsyncMock() + mock_transport.subscribe = AsyncMock() + driver._transport = mock_transport + # connect() creates ZenohLink and subscribes via transport + asyncio.run(driver.connect()) + self.assertEqual(mock_transport.subscribe.await_count, 2) + topics = [call[0][0] for call in mock_transport.subscribe.call_args_list] + self.assertIn("reachy_mini/joint_positions", topics) + self.assertIn("reachy_mini/imu_data", topics) + + def test_disconnect(self): + driver = self._make_driver() + asyncio.run(driver.disconnect()) + + def test_emergency_stop_handler(self): + driver = self._make_driver() + asyncio.run( + driver.onEmergencyStop("other-device", "emergencyStop", {"reason": "test"}) + ) + # stopMotion calls REST API, disableMotors calls transport.publish + driver._transport.publish.assert_awaited() + + def test_command_payload_format(self): + """Verify that transport.publish receives correct JSON payload.""" + driver = self._make_driver() + asyncio.run(driver.enableMotors()) + _, payload_bytes = driver._transport.publish.call_args[0] + payload = json.loads(payload_bytes.decode()) + self.assertTrue(payload["torque"]) + self.assertIsNone(payload["ids"]) + + +# ── TestInitDeviceConnect ───────────────────────────────────────── + +class TestInitDeviceConnect(unittest.TestCase): + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_creates_robot_driver(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + robot = _make_mock_robot() + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(robot, peer_id="test-1", peer_type="robot")) + loop.close() + + # Verify DeviceRuntime was created with a RobotDeviceDriver + call_kwargs = MockRuntime.call_args + self.assertIsNotNone(call_kwargs) + driver = call_kwargs.kwargs.get("driver") or call_kwargs[1].get("driver") + self.assertEqual(type(driver).__name__, "RobotDeviceDriver") + self.assertEqual(driver._robot, robot) + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_creates_sim_driver(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + sim = _make_mock_sim() + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(sim, peer_id="test-sim", peer_type="sim")) + loop.close() + + call_kwargs = MockRuntime.call_args + driver = call_kwargs.kwargs.get("driver") or call_kwargs[1].get("driver") + self.assertEqual(type(driver).__name__, "SimulationDeviceDriver") + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_generates_device_id(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + robot = _make_mock_robot(tool_name="so100") + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(robot)) + loop.close() + + call_kwargs = MockRuntime.call_args + device_id = call_kwargs.kwargs.get("device_id") or call_kwargs[1].get("device_id") + self.assertTrue(device_id.startswith("so100-")) + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_explicit_device_id(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + robot = _make_mock_robot() + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(robot, peer_id="my-robot-42")) + loop.close() + + call_kwargs = MockRuntime.call_args + device_id = call_kwargs.kwargs.get("device_id") or call_kwargs[1].get("device_id") + self.assertEqual(device_id, "my-robot-42") + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_sets_heartbeat_provider(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + robot = _make_mock_robot() + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(robot, peer_id="test-hb")) + loop.close() + + mock_runtime.set_heartbeat_provider.assert_called_once() + + +# ── TestEmergencyStop (cross-driver) ────────────────────────────── + +class TestEmergencyStop(unittest.TestCase): + + def test_robot_reacts_to_emergency_stop(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + asyncio.run( + driver.onEmergencyStop("reachy-1", "emergencyStop", {"reason": "button pressed"}) + ) + robot.stop_task.assert_called_once() + + def test_sim_reacts_to_emergency_stop(self): + sim = _make_mock_sim() + sim._world.robots["so100"].policy_running = True + driver = SimulationDeviceDriver(sim) + asyncio.run( + driver.onEmergencyStop("barista-001", "emergencyStop", {"reason": "agent-initiated"}) + ) + self.assertFalse(sim._world.robots["so100"].policy_running) + + +# ── TestRobotMeshTool (Device Connect backend) ─────────────────── + +class TestRobotMeshToolDeviceConnect(unittest.TestCase): + + def setUp(self): + # Mock device_connect_agent_tools.connection + self.mock_conn = MagicMock() + self.mock_conn.list_devices.return_value = [ + { + "device_id": "so100-lab-1", + "device_type": "strands_robot", + "status": {"availability": "idle"}, + "functions": [{"name": "execute"}, {"name": "stop"}], + "events": [], + }, + { + "device_id": "reachy-mini-1", + "device_type": "reachy_mini", + "status": {"availability": "idle"}, + "functions": [{"name": "look"}, {"name": "nod"}], + "events": [], + }, + ] + self.mock_conn.invoke.return_value = { + "jsonrpc": "2.0", + "id": "test", + "result": {"status": "accepted"}, + } + + # Mock the device_connect_agent_tools modules before importing + mock_aft = MagicMock() + mock_aft_conn = MagicMock() + mock_aft_conn.get_connection.return_value = self.mock_conn + self._saved_modules = {} + for mod in ["device_connect_agent_tools", "device_connect_agent_tools.connection", + "device_connect_agent_tools.tools", "device_connect_agent_tools.agent", + "device_connect_agent_tools.adapters", "device_connect_agent_tools.adapters.strands"]: + self._saved_modules[mod] = sys.modules.get(mod) + sys.modules[mod] = mock_aft if mod == "device_connect_agent_tools" else mock_aft_conn + + # Force reimport of robot_mesh to pick up the mocked modules + if "strands_robots.tools.robot_mesh" in sys.modules: + del sys.modules["strands_robots.tools.robot_mesh"] + + def tearDown(self): + for mod, saved in self._saved_modules.items(): + if saved is None: + sys.modules.pop(mod, None) + else: + sys.modules[mod] = saved + if "strands_robots.tools.robot_mesh" in sys.modules: + del sys.modules["strands_robots.tools.robot_mesh"] + + def test_peers_action(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("peers", "", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "success") + text = result["content"][0]["text"] + self.assertIn("so100-lab-1", text) + self.assertIn("reachy-mini-1", text) + + def test_tell_action(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch( + "tell", "so100-lab-1", "pick up cube", "", "groot", 0, 30.0, 30.0, + ) + self.assertEqual(result["status"], "success") + self.mock_conn.invoke.assert_called_once() + call_args = self.mock_conn.invoke.call_args + self.assertEqual(call_args[0][0], "so100-lab-1") + self.assertEqual(call_args[0][1], "execute") + + def test_stop_action(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("stop", "so100-lab-1", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "success") + self.mock_conn.invoke.assert_called_once_with("so100-lab-1", "stop", timeout=5.0) + + def test_emergency_stop(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("emergency_stop", "", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "success") + self.assertIn("2", result["content"][0]["text"]) # 2 devices stopped + + def test_missing_target(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("tell", "", "do something", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "error") + self.assertIn("target", result["content"][0]["text"]) + + def test_missing_instruction(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("tell", "so100-lab-1", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "error") + self.assertIn("instruction", result["content"][0]["text"]) + + def test_status_action(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("status", "", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "success") + self.assertIn("2 device(s)", result["content"][0]["text"]) + + +# ── TestReachyTransport ─────────────────────────────────────────── + +class TestReachyTransport(unittest.TestCase): + """Test the extracted transport helpers.""" + + def test_rpy_to_pose_identity(self): + from strands_robots.device_connect.reachy_transport import rpy_to_pose + pose = rpy_to_pose(0, 0, 0) + # Should be close to identity rotation + self.assertAlmostEqual(pose[0][0], 1.0, places=5) + self.assertAlmostEqual(pose[1][1], 1.0, places=5) + self.assertAlmostEqual(pose[2][2], 1.0, places=5) + self.assertAlmostEqual(pose[3][3], 1.0, places=5) + + def test_rpy_to_pose_translation(self): + from strands_robots.device_connect.reachy_transport import rpy_to_pose + pose = rpy_to_pose(0, 0, 0, x_mm=100, y_mm=200, z_mm=300) + self.assertAlmostEqual(pose[0][3], 0.1, places=5) # 100mm = 0.1m + self.assertAlmostEqual(pose[1][3], 0.2, places=5) + self.assertAlmostEqual(pose[2][3], 0.3, places=5) + + def test_identity_pose(self): + from strands_robots.device_connect.reachy_transport import identity_pose + pose = identity_pose() + self.assertEqual(pose, [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]) + + def test_resolve_host_ip(self): + from strands_robots.device_connect.reachy_transport import resolve_host + # IP should pass through unchanged + result = resolve_host("192.168.1.1") + self.assertEqual(result, "192.168.1.1") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_device_connect_integration.py b/tests/test_device_connect_integration.py new file mode 100644 index 0000000..2d821ff --- /dev/null +++ b/tests/test_device_connect_integration.py @@ -0,0 +1,368 @@ +"""Integration tests for Device Connect DeviceDriver adapters. + +Requires Docker infrastructure running: + - Zenoh router (:7447) + - etcd (:2379) + - device-registry (:8000) + +Start with: + cd device-connect/packages/device-connect-server + docker compose -f infra/docker-compose-dev.yml up -d + +Run with: + MESSAGING_BACKEND=zenoh ZENOH_CONNECT=tcp/localhost:7447 \ + DEVICE_CONNECT_ALLOW_INSECURE=true python3 -m pytest tests/test_device_connect_integration.py -v +""" + +import asyncio +import os +from unittest.mock import MagicMock + +import pytest + +pytestmark = [ + pytest.mark.integration, + pytest.mark.skipif( + not os.getenv("DEVICE_CONNECT_ALLOW_INSECURE"), + reason="Requires Docker infrastructure (set DEVICE_CONNECT_ALLOW_INSECURE=true)", + ), +] + + +def _make_mock_robot(tool_name="itest-robot"): + """Create a mock Robot for integration testing.""" + from dataclasses import dataclass + from enum import Enum + + class TaskStatus(Enum): + IDLE = "idle" + RUNNING = "running" + + @dataclass + class TaskState: + status: TaskStatus = TaskStatus.IDLE + instruction: str = "" + step_count: int = 0 + + robot = MagicMock() + robot.tool_name_str = tool_name + robot._task_state = TaskState() + robot.start_task.return_value = {"status": "success", "content": [{"text": "Task started"}]} + robot.stop_task.return_value = {"status": "success", "content": [{"text": "Task stopped"}]} + robot.get_task_status.return_value = {"status": "success", "content": [{"text": "Idle"}]} + robot.get_features.return_value = {"status": "success", "content": [{"json": {}}]} + robot.robot = MagicMock() + robot.robot.get_observation.return_value = {"joint1": 0.5} + return robot + + +def _make_mock_sim(tool_name="itest-sim"): + """Create a mock Simulation for integration testing.""" + sim = MagicMock() + sim.tool_name_str = tool_name + + robot_data = MagicMock() + robot_data.policy_running = False + robot_data.policy_steps = 0 + robot_data.policy_instruction = "" + + world = MagicMock() + world.robots = {"arm1": robot_data} + world.sim_time = 0.0 + world.step_count = 0 + sim._world = world + + sim.start_policy.return_value = {"status": "success", "content": [{"text": "Started"}]} + sim.get_state.return_value = {"status": "success", "content": [{"text": "State"}]} + sim.get_features.return_value = {"status": "success", "content": [{"json": {}}]} + sim.step.return_value = {"status": "success", "content": [{"text": "Stepped"}]} + sim.reset.return_value = {"status": "success", "content": [{"text": "Reset"}]} + return sim + + +@pytest.fixture(autouse=True) +def device_connect_env(): + """Set environment for Device Connect messaging. + + Supports both Zenoh and NATS backends. The backend is chosen by the + MESSAGING_BACKEND env-var (default ``nats`` to match the standard + docker-compose-itest.yml setup). + """ + backend = os.getenv("MESSAGING_BACKEND", "nats") + os.environ.setdefault("MESSAGING_BACKEND", backend) + + if backend == "zenoh": + url = os.getenv("ZENOH_CONNECT", "tcp/localhost:7447") + os.environ.setdefault("ZENOH_CONNECT", url) + os.environ.setdefault("MESSAGING_URLS", url) + else: + url = os.getenv("NATS_URL", "nats://localhost:4222") + os.environ.setdefault("NATS_URL", url) + os.environ.setdefault("MESSAGING_URLS", url) + + os.environ.setdefault("DEVICE_CONNECT_ALLOW_INSECURE", "true") + yield + + +class TestRobotDriverRegistration: + """Test that RobotDeviceDriver registers and is discoverable.""" + + async def test_robot_driver_registers(self): + """Create RobotDeviceDriver + DeviceRuntime, verify device is discoverable.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-robot-001", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + # Wait for registration + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices, device_type="strands_robot") + device_ids = [d["device_id"] for d in devices] + assert "itest-robot-001" in device_ids, f"Expected itest-robot-001 in {device_ids}" + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + async def test_robot_execute_rpc(self): + """Discover robot and invoke execute RPC.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-robot-exec", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + result = await asyncio.to_thread( + conn.invoke, + "itest-robot-exec", "execute", + {"instruction": "test move", "policy_provider": "mock", "duration": 5.0}, + ) + assert "result" in result, f"Expected result in {result}" + robot.start_task.assert_called_once() + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + async def test_robot_stop_rpc(self): + """Invoke stop RPC on a registered robot.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-robot-stop", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + result = await asyncio.to_thread(conn.invoke, "itest-robot-stop", "stop") + assert "result" in result + robot.stop_task.assert_called_once() + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + +class TestSimDriverRegistration: + """Test that SimulationDeviceDriver registers and is discoverable.""" + + async def test_sim_driver_registers(self): + """Create SimulationDeviceDriver + DeviceRuntime, verify device is discoverable.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-sim-001", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices, device_type="strands_sim") + device_ids = [d["device_id"] for d in devices] + assert "itest-sim-001" in device_ids + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + async def test_sim_step_rpc(self): + """Invoke step RPC on a registered simulation.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-sim-step", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + result = await asyncio.to_thread(conn.invoke, "itest-sim-step", "step", {"n_steps": 10}) + assert "result" in result + sim.step.assert_called_once_with(10) + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + +class TestMultipleDevices: + """Test multiple devices registered simultaneously.""" + + async def test_multiple_devices_discoverable(self): + """Register 3 devices and verify all are discoverable.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + + robot1 = _make_mock_robot("robot-a") + robot2 = _make_mock_robot("robot-b") + sim1 = _make_mock_sim("sim-c") + + runtimes = [] + tasks = [] + for device_id, driver_cls, instance in [ + ("itest-multi-a", RobotDeviceDriver, robot1), + ("itest-multi-b", RobotDeviceDriver, robot2), + ("itest-multi-c", SimulationDeviceDriver, sim1), + ]: + driver = driver_cls(instance) + runtime = DeviceRuntime( + driver=driver, + device_id=device_id, + allow_insecure=True, + ) + runtimes.append(runtime) + tasks.append(asyncio.create_task(runtime.run())) + + try: + await asyncio.sleep(5) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + device_ids = {d["device_id"] for d in devices} + assert "itest-multi-a" in device_ids + assert "itest-multi-b" in device_ids + assert "itest-multi-c" in device_ids + finally: + await asyncio.to_thread(disconnect) + finally: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + +class TestInitDeviceConnectE2E: + """End-to-end test of init_device_connect().""" + + async def test_init_device_connect_e2e(self): + """init_device_connect() -> device registers -> discoverable -> invocable.""" + from strands_robots.device_connect import init_device_connect + + robot = _make_mock_robot("e2e-robot") + runtime = await init_device_connect(robot, peer_id="itest-e2e-robot") + + try: + # Wait for registration + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + + # Discoverable + devices = await asyncio.to_thread(conn.list_devices, device_type="strands_robot") + device_ids = [d["device_id"] for d in devices] + assert "itest-e2e-robot" in device_ids + + # Invocable + result = await asyncio.to_thread(conn.invoke, "itest-e2e-robot", "getStatus") + assert "result" in result + finally: + await asyncio.to_thread(disconnect) + finally: + await runtime.stop() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_robot_mesh_tool.py b/tests/test_robot_mesh_tool.py new file mode 100644 index 0000000..6377d5d --- /dev/null +++ b/tests/test_robot_mesh_tool.py @@ -0,0 +1,251 @@ +"""Unit tests for the robot_mesh tool — Device Connect dispatch path + fallback. + +All external dependencies are mocked. No Docker or real connections needed. +""" + +import json +import sys +import unittest +from unittest.mock import MagicMock, patch + + +# ── Mock heavy dependencies before importing ────────────────────── + +# Mock strands @tool decorator as passthrough +mock_strands = MagicMock() +mock_strands.tool = lambda fn: fn +sys.modules.setdefault("strands", mock_strands) + +# Mock device_connect_agent_tools +mock_dc_connection = MagicMock() +sys.modules.setdefault("device_connect_agent_tools", MagicMock()) +sys.modules.setdefault("device_connect_agent_tools.connection", mock_dc_connection) + + +class _FakeConnection: + """Fake _DeviceConnectConnection with all methods the dispatch uses.""" + + def __init__(self, devices=None): + self.zone = "default" + self._devices = devices or [] + self._invoke_results = {} + self._inbox = {} + self._sync_subs = {} + + def list_devices(self, device_type=None): + if device_type: + return [d for d in self._devices if d.get("device_type") == device_type] + return list(self._devices) + + def invoke(self, device_id, function, params=None, timeout=30.0): + key = (device_id, function) + if key in self._invoke_results: + return self._invoke_results[key] + return {"result": {"status": "ok"}} + + def broadcast(self, function, params=None, timeout=5.0): + results = [] + for d in self._devices: + try: + r = self.invoke(d["device_id"], function, params, timeout=timeout) + results.append({"device_id": d["device_id"], "result": r}) + except Exception as e: + results.append({"device_id": d["device_id"], "error": str(e)}) + return results + + def subscribe_buffered(self, subject, name=None): + name = name or subject + self._inbox[name] = [] + self._sync_subs[name] = True + return name + + def get_inbox(self, name=None): + if name is not None: + return {name: list(self._inbox.get(name, []))} + return {k: list(v) for k, v in self._inbox.items()} + + +SAMPLE_DEVICES = [ + { + "device_id": "so100-lab-1", + "device_type": "strands_robot", + "status": {"availability": "idle"}, + "functions": [{"name": "execute"}, {"name": "stop"}, {"name": "getStatus"}], + "events": ["taskStarted", "taskComplete"], + }, + { + "device_id": "panda-sim-1", + "device_type": "strands_sim", + "status": {"availability": "idle"}, + "functions": [{"name": "execute"}, {"name": "step"}, {"name": "reset"}], + "events": ["stateUpdate"], + }, +] + + +class TestDeviceConnectDispatch(unittest.TestCase): + """Test _device_connect_dispatch handles all 10 actions.""" + + def setUp(self): + self.conn = _FakeConnection(devices=SAMPLE_DEVICES) + # Patch get_connection to return our fake + self.patcher = patch( + "device_connect_agent_tools.connection.get_connection", + return_value=self.conn, + ) + self.patcher.start() + + # Import after mocking + from strands_robots.tools.robot_mesh import _device_connect_dispatch + self.dispatch = _device_connect_dispatch + + def tearDown(self): + self.patcher.stop() + + def _call(self, action, **kwargs): + defaults = dict( + target="", instruction="", command="", + policy_provider="mock", policy_port=0, + duration=30.0, timeout=5.0, + ) + defaults.update(kwargs) + return self.dispatch(action, **{k: defaults[k] for k in [ + "target", "instruction", "command", + "policy_provider", "policy_port", "duration", "timeout", + ]}) + + def test_peers(self): + result = self._call("peers") + self.assertEqual(result["status"], "success") + text = result["content"][0]["text"] + self.assertIn("so100-lab-1", text) + self.assertIn("panda-sim-1", text) + self.assertIn("2 device(s)", text) + + def test_tell(self): + result = self._call("tell", target="so100-lab-1", instruction="pick up cube") + self.assertEqual(result["status"], "success") + text = result["content"][0]["text"] + self.assertIn("so100-lab-1", text) + self.assertIn("pick up cube", text) + + def test_tell_missing_args(self): + result = self._call("tell", target="", instruction="") + self.assertEqual(result["status"], "error") + + def test_send(self): + result = self._call("send", target="so100-lab-1") + self.assertEqual(result["status"], "success") + + def test_send_with_command(self): + cmd = json.dumps({"action": "getFeatures"}) + result = self._call("send", target="so100-lab-1", command=cmd) + self.assertEqual(result["status"], "success") + + def test_stop(self): + result = self._call("stop", target="so100-lab-1") + self.assertEqual(result["status"], "success") + self.assertIn("Stop", result["content"][0]["text"]) + + def test_stop_missing_target(self): + result = self._call("stop", target="") + self.assertEqual(result["status"], "error") + + def test_emergency_stop(self): + result = self._call("emergency_stop") + self.assertEqual(result["status"], "success") + self.assertIn("E-STOP", result["content"][0]["text"]) + self.assertIn("2/2", result["content"][0]["text"]) + + def test_broadcast(self): + result = self._call("broadcast") + self.assertEqual(result["status"], "success") + text = result["content"][0]["text"] + self.assertIn("2 response(s)", text) + self.assertIn("so100-lab-1", text) + self.assertIn("panda-sim-1", text) + + def test_broadcast_with_command(self): + cmd = json.dumps({"function": "getStatus"}) + result = self._call("broadcast", command=cmd) + self.assertEqual(result["status"], "success") + + def test_subscribe(self): + result = self._call("subscribe", target="device-connect.default.*.event.>") + self.assertEqual(result["status"], "success") + text = result["content"][0]["text"] + self.assertIn("Subscribed", text) + self.assertIn("inbox", text) + # Verify subscription was created + self.assertIn("device-connect.default.*.event.>", self.conn._sync_subs) + + def test_subscribe_missing_target(self): + result = self._call("subscribe", target="") + self.assertEqual(result["status"], "error") + + def test_watch(self): + result = self._call("watch", target="so100-lab-1") + self.assertEqual(result["status"], "success") + text = result["content"][0]["text"] + self.assertIn("Watching", text) + self.assertIn("so100-lab-1", text) + # Verify subscription uses correct subject pattern + self.assertIn("stream:so100-lab-1", self.conn._sync_subs) + + def test_watch_missing_target(self): + result = self._call("watch", target="") + self.assertEqual(result["status"], "error") + + def test_inbox_empty(self): + result = self._call("inbox") + self.assertEqual(result["status"], "success") + self.assertIn("No subscriptions", result["content"][0]["text"]) + + def test_inbox_with_messages(self): + # Create a subscription and add messages + self.conn.subscribe_buffered("test-subject", name="test") + self.conn._inbox["test"] = [ + ("device-connect.default.so100-lab-1.event.taskStarted", {"event_name": "taskStarted", "device_id": "so100-lab-1"}), + ("device-connect.default.so100-lab-1.event.taskComplete", {"event_name": "taskComplete", "device_id": "so100-lab-1"}), + ] + result = self._call("inbox") + self.assertEqual(result["status"], "success") + text = result["content"][0]["text"] + self.assertIn("test", text) + self.assertIn("2 messages", text) + + def test_status(self): + result = self._call("status") + self.assertEqual(result["status"], "success") + text = result["content"][0]["text"] + self.assertIn("2 device(s)", text) + + def test_unknown_action(self): + result = self._call("nonexistent") + self.assertEqual(result["status"], "error") + self.assertIn("Unknown action", result["content"][0]["text"]) + # Verify all valid actions are listed in the error message + for a in ["peers", "tell", "send", "broadcast", "stop", + "emergency_stop", "status", "subscribe", "watch", "inbox"]: + self.assertIn(a, result["content"][0]["text"]) + + +class TestFallbackToZenoh(unittest.TestCase): + """Test that robot_mesh falls back to Zenoh when Device Connect fails.""" + + def test_fallback_on_dispatch_error(self): + """When _device_connect_dispatch raises, _mesh_dispatch should be called.""" + with patch("strands_robots.tools.robot_mesh._ensure_connected", side_effect=Exception("no DC")), \ + patch("strands_robots.tools.robot_mesh._mesh_dispatch") as mock_mesh: + mock_mesh.return_value = {"status": "success", "content": [{"text": "zenoh fallback"}]} + + from strands_robots.tools.robot_mesh import robot_mesh + result = robot_mesh(action="peers") + + mock_mesh.assert_called_once() + self.assertEqual(result["status"], "success") + self.assertIn("zenoh fallback", result["content"][0]["text"]) + + +if __name__ == "__main__": + unittest.main()