diff --git a/pyproject.toml b/pyproject.toml index 40382ec..127889b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,9 +48,13 @@ groot-service = [ lerobot = [ "lerobot>=0.5.0,<0.6.0", ] +mesh = [ + "eclipse-zenoh>=0.11.0,<1.0.0", +] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", + "strands-robots[mesh]", ] dev = [ "pytest>=6.0,<9.0.0", @@ -124,7 +128,7 @@ ignore_missing_imports = false # Third-party libs without type stubs [[tool.mypy.overrides]] -module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*"] +module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "zenoh.*"] ignore_missing_imports = true # @tool decorator injects runtime signatures mypy cannot check @@ -151,6 +155,11 @@ warn_return_any = false module = ["strands_robots.__init__"] disallow_untyped_defs = false +# Mesh session — zenoh typed as Any to avoid import-time dependency +[[tool.mypy.overrides]] +module = ["strands_robots.mesh_session"] +warn_return_any = false + # Registry modules — dynamic JSON loading returns Any [[tool.mypy.overrides]] module = ["strands_robots.registry.*"] diff --git a/strands_robots/mesh_session.py b/strands_robots/mesh_session.py new file mode 100644 index 0000000..b4c8032 --- /dev/null +++ b/strands_robots/mesh_session.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +""" +Zenoh session singleton — one session per process, ref-counted. + +Every Mesh instance shares the same ``zenoh.Session`` to avoid +duplicating discovery traffic and file descriptors. The session is +opened lazily on the first ``MeshSession.open()`` call and closed +when the last consumer calls ``MeshSession.close()`` (or the process +exits via the ``atexit`` hook). + +Fork safety +----------- +If the process is forked (``os.fork``), child processes get a stale +session. ``MeshSession.open()`` detects PID changes and re-initialises +the session automatically. + +Connection config +----------------- +By default Zenoh uses **multicast scouting** for peer discovery on the +local LAN. Override with environment variables:: + + # Connect to a specific endpoint (WAN / CI / Docker) + ZENOH_CONNECT=tcp/192.168.1.10:7447 + + # Listen on an explicit endpoint + ZENOH_LISTEN=tcp/0.0.0.0:7447 + + # Disable mesh entirely + STRANDS_MESH=false +""" + +from __future__ import annotations + +import atexit +import json +import logging +import os +import threading +from typing import Any + +logger = logging.getLogger(__name__) + + +class MeshSession: + """Process-level singleton over ``zenoh.Session``. + + Thread-safe, ref-counted, fork-aware. + + Usage:: + + session = MeshSession.open() # refcount +1 + # ... use session ... + MeshSession.close() # refcount -1; actual close at 0 + """ + + _lock = threading.Lock() + _session: Any = None # zenoh.Session (typed as Any to avoid import) + _refcount: int = 0 + _pid: int | None = None + _atexit_registered: bool = False + + @classmethod + def open(cls, config_overrides: dict[str, Any] | None = None) -> Any: + """Acquire the shared Zenoh session. + + Creates the session on first call. Subsequent calls increment + the reference count and return the same session. + + Args: + config_overrides: Optional dict of Zenoh config JSON5 paths + to values. Example:: + + {"connect/endpoints": ["tcp/127.0.0.1:7447"]} + + Overrides are applied *after* environment-variable config. + + Returns: + A ``zenoh.Session`` instance, or ``None`` if eclipse-zenoh + is not installed. + + Raises: + RuntimeError: If the Zenoh session cannot be opened after + applying configuration. + """ + with cls._lock: + # Fork detection: if PID changed, the session is stale. + current_pid = os.getpid() + if cls._session is not None and cls._pid != current_pid: + logger.warning( + "PID changed (%s → %s) — re-initialising Zenoh session (probable fork). Old session abandoned.", + cls._pid, + current_pid, + ) + # Don't close the parent's session — just discard our ref. + cls._session = None + cls._refcount = 0 + + if cls._session is not None: + cls._refcount += 1 + return cls._session + + # Lazy import — avoid pulling zenoh at strands_robots import time. + try: + import zenoh + except ImportError: + logger.debug( + "eclipse-zenoh not installed — mesh disabled. Install with: pip install strands-robots[mesh]" + ) + return None + + cfg = zenoh.Config() + + # --- Environment-variable overrides --- + connect = os.getenv("ZENOH_CONNECT") + if connect: + endpoints = [e.strip() for e in connect.split(",")] + try: + cfg.insert_json5("connect/endpoints", json.dumps(endpoints)) + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to apply ZENOH_CONNECT=%s: %s", connect, exc) + + listen = os.getenv("ZENOH_LISTEN") + if listen: + endpoints = [e.strip() for e in listen.split(",")] + try: + cfg.insert_json5("listen/endpoints", json.dumps(endpoints)) + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to apply ZENOH_LISTEN=%s: %s", listen, exc) + + # --- Programmatic overrides --- + if config_overrides: + for path, value in config_overrides.items(): + try: + cfg.insert_json5(path, json.dumps(value)) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to apply config override %s=%r: %s", + path, + value, + exc, + ) + + try: + cls._session = zenoh.open(cfg) + except Exception as exc: + raise RuntimeError( + f"Failed to open Zenoh session: {exc}. Check ZENOH_CONNECT / ZENOH_LISTEN env vars." + ) from exc + + cls._refcount = 1 + cls._pid = current_pid + + if not cls._atexit_registered: + atexit.register(cls._atexit_cleanup) + cls._atexit_registered = True + + logger.info("Zenoh mesh session opened (pid=%s)", current_pid) + return cls._session + + @classmethod + def close(cls) -> None: + """Release one reference to the shared session. + + When the reference count reaches zero the underlying + ``zenoh.Session`` is closed. + """ + with cls._lock: + if cls._refcount <= 0: + return + + cls._refcount -= 1 + if cls._refcount == 0 and cls._session is not None: + try: + cls._session.close() + except Exception: # noqa: BLE001 + pass # Best-effort; session may already be dead. + cls._session = None + cls._pid = None + logger.info("Zenoh mesh session closed (refcount → 0)") + + @classmethod + def _atexit_cleanup(cls) -> None: + """Best-effort cleanup at interpreter shutdown.""" + with cls._lock: + if cls._session is not None: + try: + cls._session.close() + except Exception: # noqa: BLE001 + pass + cls._session = None + cls._refcount = 0 + cls._pid = None + + # --- Introspection helpers (testing / debugging) --- + + @classmethod + def is_open(cls) -> bool: + """Return ``True`` if a session is currently open.""" + with cls._lock: + return cls._session is not None + + @classmethod + def refcount(cls) -> int: + """Return the current reference count.""" + with cls._lock: + return cls._refcount + + @classmethod + def _reset(cls) -> None: + """Force-reset internal state. **Testing only.**""" + with cls._lock: + if cls._session is not None: + try: + cls._session.close() + except Exception: # noqa: BLE001 + pass + cls._session = None + cls._refcount = 0 + cls._pid = None diff --git a/tests/test_mesh_session.py b/tests/test_mesh_session.py new file mode 100644 index 0000000..4d603b0 --- /dev/null +++ b/tests/test_mesh_session.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +"""Tests for strands_robots.mesh_session — Zenoh session singleton. + +All tests mock ``zenoh`` so no network or eclipse-zenoh installation +is required. Tests verify ref-counting, fork-detection, environment +variable configuration, and thread safety. +""" + +import os +import threading +from unittest.mock import MagicMock, patch + +import pytest + +from strands_robots.mesh_session import MeshSession + + +@pytest.fixture(autouse=True) +def _clean_session(): + """Ensure MeshSession is reset before and after each test.""" + MeshSession._reset() + yield + MeshSession._reset() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_zenoh(): + """Build a mock ``zenoh`` module with open() and Config.""" + mock_zenoh = MagicMock() + mock_session = MagicMock() + mock_zenoh.open.return_value = mock_session + mock_config = MagicMock() + mock_zenoh.Config.return_value = mock_config + return mock_zenoh, mock_session, mock_config + + +# --------------------------------------------------------------------------- +# Basic lifecycle +# --------------------------------------------------------------------------- + + +class TestSessionLifecycle: + """Verify open / close / refcount behaviour.""" + + @patch.dict(os.environ, {}, clear=False) + def test_open_returns_session(self): + mock_zenoh, mock_session, _ = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + session = MeshSession.open() + + assert session is mock_session + assert MeshSession.is_open() + assert MeshSession.refcount() == 1 + + @patch.dict(os.environ, {}, clear=False) + def test_second_open_reuses_session(self): + mock_zenoh, mock_session, _ = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + s1 = MeshSession.open() + s2 = MeshSession.open() + + assert s1 is s2 + assert MeshSession.refcount() == 2 + mock_zenoh.open.assert_called_once() # only one real open + + @patch.dict(os.environ, {}, clear=False) + def test_close_decrements_refcount(self): + mock_zenoh, mock_session, _ = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + MeshSession.open() + MeshSession.open() + + MeshSession.close() + assert MeshSession.refcount() == 1 + assert MeshSession.is_open() + mock_session.close.assert_not_called() + + @patch.dict(os.environ, {}, clear=False) + def test_close_at_zero_closes_session(self): + mock_zenoh, mock_session, _ = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + MeshSession.open() + + MeshSession.close() + assert MeshSession.refcount() == 0 + assert not MeshSession.is_open() + mock_session.close.assert_called_once() + + @patch.dict(os.environ, {}, clear=False) + def test_close_when_already_zero_is_noop(self): + MeshSession.close() # should not raise + assert MeshSession.refcount() == 0 + + +# --------------------------------------------------------------------------- +# Import failure (zenoh not installed) +# --------------------------------------------------------------------------- + + +class TestZenohNotInstalled: + """Verify graceful degradation when eclipse-zenoh is absent.""" + + def test_returns_none_when_zenoh_missing(self): + with patch.dict("sys.modules", {"zenoh": None}): + # importlib.import_module will raise ImportError for None + with patch("importlib.import_module", side_effect=ImportError("No module named 'zenoh'")): + # We need to ensure zenoh isn't importable + import sys + + original = sys.modules.get("zenoh") + sys.modules["zenoh"] = None + try: + session = MeshSession.open() + assert session is None + assert not MeshSession.is_open() + finally: + if original is not None: + sys.modules["zenoh"] = original + else: + sys.modules.pop("zenoh", None) + + +# --------------------------------------------------------------------------- +# Environment variable configuration +# --------------------------------------------------------------------------- + + +class TestEnvConfig: + """Verify ZENOH_CONNECT and ZENOH_LISTEN are applied.""" + + @patch.dict(os.environ, {"ZENOH_CONNECT": "tcp/10.0.0.1:7447"}, clear=False) + def test_connect_env_applied(self): + mock_zenoh, _, mock_config = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + MeshSession.open() + + mock_config.insert_json5.assert_any_call("connect/endpoints", '["tcp/10.0.0.1:7447"]') + + @patch.dict(os.environ, {"ZENOH_LISTEN": "tcp/0.0.0.0:7448"}, clear=False) + def test_listen_env_applied(self): + mock_zenoh, _, mock_config = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + MeshSession.open() + + mock_config.insert_json5.assert_any_call("listen/endpoints", '["tcp/0.0.0.0:7448"]') + + @patch.dict( + os.environ, + {"ZENOH_CONNECT": "tcp/a:1,tcp/b:2"}, + clear=False, + ) + def test_multiple_connect_endpoints(self): + mock_zenoh, _, mock_config = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + MeshSession.open() + + mock_config.insert_json5.assert_any_call("connect/endpoints", '["tcp/a:1", "tcp/b:2"]') + + +# --------------------------------------------------------------------------- +# Programmatic config overrides +# --------------------------------------------------------------------------- + + +class TestConfigOverrides: + """Verify programmatic config_overrides are applied.""" + + @patch.dict(os.environ, {}, clear=False) + def test_overrides_applied(self): + mock_zenoh, _, mock_config = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + MeshSession.open(config_overrides={"mode": "client"}) + + mock_config.insert_json5.assert_any_call("mode", '"client"') + + +# --------------------------------------------------------------------------- +# Fork detection +# --------------------------------------------------------------------------- + + +class TestForkDetection: + """Verify session re-init when PID changes (simulated fork).""" + + @patch.dict(os.environ, {}, clear=False) + def test_pid_change_reinitialises(self): + mock_zenoh, mock_session, _ = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + MeshSession.open() + assert mock_zenoh.open.call_count == 1 + + # Simulate fork — PID changes + MeshSession._pid = -1 + + MeshSession.open() + # Should have opened a new session + assert mock_zenoh.open.call_count == 2 + assert MeshSession.refcount() == 1 # reset, not incremented + + +# --------------------------------------------------------------------------- +# Open failure (RuntimeError) +# --------------------------------------------------------------------------- + + +class TestOpenFailure: + """Verify RuntimeError when zenoh.open() fails.""" + + @patch.dict(os.environ, {}, clear=False) + def test_raises_runtime_error_on_open_failure(self): + mock_zenoh = MagicMock() + mock_zenoh.open.side_effect = Exception("Connection refused") + mock_zenoh.Config.return_value = MagicMock() + + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + with pytest.raises(RuntimeError, match="Failed to open Zenoh session"): + MeshSession.open() + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + + +class TestThreadSafety: + """Verify concurrent open/close doesn't corrupt state.""" + + @patch.dict(os.environ, {}, clear=False) + def test_concurrent_open_close(self): + mock_zenoh, mock_session, _ = _make_mock_zenoh() + errors = [] + + def opener(): + try: + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + MeshSession.open() + except Exception as e: + errors.append(e) + + def closer(): + try: + MeshSession.close() + except Exception as e: + errors.append(e) + + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + # Open 10 times concurrently + threads = [threading.Thread(target=opener) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + + assert not errors, f"Errors in concurrent open: {errors}" + assert MeshSession.refcount() == 10 + + # Close 10 times concurrently + threads = [threading.Thread(target=closer) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + + assert not errors, f"Errors in concurrent close: {errors}" + assert MeshSession.refcount() == 0 + assert not MeshSession.is_open() + + +# --------------------------------------------------------------------------- +# atexit cleanup +# --------------------------------------------------------------------------- + + +class TestAtexitCleanup: + """Verify atexit hook cleans up properly.""" + + @patch.dict(os.environ, {}, clear=False) + def test_atexit_closes_session(self): + mock_zenoh, mock_session, _ = _make_mock_zenoh() + with patch.dict("sys.modules", {"zenoh": mock_zenoh}): + MeshSession.open() + + MeshSession._atexit_cleanup() + assert not MeshSession.is_open() + assert MeshSession.refcount() == 0 + mock_session.close.assert_called_once() + + def test_atexit_noop_when_no_session(self): + MeshSession._atexit_cleanup() # should not raise