diff --git a/pyproject.toml b/pyproject.toml index 40382ec..c7ffabe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,9 +48,13 @@ groot-service = [ lerobot = [ "lerobot>=0.5.0,<0.6.0", ] +mesh = [ + "eclipse-zenoh>=0.11.0,<0.12.0", +] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", + "strands-robots[mesh]", ] dev = [ "pytest>=6.0,<9.0.0", @@ -151,6 +155,11 @@ warn_return_any = false module = ["strands_robots.__init__"] disallow_untyped_defs = false +# Mesh session — zenoh types are Any to avoid import-time dependency +[[tool.mypy.overrides]] +module = ["strands_robots.mesh_session"] +warn_return_any = false + # Registry modules — dynamic JSON loading returns Any [[tool.mypy.overrides]] module = ["strands_robots.registry.*"] diff --git a/strands_robots/mesh_session.py b/strands_robots/mesh_session.py new file mode 100644 index 0000000..0f404ef --- /dev/null +++ b/strands_robots/mesh_session.py @@ -0,0 +1,211 @@ +"""Zenoh session singleton — ONE session per process, ref-counted. + +This module manages a shared ``zenoh.Session`` so that multiple ``Mesh`` +instances (one per Robot/Simulation) reuse a single network socket. + +Session lifecycle:: + + session = get_session() # ref +1, opens on first call + session2 = get_session() # ref +1, returns same session + release_session() # ref -1 + release_session() # ref → 0, session.close() + +Environment variables +--------------------- +ZENOH_CONNECT + Comma-separated Zenoh endpoints to connect to. + Example: ``tcp/10.0.0.5:7447,tcp/10.0.0.6:7447`` +ZENOH_LISTEN + Comma-separated Zenoh endpoints to listen on. + Example: ``tcp/0.0.0.0:7447`` +STRANDS_MESH_PORT + Local auto-mesh port (default 7447). The first process on a host + listens; subsequent processes connect as clients. +STRANDS_MESH + Set to ``false`` to disable mesh globally. + +Requires ``pip install strands-robots[mesh]`` (eclipse-zenoh). +""" + +from __future__ import annotations + +import json +import logging +import os +import threading +from dataclasses import dataclass +from typing import Any + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +DEFAULT_MESH_PORT = 7447 + + +@dataclass(frozen=True) +class MeshConfig: + """Immutable Zenoh connection configuration. + + Attributes: + connect: Zenoh endpoints to connect to (e.g. ``("tcp/10.0.0.5:7447",)``). + listen: Zenoh endpoints to listen on (e.g. ``("tcp/0.0.0.0:7447",)``). + port: Local auto-mesh port used when neither *connect* nor *listen* + is specified. Default ``7447``. + """ + + connect: tuple[str, ...] = () + listen: tuple[str, ...] = () + port: int = DEFAULT_MESH_PORT + + @classmethod + def from_env(cls) -> MeshConfig: + """Build a ``MeshConfig`` from environment variables. + + Reads ``ZENOH_CONNECT``, ``ZENOH_LISTEN``, and + ``STRANDS_MESH_PORT``. Missing variables produce empty tuples / + the default port. + """ + connect_raw = os.getenv("ZENOH_CONNECT", "") + listen_raw = os.getenv("ZENOH_LISTEN", "") + port = int(os.getenv("STRANDS_MESH_PORT", str(DEFAULT_MESH_PORT))) + + connect = tuple(e.strip() for e in connect_raw.split(",") if e.strip()) if connect_raw else () + listen = tuple(e.strip() for e in listen_raw.split(",") if e.strip()) if listen_raw else () + + return cls(connect=connect, listen=listen, port=port) + + +# --------------------------------------------------------------------------- +# Session singleton +# --------------------------------------------------------------------------- + +_SESSION: Any = None # Optional[zenoh.Session] — typed as Any to avoid import-time dep +_SESSION_LOCK = threading.Lock() +_SESSION_REFS: int = 0 + + +def _apply_config(zenoh_config: Any, config: MeshConfig) -> None: + """Apply *config* endpoints to a ``zenoh.Config`` object. + + Mutates *zenoh_config* in-place via ``insert_json5``. + """ + if config.connect: + zenoh_config.insert_json5("connect/endpoints", json.dumps(list(config.connect))) + if config.listen: + zenoh_config.insert_json5("listen/endpoints", json.dumps(list(config.listen))) + + +def get_session(config: MeshConfig | None = None) -> Any: + """Acquire the shared Zenoh session (lazy, ref-counted). + + On the first call the session is opened. Subsequent calls increment + the reference count and return the same session. + + When neither ``ZENOH_CONNECT`` nor ``ZENOH_LISTEN`` are set, *auto-mesh* + kicks in: the first process on the host listens on + ``tcp/127.0.0.1:{port}``; later processes connect as clients. + + Parameters + ---------- + config: + Optional explicit configuration. If ``None``, reads from + environment variables via :meth:`MeshConfig.from_env`. + + Returns + ------- + zenoh.Session | None + The shared session, or ``None`` if eclipse-zenoh is not installed + or the global kill-switch ``STRANDS_MESH=false`` is set. + """ + global _SESSION, _SESSION_REFS + + # Global kill switch + if os.getenv("STRANDS_MESH", "true").lower() == "false": + return None + + with _SESSION_LOCK: + if _SESSION is not None: + _SESSION_REFS += 1 + return _SESSION + + # Lazy import — zenoh is optional + try: + import importlib + + zenoh = importlib.import_module("zenoh") + except ImportError: + logger.debug("eclipse-zenoh not installed — mesh disabled (pip install strands-robots[mesh])") + return None + + if config is None: + config = MeshConfig.from_env() + + # If explicit endpoints are configured, use them directly + if config.connect or config.listen: + zenoh_config = zenoh.Config() + _apply_config(zenoh_config, config) + _SESSION = zenoh.open(zenoh_config) + _SESSION_REFS = 1 + logger.info("Zenoh mesh session opened (explicit config)") + return _SESSION + + # Auto-mesh: try listen+connect on localhost (first process wins) + mesh_ep = f"tcp/127.0.0.1:{config.port}" + + try: + cfg_listen = zenoh.Config() + cfg_listen.insert_json5("listen/endpoints", json.dumps([mesh_ep])) + cfg_listen.insert_json5("connect/endpoints", json.dumps([mesh_ep])) + _SESSION = zenoh.open(cfg_listen) + _SESSION_REFS = 1 + logger.info("Zenoh mesh session opened (listener on %s)", mesh_ep) + return _SESSION + except Exception: + # Port already taken — another process is listening + pass + + # Connect as client to the existing listener + cfg_client = zenoh.Config() + cfg_client.insert_json5("mode", '"client"') + cfg_client.insert_json5("connect/endpoints", json.dumps([mesh_ep])) + _SESSION = zenoh.open(cfg_client) + _SESSION_REFS = 1 + logger.info("Zenoh mesh session opened (client → %s)", mesh_ep) + return _SESSION + + +def release_session() -> None: + """Release one reference to the shared session. + + When the reference count reaches zero the session is closed. + """ + global _SESSION, _SESSION_REFS + + with _SESSION_LOCK: + if _SESSION_REFS <= 0: + return + _SESSION_REFS -= 1 + if _SESSION_REFS <= 0 and _SESSION is not None: + try: + _SESSION.close() + except Exception: + pass + _SESSION = None + _SESSION_REFS = 0 + logger.info("Zenoh mesh session closed") + + +def session_info() -> dict[str, Any]: + """Return diagnostic info about the current session state. + + Useful for dashboards and debugging. Does not acquire or release + the session. + """ + with _SESSION_LOCK: + return { + "active": _SESSION is not None, + "refs": _SESSION_REFS, + } diff --git a/tests/test_mesh_session.py b/tests/test_mesh_session.py new file mode 100644 index 0000000..7a834d9 --- /dev/null +++ b/tests/test_mesh_session.py @@ -0,0 +1,258 @@ +"""Tests for strands_robots.mesh_session — Zenoh session singleton.""" + +from __future__ import annotations + +import importlib +import threading +from unittest.mock import MagicMock, patch + +import pytest + +from strands_robots import mesh_session +from strands_robots.mesh_session import ( + DEFAULT_MESH_PORT, + MeshConfig, + get_session, + release_session, + session_info, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _reset_session_state() -> None: + """Reset module-level session state between tests.""" + mesh_session._SESSION = None + mesh_session._SESSION_REFS = 0 + + +@pytest.fixture(autouse=True) +def _clean_session(): + """Ensure each test starts and ends with a clean session.""" + _reset_session_state() + yield + _reset_session_state() + + +@pytest.fixture() +def mock_zenoh(): + """Provide a mock ``zenoh`` module with ``Config`` and ``open``.""" + mock_module = MagicMock() + mock_config = MagicMock() + mock_session = MagicMock() + + mock_module.Config.return_value = mock_config + mock_module.open.return_value = mock_session + + with patch.object(importlib, "import_module", return_value=mock_module): + yield {"module": mock_module, "config": mock_config, "session": mock_session} + + +# --------------------------------------------------------------------------- +# MeshConfig +# --------------------------------------------------------------------------- + + +class TestMeshConfig: + """Tests for MeshConfig dataclass.""" + + def test_defaults(self): + cfg = MeshConfig() + assert cfg.connect == () + assert cfg.listen == () + assert cfg.port == DEFAULT_MESH_PORT + + def test_frozen(self): + cfg = MeshConfig() + with pytest.raises(AttributeError): + cfg.port = 9999 # type: ignore[misc] + + def test_from_env_connect(self, monkeypatch): + monkeypatch.setenv("ZENOH_CONNECT", "tcp/10.0.0.1:7447,tcp/10.0.0.2:7447") + monkeypatch.delenv("ZENOH_LISTEN", raising=False) + monkeypatch.delenv("STRANDS_MESH_PORT", raising=False) + + cfg = MeshConfig.from_env() + assert cfg.connect == ("tcp/10.0.0.1:7447", "tcp/10.0.0.2:7447") + assert cfg.listen == () + assert cfg.port == DEFAULT_MESH_PORT + + def test_from_env_listen(self, monkeypatch): + monkeypatch.delenv("ZENOH_CONNECT", raising=False) + monkeypatch.setenv("ZENOH_LISTEN", "tcp/0.0.0.0:7447") + monkeypatch.delenv("STRANDS_MESH_PORT", raising=False) + + cfg = MeshConfig.from_env() + assert cfg.connect == () + assert cfg.listen == ("tcp/0.0.0.0:7447",) + + def test_from_env_port(self, monkeypatch): + monkeypatch.delenv("ZENOH_CONNECT", raising=False) + monkeypatch.delenv("ZENOH_LISTEN", raising=False) + monkeypatch.setenv("STRANDS_MESH_PORT", "8888") + + cfg = MeshConfig.from_env() + assert cfg.port == 8888 + + def test_from_env_empty(self, monkeypatch): + monkeypatch.delenv("ZENOH_CONNECT", raising=False) + monkeypatch.delenv("ZENOH_LISTEN", raising=False) + monkeypatch.delenv("STRANDS_MESH_PORT", raising=False) + + cfg = MeshConfig.from_env() + assert cfg == MeshConfig() + + +# --------------------------------------------------------------------------- +# get_session / release_session +# --------------------------------------------------------------------------- + + +class TestGetSession: + """Tests for session acquisition and ref-counting.""" + + def test_returns_session(self, mock_zenoh): + session = get_session() + assert session is mock_zenoh["session"] + + def test_refcounting_same_session(self, mock_zenoh): + s1 = get_session() + s2 = get_session() + assert s1 is s2 + # zenoh.open should only be called once + assert mock_zenoh["module"].open.call_count == 1 + assert session_info()["refs"] == 2 + + def test_release_does_not_close_above_zero(self, mock_zenoh): + get_session() + get_session() + release_session() # refs 2 → 1 + mock_zenoh["session"].close.assert_not_called() + assert session_info()["active"] is True + assert session_info()["refs"] == 1 + + def test_release_closes_at_zero(self, mock_zenoh): + get_session() + release_session() # refs 1 → 0 + mock_zenoh["session"].close.assert_called_once() + assert session_info()["active"] is False + assert session_info()["refs"] == 0 + + def test_session_reopens_after_full_release(self, mock_zenoh): + get_session() + release_session() + # Now get a fresh session + s = get_session() + assert s is mock_zenoh["session"] + assert mock_zenoh["module"].open.call_count == 2 + + def test_release_noop_when_no_session(self): + # Should not raise + release_session() + assert session_info()["refs"] == 0 + + +class TestGetSessionConfig: + """Tests for configuration application.""" + + def test_explicit_connect_config(self, mock_zenoh): + cfg = MeshConfig(connect=("tcp/10.0.0.1:7447",)) + get_session(config=cfg) + + # Should have called insert_json5 with connect endpoints + mock_zenoh["config"].insert_json5.assert_any_call("connect/endpoints", '["tcp/10.0.0.1:7447"]') + + def test_explicit_listen_config(self, mock_zenoh): + cfg = MeshConfig(listen=("tcp/0.0.0.0:7447",)) + get_session(config=cfg) + + mock_zenoh["config"].insert_json5.assert_any_call("listen/endpoints", '["tcp/0.0.0.0:7447"]') + + def test_auto_mesh_first_process_listens(self, mock_zenoh, monkeypatch): + monkeypatch.delenv("ZENOH_CONNECT", raising=False) + monkeypatch.delenv("ZENOH_LISTEN", raising=False) + + get_session() + + # First call to zenoh.open should try listen+connect (auto-mesh) + first_config = mock_zenoh["module"].Config.return_value + first_config.insert_json5.assert_any_call("listen/endpoints", '["tcp/127.0.0.1:7447"]') + + def test_auto_mesh_client_fallback(self, mock_zenoh, monkeypatch): + """When the first open (listen) fails, falls back to client mode.""" + monkeypatch.delenv("ZENOH_CONNECT", raising=False) + monkeypatch.delenv("ZENOH_LISTEN", raising=False) + + # First open raises (port taken), second succeeds + mock_zenoh["module"].open.side_effect = [OSError("Address in use"), mock_zenoh["session"]] + # Need 2 Config instances for the 2 attempts + cfg1, cfg2 = MagicMock(), MagicMock() + mock_zenoh["module"].Config.side_effect = [cfg1, cfg2] + + session = get_session() + assert session is mock_zenoh["session"] + + # Second config should be client mode + cfg2.insert_json5.assert_any_call("mode", '"client"') + + +class TestGetSessionDisabled: + """Tests for disabled/unavailable scenarios.""" + + def test_global_kill_switch(self, monkeypatch, mock_zenoh): + monkeypatch.setenv("STRANDS_MESH", "false") + assert get_session() is None + mock_zenoh["module"].open.assert_not_called() + + def test_global_kill_switch_case_insensitive(self, monkeypatch, mock_zenoh): + monkeypatch.setenv("STRANDS_MESH", "False") + assert get_session() is None + + def test_zenoh_not_installed(self): + with patch.object(importlib, "import_module", side_effect=ImportError("no zenoh")): + assert get_session() is None + + +class TestSessionInfo: + """Tests for session_info diagnostic.""" + + def test_inactive(self): + info = session_info() + assert info["active"] is False + assert info["refs"] == 0 + + def test_active(self, mock_zenoh): + get_session() + info = session_info() + assert info["active"] is True + assert info["refs"] == 1 + + +class TestThreadSafety: + """Basic thread-safety smoke tests.""" + + def test_concurrent_get_and_release(self, mock_zenoh): + """Multiple threads acquiring and releasing should not crash.""" + errors: list[Exception] = [] + + def worker(): + try: + for _ in range(50): + get_session() + for _ in range(50): + release_session() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + assert not errors, f"Thread errors: {errors}" + # All refs should be released + info = session_info() + assert info["refs"] == 0