From cd946274f8546051682903ac7f657046d983fa7b Mon Sep 17 00:00:00 2001 From: cagataycali Date: Tue, 31 Mar 2026 21:05:17 -0400 Subject: [PATCH 01/22] =?UTF-8?q?feat:=20MuJoCo=20simulation=20backend=20?= =?UTF-8?q?=E2=80=94=20AgentTool=20with=2035=20actions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Complete MuJoCo simulation backend composed of focused mixins: Simulation(AgentTool) ├── PhysicsMixin # raycasting, jacobians, energy, forces, │ # mass matrix, checkpoints, inverse dynamics ├── PolicyRunnerMixin # run_policy, eval_policy, replay_episode ├── RenderingMixin # RGB/depth offscreen rendering, observations ├── RecordingMixin # LeRobot dataset recording └── RandomizationMixin # domain randomization (colors, lighting, physics) Supporting modules: - backend.py: lazy mujoco import + headless GL auto-config (EGL/OSMesa/GLFW) - mjcf_builder.py: procedural MJCF XML generation from dataclasses - scene_ops.py: XML round-trip for runtime object/camera injection - simulation.py: orchestrator dispatching 35 actions via tool_spec.json - dataset_recorder.py: LeRobot v3 format recorder (parquet + video) Key design decisions: - Simulation extends AgentTool directly: Agent(tools=[Simulation()]) works - Lazy MuJoCo import via _ensure_mujoco() — only when first needed - XML round-trip for scene modification (standard: dm_control, robosuite) - Same Policy ABC for sim and real — zero code changes for transfer Tests: 47 new tests (12 E2E + 35 physics unit tests) All use self-contained inline XML robots (no external files needed). --- strands_robots/_async_utils.py | 28 + strands_robots/dataset_recorder.py | 515 ++++++++++ strands_robots/simulation/mujoco/__init__.py | 41 + strands_robots/simulation/mujoco/backend.py | 132 +++ .../simulation/mujoco/mjcf_builder.py | 197 ++++ strands_robots/simulation/mujoco/physics.py | 821 +++++++++++++++ .../simulation/mujoco/policy_runner.py | 356 +++++++ .../simulation/mujoco/randomization.py | 74 ++ strands_robots/simulation/mujoco/recording.py | 152 +++ strands_robots/simulation/mujoco/rendering.py | 225 +++++ strands_robots/simulation/mujoco/scene_ops.py | 211 ++++ .../simulation/mujoco/simulation.py | 949 ++++++++++++++++++ .../simulation/mujoco/tool_spec.json | 351 +++++++ tests/test_mujoco_e2e.py | 269 +++++ tests/test_physics.py | 350 +++++++ 15 files changed, 4671 insertions(+) create mode 100644 strands_robots/_async_utils.py create mode 100644 strands_robots/dataset_recorder.py create mode 100644 strands_robots/simulation/mujoco/__init__.py create mode 100644 strands_robots/simulation/mujoco/backend.py create mode 100644 strands_robots/simulation/mujoco/mjcf_builder.py create mode 100644 strands_robots/simulation/mujoco/physics.py create mode 100644 strands_robots/simulation/mujoco/policy_runner.py create mode 100644 strands_robots/simulation/mujoco/randomization.py create mode 100644 strands_robots/simulation/mujoco/recording.py create mode 100644 strands_robots/simulation/mujoco/rendering.py create mode 100644 strands_robots/simulation/mujoco/scene_ops.py create mode 100644 strands_robots/simulation/mujoco/simulation.py create mode 100644 strands_robots/simulation/mujoco/tool_spec.json create mode 100644 tests/test_mujoco_e2e.py create mode 100644 tests/test_physics.py diff --git a/strands_robots/_async_utils.py b/strands_robots/_async_utils.py new file mode 100644 index 0000000..91819a3 --- /dev/null +++ b/strands_robots/_async_utils.py @@ -0,0 +1,28 @@ +"""Async-to-sync helper for resolving coroutines in sync contexts.""" + +import asyncio +import concurrent.futures + + +def _resolve_coroutine(coro_or_result): + """Safely resolve a potentially-async result to a sync value. + + Handles three cases: + 1. Already a plain value → return as-is + 2. Coroutine, no running loop → asyncio.run() + 3. Coroutine, inside running loop → offload to thread + + Args: + coro_or_result: Either a coroutine or an already-resolved value. + + Returns: + The resolved (sync) value. + """ + if not asyncio.iscoroutine(coro_or_result): + return coro_or_result + try: + asyncio.get_running_loop() + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex: + return ex.submit(asyncio.run, coro_or_result).result() + except RuntimeError: + return asyncio.run(coro_or_result) diff --git a/strands_robots/dataset_recorder.py b/strands_robots/dataset_recorder.py new file mode 100644 index 0000000..8f25624 --- /dev/null +++ b/strands_robots/dataset_recorder.py @@ -0,0 +1,515 @@ +"""LeRobotDataset recorder bridge for strands-robots. + +Wraps LeRobotDataset so that both robot.py (real hardware) and +simulation.py (MuJoCo) can produce training-ready datasets with +a single add_frame() call per control step. + +Usage: + recorder = DatasetRecorder.create( + repo_id="user/my_dataset", + fps=30, + robot_features=robot.observation_features, + action_features=robot.action_features, + task="pick up the red cube", + ) + # In control loop: + recorder.add_frame(observation, action, task="pick up the red cube") + # End of episode: + recorder.save_episode() + # Optionally: + recorder.push_to_hub() +""" + +import functools +import logging +import sys +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + +# ── Lazy check for LeRobot availability ────────────────────────────── +# We must NOT import lerobot at module level because it pulls in +# `datasets` → `pandas`, which can crash with a numpy ABI mismatch on +# systems where the system pandas was compiled against an older numpy +# (e.g. JetPack / Jetson with system pandas 2.1.4 + pip numpy 2.x). + + +@functools.lru_cache(maxsize=1) +def has_lerobot_dataset() -> bool: + """Check if lerobot is available. Result is cached after first call.""" + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: F401 + + return True + except (ImportError, ValueError, RuntimeError) as exc: + logger.debug("lerobot not available: %s", exc) + return False + + +def _get_lerobot_dataset_class(): + """Import and return LeRobotDataset class, or raise ImportError. + + Supports test mocking: if ``strands_robots.dataset_recorder.LeRobotDataset`` + has been set (by a test mock), returns that class directly. + """ + # Support test mocking: check module-level overrides + this_module = sys.modules[__name__] + + # If a test injected a mock LeRobotDataset class, use it + mock_cls = getattr(this_module, "LeRobotDataset", None) + if mock_cls is not None: + return mock_cls + + # Actual import + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + return LeRobotDataset + except (ImportError, ValueError, RuntimeError) as exc: + raise ImportError( + f"lerobot not available ({exc}). Install with: pip install lerobot\nRequired for LeRobotDataset recording." + ) from exc + + +def _numpy_ify(v): + """Convert any value to numpy-friendly format for add_frame.""" + if hasattr(v, "numpy"): + return v.numpy() + if hasattr(v, "tolist") and isinstance(v, np.ndarray): + return v + if isinstance(v, (int, float)): + return np.array([v], dtype=np.float32) + if isinstance(v, list): + return np.array(v, dtype=np.float32) + return v + + +class DatasetRecorder: + """Bridge between strands-robots control loops and LeRobotDataset. + + Handles the full lifecycle: + 1. create() — build LeRobotDataset with correct features + 2. add_frame() — called every control step with obs + action + 3. save_episode() — finalize episode (encodes video, writes parquet) + 4. push_to_hub() — upload to HuggingFace + + Works for both real hardware (robot.py) and simulation (simulation.py). + """ + + def __init__(self, dataset, task: str = ""): + self.dataset = dataset + self.default_task = task + self.frame_count = 0 + self.dropped_frame_count = 0 + self.episode_count = 0 + self._closed = False + self._cached_state_keys: list[str] | None = None + self._cached_action_keys: list[str] | None = None + + @classmethod + def create( + cls, + repo_id: str, + fps: int = 30, + robot_type: str = "unknown", + robot_features: dict[str, Any] | None = None, + action_features: dict[str, Any] | None = None, + camera_keys: list[str] | None = None, + joint_names: list[str] | None = None, + task: str = "", + root: str | None = None, + use_videos: bool = True, + vcodec: str = "libsvtav1", + streaming_encoding: bool = True, + image_writer_threads: int = 4, + video_backend: str = "auto", + ) -> "DatasetRecorder": + """Create a new DatasetRecorder with auto-detected features. + + Args: + repo_id: HuggingFace dataset ID (e.g. "user/my_dataset") + fps: Recording frame rate + robot_type: Robot type string (e.g. "so100", "panda") + robot_features: Dict of observation feature names → types + (from robot.observation_features or sim joint names) + action_features: Dict of action feature names → types + camera_keys: List of camera names (images become video features) + joint_names: List of joint names (alternative to robot_features for sim) + task: Default task description + root: Local directory for dataset storage + use_videos: Encode camera frames as video (True) or keep as images + vcodec: Video codec (h264, hevc, libsvtav1) + streaming_encoding: Stream-encode video during capture + image_writer_threads: Threads for writing image frames + video_backend: Video backend for encoding ("auto" for HW encoder auto-detect) + """ + # Lazy import — this is where we actually need lerobot + LeRobotDatasetCls = _get_lerobot_dataset_class() + + # Build features dict in LeRobot format + features = cls._build_features( + robot_features=robot_features, + action_features=action_features, + camera_keys=camera_keys, + joint_names=joint_names, + use_videos=use_videos, + ) + + logger.info(f"Creating LeRobotDataset: {repo_id} @ {fps}fps, {len(features)} features, robot_type={robot_type}") + + # Build kwargs, skip unsupported params for this LeRobot version + create_kwargs = dict( + repo_id=repo_id, + fps=fps, + root=root, + robot_type=robot_type, + features=features, + use_videos=use_videos, + image_writer_threads=image_writer_threads, + vcodec=vcodec, + ) + # streaming_encoding only in newer LeRobot versions + import inspect + + create_sig = inspect.signature(LeRobotDatasetCls.create) + if "streaming_encoding" in create_sig.parameters: + create_kwargs["streaming_encoding"] = streaming_encoding + if "video_backend" in create_sig.parameters: + create_kwargs["video_backend"] = video_backend + dataset = LeRobotDatasetCls.create(**create_kwargs) + + recorder = cls(dataset=dataset, task=task) + logger.info("DatasetRecorder ready: %s", repo_id) + return recorder + + @classmethod + def _build_features( + cls, + robot_features: dict | None = None, + action_features: dict | None = None, + camera_keys: list[str] | None = None, + joint_names: list[str] | None = None, + use_videos: bool = True, + ) -> dict[str, Any]: + """Build LeRobot v3-compatible features dict. + + LeRobot v3 features format: + { + "observation.images.camera_name": {"dtype": "video", "shape": (C, H, W), "names": [...]}, + "observation.state": {"dtype": "float32", "shape": (N,), "names": [...]}, + "action": {"dtype": "float32", "shape": (N,), "names": [...]}, + } + + Note: "names" must be a flat list of strings, NOT a dict like {"motors": [...]}. + """ + features = {} + + # --- Observation: cameras → video/image features --- + if camera_keys: + for cam_name in camera_keys: + key = f"observation.images.{cam_name}" + dtype = "video" if use_videos else "image" + features[key] = { + "dtype": dtype, + "shape": ( + 3, + 480, + 640, + ), # CHW default, actual shape set on first frame + "names": ["channels", "height", "width"], + } + + # --- Observation: state (joint positions) --- + state_dim = 0 + state_names = [] + if robot_features: + # Count scalar features (exclude cameras) + state_keys = [ + k + for k, v in robot_features.items() + if not isinstance(v, dict) or v.get("dtype") not in ("image", "video") + ] + state_dim = len(state_keys) + state_names = state_keys + elif joint_names: + state_dim = len(joint_names) + state_names = list(joint_names) + + if state_dim > 0: + features["observation.state"] = { + "dtype": "float32", + "shape": (state_dim,), + "names": state_names, + } + + # --- Action --- + action_dim = 0 + action_names = [] + if action_features: + action_keys = [ + k + for k, v in action_features.items() + if not isinstance(v, dict) or v.get("dtype") not in ("image", "video") + ] + action_dim = len(action_keys) + action_names = action_keys + elif joint_names: + action_dim = len(joint_names) + action_names = list(joint_names) + elif state_dim > 0: + action_dim = state_dim # Same dim as state by default + action_names = state_names[:] + + if action_dim > 0: + features["action"] = { + "dtype": "float32", + "shape": (action_dim,), + "names": action_names[:action_dim], + } + + return features + + def add_frame( + self, + observation: dict[str, Any], + action: dict[str, Any], + task: str | None = None, + camera_keys: list[str] | None = None, + ) -> None: + """Add a single control-loop frame to the dataset. + + This is the key method — called every step in the control loop. + + Args: + observation: Raw observation dict from robot/sim + (joint_name → float, camera_name → np.ndarray) + action: Action dict (joint_name → float) + task: Task description (uses default if None) + camera_keys: Which keys in observation are camera images + """ + if self._closed: + return + + frame = {} + + # --- Detect camera vs state keys --- + if camera_keys is None: + camera_keys = [k for k, v in observation.items() if isinstance(v, np.ndarray) and v.ndim >= 2] + + state_keys = [k for k in observation.keys() if k not in camera_keys] + + # --- Camera images → observation.images.{name} --- + for cam_key in camera_keys: + img = observation[cam_key] + if isinstance(img, np.ndarray): + # LeRobot expects HWC uint8 for add_frame + if img.dtype != np.uint8: + img = (np.clip(img, 0, 1) * 255).astype(np.uint8) + frame[f"observation.images.{cam_key}"] = img + + # --- State → observation.state (flattened vector) --- + # Use feature schema ordering to match the dataset schema declared in _build_features(). + if state_keys: + state_vals = [] + if self._cached_state_keys is None: + feat = self.dataset.features.get("observation.state", {}) + state_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", []) + self._cached_state_keys = state_names if state_names else sorted(state_keys) + + for k in self._cached_state_keys: + v = observation.get(k) + if v is None: + state_vals.append(0.0) + elif isinstance(v, (int, float)): + state_vals.append(float(v)) + elif isinstance(v, np.ndarray) and v.ndim == 0: + state_vals.append(float(v)) + elif isinstance(v, (list, np.ndarray)): + arr = np.asarray(v, dtype=np.float32).flatten() + state_vals.extend(arr.tolist()) + if state_vals: + frame["observation.state"] = np.array(state_vals, dtype=np.float32) + + # --- Action → flattened vector --- + # Use feature schema ordering for actions too. + if action: + action_vals = [] + if self._cached_action_keys is None: + feat = self.dataset.features.get("action", {}) + action_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", []) + self._cached_action_keys = action_names if action_names else sorted(action.keys()) + + for k in self._cached_action_keys: + v = action.get(k) + if v is None: + action_vals.append(0.0) + elif isinstance(v, (int, float)): + action_vals.append(float(v)) + elif isinstance(v, np.ndarray) and v.ndim == 0: + action_vals.append(float(v)) + elif isinstance(v, (list, np.ndarray)): + arr = np.asarray(v, dtype=np.float32).flatten() + action_vals.extend(arr.tolist()) + if action_vals: + frame["action"] = np.array(action_vals, dtype=np.float32) + + # --- Task (mandatory for LeRobot v3) --- + frame["task"] = task or self.default_task or "untitled" + + # --- Reconcile camera keys between frame and feature schema --- + # Only strip *undeclared* cameras from the frame (keys present in obs + # but not registered in _build_features). This avoids LeRobot's + # "Extra features" error. Declared-but-missing cameras (e.g. when a + # render fails) are left alone — LeRobot tolerates absent columns and + # the episode simply won't have that camera's data. + declared_cam_keys = {k for k in self.dataset.features if k.startswith("observation.images.")} + frame_cam_keys = {k for k in frame if k.startswith("observation.images.")} + for extra in frame_cam_keys - declared_cam_keys: + del frame[extra] + + # --- Add to dataset --- + try: + self.dataset.add_frame(frame) + self.frame_count += 1 + except Exception as e: + self.dropped_frame_count += 1 + n = self.dropped_frame_count + # Log at 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, then every 1000 + if (n & (n - 1)) == 0 or n % 1000 == 0: + logger.warning( + "add_frame failed (frame %d, dropped %d): %s", + self.frame_count, + self.dropped_frame_count, + e, + ) + + def save_episode(self) -> dict[str, Any]: + """Finalize current episode — writes parquet, encodes video, computes stats. + + LeRobot v3: save_episode() takes no task argument. Tasks are stored + per-frame in the episode buffer via add_frame(). + + Returns: + Dict with episode info + """ + if self._closed: + return {"status": "error", "message": "Recorder closed"} + + try: + self.dataset.save_episode() + self.episode_count += 1 + ep_frames = self.frame_count # Total frames so far + logger.info(f"Episode {self.episode_count} saved: {ep_frames} total frames") + return { + "status": "success", + "episode": self.episode_count, + "total_frames": ep_frames, + } + except Exception as e: + logger.error("save_episode failed: %s", e) + return {"status": "error", "message": str(e)} + + def finalize(self) -> None: + """Finalize the dataset (close parquet writers, flush metadata).""" + if self._closed: + return + try: + self.dataset.finalize() + except Exception as e: + logger.warning("finalize warning: %s", e) + self._closed = True + + def push_to_hub( + self, + tags: list[str] | None = None, + private: bool = False, + ) -> dict[str, Any]: + """Push dataset to HuggingFace Hub. + + Args: + tags: Optional tags for the dataset + private: Upload as private dataset + + Returns: + Dict with push status + """ + try: + self.dataset.push_to_hub(tags=tags, private=private) + logger.info("Dataset pushed to hub: %s", self.dataset.repo_id) + return { + "status": "success", + "repo_id": self.dataset.repo_id, + "episodes": self.episode_count, + "frames": self.frame_count, + } + except Exception as e: + logger.error("push_to_hub failed: %s", e) + return {"status": "error", "message": str(e)} + + @property + def repo_id(self) -> str: + return self.dataset.repo_id + + @property + def root(self) -> str: + return str(self.dataset.root) + + def __repr__(self) -> str: + return f"DatasetRecorder(repo_id={self.repo_id}, episodes={self.episode_count}, frames={self.frame_count})" + + +# ── Shared replay-episode helpers ──────────────────────────────────── + + +def load_lerobot_episode(repo_id: str, episode: int = 0, root: str | None = None): + """Load a LeRobotDataset and resolve the frame range for an episode. + + Returns: + Tuple of (dataset, episode_start, episode_length) on success. + + Raises: + ImportError: If lerobot is not installed. + ValueError: If the episode is out of range or has no frames. + """ + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + ds = LeRobotDataset(repo_id=repo_id, root=root) + + num_episodes = ds.meta.total_episodes if hasattr(ds.meta, "total_episodes") else len(ds.meta.episodes) + if episode >= num_episodes: + raise ValueError(f"Episode {episode} out of range (0-{num_episodes - 1})") + + episode_start = 0 + episode_length = 0 + try: + if hasattr(ds, "episode_data_index"): + from_idx = ds.episode_data_index["from"][episode].item() + to_idx = ds.episode_data_index["to"][episode].item() + episode_start = from_idx + episode_length = to_idx - from_idx + else: + for i in range(episode): + ep_info = ds.meta.episodes[i] if hasattr(ds.meta, "episodes") else {} + episode_start += ep_info.get("length", 0) + ep_info = ds.meta.episodes[episode] if hasattr(ds.meta, "episodes") else {} + episode_length = ep_info.get("length", 0) + except Exception: + # Last resort: scan frames to find episode boundaries + for idx in range(len(ds)): + frame = ds[idx] + frame_ep = frame.get("episode_index", -1) if hasattr(frame, "get") else -1 + if hasattr(frame_ep, "item"): + frame_ep = frame_ep.item() + if frame_ep == episode: + if episode_length == 0: + episode_start = idx + episode_length += 1 + elif episode_length > 0: + break + + if episode_length == 0: + raise ValueError(f"Episode {episode} has no frames") + + return ds, episode_start, episode_length diff --git a/strands_robots/simulation/mujoco/__init__.py b/strands_robots/simulation/mujoco/__init__.py new file mode 100644 index 0000000..014926b --- /dev/null +++ b/strands_robots/simulation/mujoco/__init__.py @@ -0,0 +1,41 @@ +"""MuJoCo simulation backend for strands-robots. + +CPU-based physics with offscreen rendering. No GPU required. +Supports URDF/MJCF loading, multi-robot scenes, policy execution, +domain randomization, and LeRobotDataset recording. + +Usage:: + + from strands_robots.simulation.mujoco import MuJoCoSimulation + + sim = MuJoCoSimulation(tool_name="my_sim") + sim.create_world() + sim.add_robot("so100", data_config="so100") + sim.run_policy("so100", policy_provider="mock", instruction="wave") + +Or via the top-level alias:: + + from strands_robots.simulation import Simulation # → MuJoCoSimulation +""" + +from strands_robots.simulation.mujoco.backend import ( + _configure_gl_backend, + _ensure_mujoco, + _is_headless, +) + +__all__ = [ + "MuJoCoSimulation", + "_configure_gl_backend", + "_ensure_mujoco", + "_is_headless", +] + + +def __getattr__(name): + if name == "MuJoCoSimulation": + from strands_robots.simulation.mujoco.simulation import Simulation as _Sim + + globals()["MuJoCoSimulation"] = _Sim + return _Sim + raise AttributeError(f"module 'strands_robots.simulation.mujoco' has no attribute {name!r}") diff --git a/strands_robots/simulation/mujoco/backend.py b/strands_robots/simulation/mujoco/backend.py new file mode 100644 index 0000000..da9a268 --- /dev/null +++ b/strands_robots/simulation/mujoco/backend.py @@ -0,0 +1,132 @@ +"""MuJoCo lazy import and GL backend configuration.""" + +import ctypes +import logging +import os +import sys + +logger = logging.getLogger(__name__) + +_mujoco = None +_mujoco_viewer = None + + +def _is_headless() -> bool: + """Detect if running in a headless environment (no display server). + + Returns True on Linux when no DISPLAY or WAYLAND_DISPLAY is set, + which means GLFW-based rendering will fail. + """ + if sys.platform != "linux": + return False + if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"): + return False + return True + + +def _configure_gl_backend() -> None: + """Auto-configure MuJoCo's OpenGL backend for headless environments. + + MuJoCo reads MUJOCO_GL at import time to select the OpenGL backend: + - "egl" → EGL (GPU-accelerated offscreen, requires libEGL + NVIDIA driver) + - "osmesa" → OSMesa (CPU software rendering, slower but always works) + - "glfw" → GLFW (default, requires X11/Wayland display server) + + This function MUST be called before `import mujoco`. Setting MUJOCO_GL + after import has no effect — the backend is locked at import time. + + Never overrides a user-set MUJOCO_GL value. + """ + if os.environ.get("MUJOCO_GL"): + logger.debug(f"MUJOCO_GL already set to '{os.environ['MUJOCO_GL']}', respecting user config") + return + + if not _is_headless(): + return + + # Headless Linux — probe for EGL first (GPU-accelerated), then fall back to OSMesa (CPU) + try: + ctypes.cdll.LoadLibrary("libEGL.so.1") + os.environ["MUJOCO_GL"] = "egl" + logger.info("Headless environment detected — using MUJOCO_GL=egl (GPU-accelerated offscreen)") + return + except OSError: + pass + + try: + ctypes.cdll.LoadLibrary("libOSMesa.so") + os.environ["MUJOCO_GL"] = "osmesa" + logger.info("Headless environment detected — using MUJOCO_GL=osmesa (CPU software rendering)") + return + except OSError: + pass + + logger.warning( + "Headless environment detected but neither EGL nor OSMesa found. " + "MuJoCo rendering will likely fail. Install one of:\n" + " GPU: apt-get install libegl1-mesa-dev (or NVIDIA driver provides libEGL)\n" + " CPU: apt-get install libosmesa6-dev\n" + "Then set: export MUJOCO_GL=egl (or osmesa)" + ) + + +def _ensure_mujoco(): + """Lazy import MuJoCo to avoid hard dependency. + + Auto-configures the OpenGL backend for headless environments before + importing mujoco, since MUJOCO_GL must be set at import time. + + Uses require_optional() for consistent dependency management across + the strands-robots package. + """ + global _mujoco, _mujoco_viewer + if _mujoco is None: + _configure_gl_backend() + from strands_robots.utils import require_optional + + _mujoco = require_optional( + "mujoco", + pip_install="mujoco", + extra="sim", + purpose="MuJoCo simulation", + ) + if _mujoco_viewer is None and not _is_headless(): + try: + import mujoco.viewer as viewer + + _mujoco_viewer = viewer + except ImportError: + pass + return _mujoco + + +_rendering_available: bool | None = None + + +def _can_render() -> bool: + """Check if MuJoCo offscreen rendering is available. + + Probes once by creating a minimal Renderer. Result is cached. + Returns False on headless environments without EGL/OSMesa. + """ + global _rendering_available + if _rendering_available is not None: + return _rendering_available + + mj = _ensure_mujoco() + try: + model = mj.MjModel.from_xml_string("") + renderer = mj.Renderer(model, height=1, width=1) + renderer.close() + del renderer + _rendering_available = True + logger.info("MuJoCo rendering available") + except Exception as e: + _rendering_available = False + logger.warning( + "MuJoCo rendering unavailable: %s. " + "Physics/policy will work, but render/camera observations will be skipped. " + "Install EGL or OSMesa for offscreen rendering.", + e, + ) + return _rendering_available diff --git a/strands_robots/simulation/mujoco/mjcf_builder.py b/strands_robots/simulation/mujoco/mjcf_builder.py new file mode 100644 index 0000000..5dcdc69 --- /dev/null +++ b/strands_robots/simulation/mujoco/mjcf_builder.py @@ -0,0 +1,197 @@ +"""MJCF XML builder — programmatic scene construction.""" + +import logging +import os +import subprocess +import tempfile + +from strands_robots.simulation.models import SimCamera, SimObject, SimRobot, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class MJCFBuilder: + """Builds MuJoCo MJCF XML from SimWorld state.""" + + @staticmethod + def build_objects_only(world: SimWorld) -> str: + """Build MJCF XML for a world with only objects (robots loaded separately).""" + _ensure_mujoco() + + parts = [] + parts.append('') + parts.append(' ') + + gx, gy, gz = world.gravity + parts.append(f' ") + + return "\n".join(parts) + + @staticmethod + def _object_xml(obj: SimObject, indent: int = 4) -> str: + """Generate MJCF XML for a single object.""" + pad = " " * indent + px, py, pz = obj.position + qw, qx, qy, qz = obj.orientation + r, g, b, a = obj.color + lines = [] + + lines.append(f'{pad}') + + if not obj.is_static: + lines.append(f'{pad} ') + lines.append(f'{pad} ') + + if obj.shape == "box": + sx, sy, sz = [s / 2 for s in obj.size] + lines.append( + f'{pad} ' + ) + elif obj.shape == "sphere": + radius = obj.size[0] / 2 if obj.size else 0.025 + lines.append( + f'{pad} ' + ) + elif obj.shape == "cylinder": + radius = obj.size[0] / 2 if obj.size else 0.025 + half_h = obj.size[2] / 2 if len(obj.size) > 2 else 0.05 + lines.append( + f'{pad} ' + ) + elif obj.shape == "capsule": + radius = obj.size[0] / 2 if obj.size else 0.025 + half_h = obj.size[2] / 2 if len(obj.size) > 2 else 0.05 + lines.append( + f'{pad} ' + ) + elif obj.shape == "mesh" and obj.mesh_path: + lines.append( + f'{pad} ' + ) + elif obj.shape == "plane": + sx = obj.size[0] if obj.size else 1.0 + sy = obj.size[1] if len(obj.size) > 1 else sx + lines.append( + f'{pad} ' + ) + + lines.append(f"{pad}") + return "\n".join(lines) + + @staticmethod + def compose_multi_robot_scene( + robots: dict[str, SimRobot], + objects: dict[str, SimObject], + cameras: dict[str, SimCamera], + world: SimWorld, + ) -> str: + """Compose a multi-robot scene by merging URDF-derived MJCF fragments.""" + mj = _ensure_mujoco() + world._tmpdir = tempfile.TemporaryDirectory(prefix="strands_sim_") + tmpdir = world._tmpdir.name + + robot_xmls = {} + for robot_name, robot in robots.items(): + try: + model = mj.MjModel.from_xml_path(str(robot.urdf_path)) + robot_xml_path = os.path.join(tmpdir, f"{robot_name}.xml") + mj.mj_saveLastXML(robot_xml_path, model) + robot_xmls[robot_name] = robot_xml_path + logger.debug("Converted %s → %s", robot.urdf_path, robot_xml_path) + except (FileNotFoundError, OSError, subprocess.CalledProcessError) as e: + logger.error("Failed to convert URDF for '%s': %s", robot_name, e) + raise + + parts = [] + parts.append('') + parts.append(' ') + + gx, gy, gz = world.gravity + parts.append(f' ") + + master_xml = "\n".join(parts) + master_path = os.path.join(tmpdir, "master_scene.xml") + with open(master_path, "w") as f: + f.write(master_xml) + + return master_path diff --git a/strands_robots/simulation/mujoco/physics.py b/strands_robots/simulation/mujoco/physics.py new file mode 100644 index 0000000..1afc7e8 --- /dev/null +++ b/strands_robots/simulation/mujoco/physics.py @@ -0,0 +1,821 @@ +"""Physics mixin — advanced MuJoCo physics introspection and manipulation. + +Exposes the deep MuJoCo C API through clean Python methods: +- Raycasting (mj_ray) +- Jacobians (mj_jacBody, mj_jacSite, mj_jacGeom) +- Energy computation (mj_energyPos, mj_energyVel) +- External forces (mj_applyFT, xfrc_applied) +- Mass matrix (mj_fullM) +- State checkpointing (mj_getState, mj_setState) +- Inverse dynamics (mj_inverse) +- Body/joint introspection (poses, velocities, accelerations) +- Direct joint position/velocity control (qpos, qvel) +- Runtime model modification (mass, friction, color, size) +- Sensor readout (sensordata) +- Contact force analysis (mj_contactForce) +""" + +import json +import logging +from typing import Any + +import numpy as np + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class PhysicsMixin: + """Advanced physics capabilities for Simulation. + + Expects: self._world (SimWorld with _model, _data) + + Naming: methods match action names in tool_spec.json for direct dispatch. + """ + + # ── State Checkpointing ── + + def save_state(self, name: str = "default") -> dict[str, Any]: + """Save the full physics state (qpos, qvel, act, time) to a named checkpoint. + + Uses mj_getState with mjSTATE_PHYSICS for complete state capture. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + state_size = mj.mj_stateSize(model, mj.mjtState.mjSTATE_PHYSICS) + state = np.zeros(state_size) + mj.mj_getState(model, data, state, mj.mjtState.mjSTATE_PHYSICS) + + if not hasattr(self._world, "_checkpoints"): + self._world._checkpoints = {} + + self._world._checkpoints[name] = { + "state": state.copy(), + "sim_time": self._world.sim_time, + "step_count": self._world.step_count, + } + + return { + "status": "success", + "content": [ + { + "text": ( + f"💾 State '{name}' saved\n" + f" t={self._world.sim_time:.4f}s, step={self._world.step_count}\n" + f" State vector: {state_size} floats\n" + f" Checkpoints: {list(self._world._checkpoints.keys())}" + ) + } + ], + } + + def load_state(self, name: str = "default") -> dict[str, Any]: + """Restore physics state from a named checkpoint.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + checkpoints = getattr(self._world, "_checkpoints", {}) + if name not in checkpoints: + available = list(checkpoints.keys()) if checkpoints else ["none"] + return { + "status": "error", + "content": [{"text": f"❌ Checkpoint '{name}' not found. Available: {available}"}], + } + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + checkpoint = checkpoints[name] + + mj.mj_setState(model, data, checkpoint["state"], mj.mjtState.mjSTATE_PHYSICS) + mj.mj_forward(model, data) + + self._world.sim_time = checkpoint["sim_time"] + self._world.step_count = checkpoint["step_count"] + + return { + "status": "success", + "content": [ + {"text": f"📂 State '{name}' restored (t={self._world.sim_time:.4f}s, step={self._world.step_count})"} + ], + } + + # ── External Forces ── + + def apply_force( + self, + body_name: str, + force: list[float] = None, + torque: list[float] = None, + point: list[float] = None, + ) -> dict[str, Any]: + """Apply external force and/or torque to a body. + + Uses mj_applyFT for precise force application at a world-frame point. + Forces persist for one timestep — call before each step for continuous force. + + Args: + body_name: Target body name. + force: [fx, fy, fz] in world frame (Newtons). + torque: [tx, ty, tz] in world frame (N·m). + point: [px, py, pz] world-frame point of force application. + Defaults to body CoM if not specified. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + body_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"❌ Body '{body_name}' not found."}]} + + f = np.array(force or [0, 0, 0], dtype=np.float64) + t = np.array(torque or [0, 0, 0], dtype=np.float64) + p = np.array(point, dtype=np.float64) if point else data.xipos[body_id].copy() + + mj.mj_applyFT(model, data, f, t, p, body_id, data.qfrc_applied) + + return { + "status": "success", + "content": [ + { + "text": ( + f"💨 Force applied to '{body_name}' (body {body_id})\n" + f" Force: {f.tolist()} N\n" + f" Torque: {t.tolist()} N·m\n" + f" Point: {p.tolist()}" + ) + } + ], + } + + # ── Raycasting ── + + def raycast( + self, + origin: list[float], + direction: list[float], + exclude_body: int = -1, + include_static: bool = True, + ) -> dict[str, Any]: + """Cast a ray and find the first geom intersection. + + Uses mj_ray for precise distance sensing / obstacle detection. + + Args: + origin: [x, y, z] ray start point in world frame. + direction: [dx, dy, dz] ray direction (auto-normalized). + exclude_body: Body ID to exclude from intersection (-1 = none). + include_static: Whether to include static geoms. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + pnt = np.array(origin, dtype=np.float64) + vec = np.array(direction, dtype=np.float64) + # Normalize direction + norm = np.linalg.norm(vec) + if norm > 0: + vec = vec / norm + + geomid = np.array([-1], dtype=np.int32) + dist = mj.mj_ray( + model, + data, + pnt, + vec, + None, # geom group filter (None = all) + 1 if include_static else 0, + exclude_body, + geomid, + ) + + hit = dist >= 0 + geom_name = None + if hit and geomid[0] >= 0: + geom_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, geomid[0]) + + result = { + "hit": hit, + "distance": float(dist) if hit else None, + "geom_id": int(geomid[0]) if hit else None, + "geom_name": geom_name, + "hit_point": (pnt + vec * dist).tolist() if hit else None, + } + + if hit: + text = f"🎯 Ray hit '{geom_name or geomid[0]}' at dist={dist:.4f}m, point={result['hit_point']}" + else: + text = "🎯 Ray: no intersection" + + return {"status": "success", "content": [{"text": text}, {"text": json.dumps(result, default=str)}]} + + # ── Jacobians ── + + def get_jacobian( + self, + body_name: str = None, + site_name: str = None, + geom_name: str = None, + ) -> dict[str, Any]: + """Compute the Jacobian (position + rotation) for a body, site, or geom. + + The Jacobian maps joint velocities to Cartesian velocities: + v = J @ dq + + Returns both positional (3×nv) and rotational (3×nv) Jacobians. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + jacp = np.zeros((3, model.nv)) + jacr = np.zeros((3, model.nv)) + + if body_name: + obj_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"❌ Body '{body_name}' not found."}]} + mj.mj_jacBody(model, data, jacp, jacr, obj_id) + label = f"body '{body_name}'" + elif site_name: + obj_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_SITE, site_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"❌ Site '{site_name}' not found."}]} + mj.mj_jacSite(model, data, jacp, jacr, obj_id) + label = f"site '{site_name}'" + elif geom_name: + obj_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_GEOM, geom_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"❌ Geom '{geom_name}' not found."}]} + mj.mj_jacGeom(model, data, jacp, jacr, obj_id) + label = f"geom '{geom_name}'" + else: + return {"status": "error", "content": [{"text": "❌ Specify body_name, site_name, or geom_name."}]} + + return { + "status": "success", + "content": [ + {"text": f"🧮 Jacobian for {label}: pos={jacp.shape}, rot={jacr.shape}, nv={model.nv}"}, + { + "text": json.dumps( + { + "jacp": jacp.tolist(), + "jacr": jacr.tolist(), + "nv": model.nv, + }, + default=str, + ) + }, + ], + } + + # ── Energy ── + + def get_energy(self) -> dict[str, Any]: + """Compute potential and kinetic energy of the system.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_energyPos(model, data) + mj.mj_energyVel(model, data) + + potential = float(data.energy[0]) + kinetic = float(data.energy[1]) + total = potential + kinetic + + return { + "status": "success", + "content": [ + {"text": f"⚡ Energy: potential={potential:.4f}J, kinetic={kinetic:.4f}J, total={total:.4f}J"}, + {"text": json.dumps({"potential": potential, "kinetic": kinetic, "total": total}, default=str)}, + ], + } + + # ── Mass Matrix ── + + def get_mass_matrix(self) -> dict[str, Any]: + """Compute the full mass (inertia) matrix M(q). + + M is nv×nv where nv is the number of DoFs. + Useful for dynamics analysis, impedance control, etc. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + nv = model.nv + M = np.zeros((nv, nv)) + mj.mj_fullM(model, M, data.qM) + rank = int(np.linalg.matrix_rank(M)) + cond = float(np.linalg.cond(M)) if rank > 0 else float("inf") + + return { + "status": "success", + "content": [ + {"text": f"🧮 Mass matrix: {nv}×{nv}, rank={rank}, cond={cond:.2e}"}, + { + "text": json.dumps( + { + "shape": [nv, nv], + "rank": rank, + "condition_number": cond, + "diagonal": np.diag(M).tolist(), + "total_mass": float(np.sum(model.body_mass)), + }, + default=str, + ) + }, + ], + } + + # ── Inverse Dynamics ── + + def inverse_dynamics(self) -> dict[str, Any]: + """Compute inverse dynamics: given qacc, what forces are needed? + + Runs mj_inverse to compute qfrc_inverse — the generalized forces + that would produce the current accelerations. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_inverse(model, data) + + # Build named force mapping + forces = {} + for i in range(model.njnt): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) + if name: + dof_adr = model.jnt_dofadr[i] + forces[name] = float(data.qfrc_inverse[dof_adr]) + + return { + "status": "success", + "content": [ + {"text": f"🔄 Inverse dynamics: {len(forces)} joint forces computed"}, + {"text": json.dumps({"qfrc_inverse": forces}, default=str)}, + ], + } + + # ── Body Introspection ── + + def get_body_state( + self, + body_name: str, + ) -> dict[str, Any]: + """Get the full state of a body: position, orientation, velocity, acceleration. + + Returns Cartesian pose + 6D spatial velocity (linear + angular). + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + body_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"❌ Body '{body_name}' not found."}]} + + # Position and orientation + pos = data.xpos[body_id].tolist() + quat = data.xquat[body_id].tolist() + rotmat = data.xmat[body_id].reshape(3, 3).tolist() + + # Velocity (6D: angular then linear in world frame) + vel = np.zeros(6) + mj.mj_objectVelocity(model, data, mj.mjtObj.mjOBJ_BODY, body_id, vel, 0) + linvel = vel[3:].tolist() + angvel = vel[:3].tolist() + + # Mass and inertia + mass = float(model.body_mass[body_id]) + com = data.xipos[body_id].tolist() + + state = { + "position": pos, + "quaternion": quat, + "rotation_matrix": rotmat, + "linear_velocity": linvel, + "angular_velocity": angvel, + "mass": mass, + "center_of_mass": com, + } + + text = ( + f"🏷️ Body '{body_name}' (id={body_id}):\n" + f" pos: [{pos[0]:.4f}, {pos[1]:.4f}, {pos[2]:.4f}]\n" + f" quat: [{quat[0]:.4f}, {quat[1]:.4f}, {quat[2]:.4f}, {quat[3]:.4f}]\n" + f" linvel: [{linvel[0]:.4f}, {linvel[1]:.4f}, {linvel[2]:.4f}]\n" + f" angvel: [{angvel[0]:.4f}, {angvel[1]:.4f}, {angvel[2]:.4f}]\n" + f" mass: {mass:.4f}kg, com: {com}" + ) + + return {"status": "success", "content": [{"text": text}, {"text": json.dumps(state, default=str)}]} + + # ── Direct Joint Control ── + + def set_joint_positions( + self, + positions: dict[str, float] = None, + robot_name: str = None, + ) -> dict[str, Any]: + """Set joint positions directly (bypassing actuators). + + Writes to qpos and runs mj_forward to update kinematics. + Useful for teleportation, IK solutions, or keyframe setting. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if positions is None: + return {"status": "error", "content": [{"text": "❌ positions dict required."}]} + + set_count = 0 + for jnt_name, value in positions.items(): + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + qpos_adr = model.jnt_qposadr[jnt_id] + data.qpos[qpos_adr] = float(value) + set_count += 1 + else: + logger.warning("Joint '%s' not found, skipping", jnt_name) + + mj.mj_forward(model, data) + + return { + "status": "success", + "content": [{"text": f"🎯 Set {set_count}/{len(positions)} joint positions, FK updated"}], + } + + def set_joint_velocities( + self, + velocities: dict[str, float] = None, + ) -> dict[str, Any]: + """Set joint velocities directly. + + Writes to qvel. Useful for initializing dynamics. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if velocities is None: + return {"status": "error", "content": [{"text": "❌ velocities dict required."}]} + + set_count = 0 + for jnt_name, value in velocities.items(): + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + dof_adr = model.jnt_dofadr[jnt_id] + data.qvel[dof_adr] = float(value) + set_count += 1 + + return { + "status": "success", + "content": [{"text": f"💨 Set {set_count}/{len(velocities)} joint velocities"}], + } + + # ── Sensor Readout ── + + def get_sensor_data(self, sensor_name: str = None) -> dict[str, Any]: + """Read sensor values from the simulation. + + MuJoCo supports: jointpos, jointvel, accelerometer, gyro, force, + torque, touch, rangefinder, framequat, subtreecom, clock, etc. + + Args: + sensor_name: Specific sensor name, or None for all sensors. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if model.nsensor == 0: + return {"status": "success", "content": [{"text": "📡 No sensors in model."}]} + + mj.mj_forward(model, data) + + sensors = {} + for i in range(model.nsensor): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_SENSOR, i) + if not name: + name = f"sensor_{i}" + + adr = model.sensor_adr[i] + dim = model.sensor_dim[i] + values = data.sensordata[adr : adr + dim].tolist() + + if sensor_name and name != sensor_name: + continue + + sensors[name] = { + "values": values if dim > 1 else values[0], + "dim": int(dim), + "type": int(model.sensor_type[i]), + } + + if sensor_name and sensor_name not in sensors: + return {"status": "error", "content": [{"text": f"❌ Sensor '{sensor_name}' not found."}]} + + lines = [f"📡 Sensors ({len(sensors)}/{model.nsensor}):"] + for name, info in sensors.items(): + lines.append(f" {name}: {info['values']} (dim={info['dim']})") + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"text": json.dumps({"sensors": sensors}, default=str)}], + } + + # ── Runtime Model Modification ── + + def set_body_properties( + self, + body_name: str, + mass: float = None, + ) -> dict[str, Any]: + """Modify body properties at runtime (no recompile needed). + + Changes take effect on the next mj_step. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model = self._world._model + body_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"❌ Body '{body_name}' not found."}]} + + changes = [] + if mass is not None: + old_mass = float(model.body_mass[body_id]) + model.body_mass[body_id] = mass + changes.append(f"mass: {old_mass:.3f} → {mass:.3f}") + + return { + "status": "success", + "content": [{"text": f"🔧 Body '{body_name}': {', '.join(changes)}"}], + } + + def set_geom_properties( + self, + geom_name: str = None, + geom_id: int = None, + color: list[float] = None, + friction: list[float] = None, + size: list[float] = None, + ) -> dict[str, Any]: + """Modify geom properties at runtime (no recompile needed). + + Changes take effect immediately for rendering (color) or next step (friction, size). + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model = self._world._model + + gid = geom_id + if geom_name: + gid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_GEOM, geom_name) + if gid is None or gid < 0: + return {"status": "error", "content": [{"text": f"❌ Geom '{geom_name or geom_id}' not found."}]} + + label = geom_name or f"geom_{gid}" + changes = [] + + if color is not None: + model.geom_rgba[gid] = color[:4] if len(color) >= 4 else color[:3] + [1.0] + changes.append(f"color → {model.geom_rgba[gid].tolist()}") + + if friction is not None: + fric = friction[:3] if len(friction) >= 3 else friction + [0.0] * (3 - len(friction)) + model.geom_friction[gid] = fric + changes.append(f"friction → {fric}") + + if size is not None: + n = min(len(size), 3) + model.geom_size[gid, :n] = size[:n] + changes.append(f"size → {model.geom_size[gid].tolist()}") + + return { + "status": "success", + "content": [{"text": f"🔧 Geom '{label}': {', '.join(changes)}"}], + } + + # ── Contact Force Analysis ── + + def get_contact_forces(self) -> dict[str, Any]: + """Get detailed contact forces for all active contacts. + + Uses mj_contactForce for each active contact pair. + Returns normal and friction forces. + """ + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + contacts = [] + for i in range(data.ncon): + c = data.contact[i] + g1 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom1) or f"geom_{c.geom1}" + g2 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom2) or f"geom_{c.geom2}" + + # Get contact force (normal + friction in contact frame) + force = np.zeros(6) + mj.mj_contactForce(model, data, i, force) + + contacts.append( + { + "geom1": g1, + "geom2": g2, + "distance": float(c.dist), + "position": c.pos.tolist(), + "normal_force": float(force[0]), + "friction_force": force[1:3].tolist(), + "full_wrench": force.tolist(), + } + ) + + if not contacts: + return {"status": "success", "content": [{"text": "💥 No active contacts."}]} + + lines = [f"💥 {len(contacts)} contacts:"] + for c in contacts[:15]: + lines.append(f" {c['geom1']} ↔ {c['geom2']}: normal={c['normal_force']:.3f}N, dist={c['distance']:.4f}m") + if len(contacts) > 15: + lines.append(f" ... and {len(contacts) - 15} more") + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"text": json.dumps({"contacts": contacts}, default=str)}], + } + + # ── Multi-Ray (batch raycasting) ── + + def multi_raycast( + self, + origin: list[float], + directions: list[list[float]], + exclude_body: int = -1, + ) -> dict[str, Any]: + """Cast multiple rays from a single origin (e.g., for LIDAR simulation). + + Efficiently casts N rays using individual mj_ray calls. + Returns array of distances and hit geoms. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + pnt = np.array(origin, dtype=np.float64) + results = [] + + for d in directions: + vec = np.array(d, dtype=np.float64) + norm = np.linalg.norm(vec) + if norm > 0: + vec /= norm + geomid = np.array([-1], dtype=np.int32) + dist = mj.mj_ray(model, data, pnt, vec, None, 1, exclude_body, geomid) + results.append( + { + "distance": float(dist) if dist >= 0 else None, + "geom_id": int(geomid[0]) if dist >= 0 else None, + } + ) + + hit_count = sum(1 for r in results if r["distance"] is not None) + return { + "status": "success", + "content": [ + {"text": f"🎯 Multi-ray: {hit_count}/{len(directions)} hits from {origin}"}, + {"text": json.dumps({"rays": results}, default=str)}, + ], + } + + # ── Forward Kinematics (explicit) ── + + def forward_kinematics(self) -> dict[str, Any]: + """Run forward kinematics to update all body positions/orientations. + + Usually called implicitly by mj_step, but useful after manually + setting qpos to see updated Cartesian positions. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_kinematics(model, data) + mj.mj_comPos(model, data) + mj.mj_camlight(model, data) + + # Build body position summary + bodies = {} + for i in range(model.nbody): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + bodies[name] = { + "position": data.xpos[i].tolist(), + "quaternion": data.xquat[i].tolist(), + } + + return { + "status": "success", + "content": [ + {"text": f"🦴 FK computed for {model.nbody} bodies"}, + {"text": json.dumps({"bodies": bodies}, default=str)}, + ], + } + + # ── Total Mass ── + + def get_total_mass(self) -> dict[str, Any]: + """Get total mass and per-body mass breakdown.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model = self._world._model + + total = float(mj.mj_getTotalmass(model)) + bodies = {} + for i in range(model.nbody): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + m = float(model.body_mass[i]) + if m > 0: + bodies[name] = m + + return { + "status": "success", + "content": [ + {"text": f"⚖️ Total mass: {total:.4f}kg ({len(bodies)} bodies with mass)"}, + {"text": json.dumps({"total_mass": total, "bodies": bodies}, default=str)}, + ], + } + + # ── Export Model XML ── + + def export_xml(self, output_path: str = None) -> dict[str, Any]: + """Export the current model to MJCF XML. + + Uses mj_saveLastXML — exports the exact model currently loaded, + including any runtime modifications. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + + if output_path: + mj.mj_saveLastXML(output_path, self._world._model) + return {"status": "success", "content": [{"text": f"📄 Model exported to {output_path}"}]} + else: + # Return XML string via saveLastXML to temp file + import os + import tempfile + + tmpfile = tempfile.mktemp(suffix=".xml") + mj.mj_saveLastXML(tmpfile, self._world._model) + with open(tmpfile) as f: + xml = f.read() + os.unlink(tmpfile) + return { + "status": "success", + "content": [ + {"text": f"📄 Model XML ({len(xml)} chars):\n{xml[:2000]}{'...' if len(xml) > 2000 else ''}"} + ], + } diff --git a/strands_robots/simulation/mujoco/policy_runner.py b/strands_robots/simulation/mujoco/policy_runner.py new file mode 100644 index 0000000..59c3f8d --- /dev/null +++ b/strands_robots/simulation/mujoco/policy_runner.py @@ -0,0 +1,356 @@ +"""Policy execution mixin — run_policy, start_policy, record_video, replay_episode, eval_policy.""" + +import logging +import os +import time +from typing import Any + +import numpy as np + +from strands_robots._async_utils import _resolve_coroutine +from strands_robots.simulation.models import TrajectoryStep +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class PolicyRunnerMixin: + """Policy execution for Simulation. Expects self._world, self._executor, self._policy_threads.""" + + def run_policy( + self, + robot_name: str, + policy_provider: str = "mock", + instruction: str = "", + duration: float = 10.0, + action_horizon: int = 8, + control_frequency: float = 50.0, + fast_mode: bool = False, + record_video: str = None, + video_fps: int = 30, + video_camera: str = None, + video_width: int = 640, + video_height: int = 480, + **policy_kwargs, + ) -> dict[str, Any]: + """Run a policy on a simulated robot (blocking). + + Args: + record_video: If set, path to save an MP4 recording of the run. + video_fps: Frames per second for the recording (default 30). + video_camera: Camera name for recording (default: first scene camera). + video_width: Recording width in pixels. + video_height: Recording height in pixels. + """ + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + if robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + robot = self._world.robots[robot_name] + + # Video recording setup + writer = None + frame_count = 0 + cam_id = -1 + if record_video: + import imageio + + os.makedirs(os.path.dirname(os.path.abspath(record_video)), exist_ok=True) + writer = imageio.get_writer(record_video, fps=video_fps, quality=8, macro_block_size=1) + if video_camera: + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, video_camera) + elif model.ncam > 0: + cam_id = 0 + frame_interval = control_frequency / video_fps # fractional steps per frame + + try: + from strands_robots.policies import create_policy as _create_policy + + policy = _create_policy(policy_provider, **policy_kwargs) + policy.set_robot_state_keys(robot.joint_names) + + robot.policy_running = True + robot.policy_instruction = instruction + robot.policy_steps = 0 + next_frame_step = 0.0 + + sim_duration = duration * control_frequency # target number of control steps + start_time = time.time() + action_sleep = 1.0 / control_frequency + + while robot.policy_steps < sim_duration and robot.policy_running: + observation = self._get_sim_observation(robot_name) + + coro_or_result = policy.get_actions(observation, instruction) + actions = _resolve_coroutine(coro_or_result) + + for action_dict in actions[:action_horizon]: + if not robot.policy_running: + break + + if self._world._recording: + self._world._trajectory.append( + TrajectoryStep( + timestamp=time.time(), + sim_time=self._world.sim_time, + robot_name=robot_name, + observation={k: v for k, v in observation.items() if not isinstance(v, np.ndarray)}, + action=action_dict, + instruction=instruction, + ) + ) + if self._world._dataset_recorder is not None: + self._world._dataset_recorder.add_frame( + observation=observation, + action=action_dict, + task=instruction, + ) + + self._apply_sim_action(robot_name, action_dict) + robot.policy_steps += 1 + + if writer and robot.policy_steps >= next_frame_step: + renderer = self._get_renderer(video_width, video_height) + if renderer is not None: + if cam_id >= 0: + renderer.update_scene(data, camera=cam_id) + else: + renderer.update_scene(data) + writer.append_data(renderer.render().copy()) + frame_count += 1 + next_frame_step += frame_interval + + if not fast_mode: + time.sleep(action_sleep) + + elapsed = time.time() - start_time + robot.policy_running = False + + result_text = ( + f"✅ Policy complete on '{robot_name}'\n" + f"🧠 {policy_provider} | 🎯 {instruction}\n" + f"⏱️ {elapsed:.1f}s | 📊 {robot.policy_steps} steps | " + f"🕐 sim_t={self._world.sim_time:.3f}s" + ) + + if writer: + writer.close() + file_kb = os.path.getsize(record_video) / 1024 + result_text += ( + f"\n🎬 Video: {record_video}\n" + f"📹 {frame_count} frames, {video_fps}fps, {video_width}x{video_height} | 💾 {file_kb:.0f} KB" + ) + + return {"status": "success", "content": [{"text": result_text}]} + + except Exception as e: + robot.policy_running = False + if writer: + writer.close() + return {"status": "error", "content": [{"text": f"❌ Policy failed: {e}"}]} + + def start_policy( + self, + robot_name: str, + policy_provider: str = "mock", + instruction: str = "", + duration: float = 10.0, + fast_mode: bool = False, + **policy_kwargs, + ) -> dict[str, Any]: + """Start policy execution in background (non-blocking).""" + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + if robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found."}]} + + future = self._executor.submit( + self.run_policy, + robot_name, + policy_provider, + instruction, + duration, + fast_mode=fast_mode, + **policy_kwargs, + ) + self._policy_threads[robot_name] = future + + return { + "status": "success", + "content": [{"text": f"🚀 Policy started on '{robot_name}' (async)"}], + } + + def replay_episode( + self, + repo_id: str, + robot_name: str = None, + episode: int = 0, + root: str = None, + speed: float = 1.0, + ) -> dict[str, Any]: + """Replay actions from a LeRobotDataset episode in simulation.""" + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world. Call create_world first."}]} + + if robot_name is None: + if not self._world.robots: + return {"status": "error", "content": [{"text": "❌ No robots in sim. Add one first."}]} + robot_name = next(iter(self._world.robots)) + + robot = self._world.robots.get(robot_name) + if robot is None: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found"}]} + + try: + from strands_robots.dataset_recorder import load_lerobot_episode + + ds, episode_start, episode_length = load_lerobot_episode(repo_id, episode, root) + except ImportError: + return {"status": "error", "content": [{"text": "❌ lerobot not installed"}]} + except (ValueError, Exception) as e: + return {"status": "error", "content": [{"text": f"❌ {e}"}]} + + mj = _ensure_mujoco() + dataset_fps = getattr(ds, "fps", 30) + frame_interval = 1.0 / (dataset_fps * speed) + model = self._world._model + data = self._world._data + n_actuators = model.nu + frames_applied = 0 + start_time = time.time() + + for frame_idx in range(episode_length): + step_start = time.time() + frame = ds[episode_start + frame_idx] + + if "action" in frame: + action_vals = frame["action"] + if hasattr(action_vals, "numpy"): + action_vals = action_vals.numpy() + if hasattr(action_vals, "tolist"): + action_vals = action_vals.tolist() + for i in range(min(len(action_vals), n_actuators)): + data.ctrl[i] = float(action_vals[i]) + + mj.mj_step(model, data) + frames_applied += 1 + + elapsed = time.time() - step_start + sleep_time = frame_interval - elapsed + if sleep_time > 0: + time.sleep(sleep_time) + + duration = time.time() - start_time + return { + "status": "success", + "content": [ + { + "text": ( + f"▶️ Replayed episode {episode} from {repo_id} on '{robot_name}'\n" + f"Frames: {frames_applied}/{episode_length} | Duration: {duration:.1f}s | Speed: {speed}x" + ) + }, + { + "json": { + "episode": episode, + "robot_name": robot_name, + "frames_applied": frames_applied, + "total_frames": episode_length, + "duration_s": round(duration, 2), + "speed": speed, + } + }, + ], + } + + def eval_policy( + self, + robot_name: str = None, + policy_provider: str = "mock", + instruction: str = "", + n_episodes: int = 10, + max_steps: int = 300, + success_fn: str = None, + **policy_kwargs, + ) -> dict[str, Any]: + """Evaluate a policy over multiple episodes with success metrics.""" + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world. Call create_world first."}]} + + if robot_name is None: + if not self._world.robots: + return {"status": "error", "content": [{"text": "❌ No robots"}]} + robot_name = next(iter(self._world.robots)) + + robot = self._world.robots.get(robot_name) + if robot is None: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found"}]} + + from strands_robots.policies import create_policy + + mj = _ensure_mujoco() + policy_instance = create_policy(policy_provider, **policy_kwargs) + policy_instance.set_robot_state_keys(robot.joint_names) + + model = self._world._model + data = self._world._data + + results = [] + for ep in range(n_episodes): + mj.mj_resetData(model, data) + mj.mj_forward(model, data) + + total_reward = 0.0 + success = False + steps = 0 + + for step in range(max_steps): + obs = self._get_sim_observation(robot_name=robot_name) + coro_or_result = policy_instance.get_actions(obs, instruction) + actions = _resolve_coroutine(coro_or_result) + + if actions: + self._apply_sim_action(robot_name, actions[0]) + + mj.mj_step(model, data) + steps += 1 + + if success_fn == "contact": + for i in range(data.ncon): + if data.contact[i].dist < 0: + success = True + break + if success: + break + + results.append({"episode": ep, "steps": steps, "success": success, "reward": total_reward}) + + n_success = sum(1 for r in results if r["success"]) + success_rate = n_success / max(n_episodes, 1) + avg_steps = sum(r["steps"] for r in results) / max(n_episodes, 1) + + return { + "status": "success", + "content": [ + { + "text": ( + f"📊 Evaluation: {policy_provider} on '{robot_name}'\n" + f"Episodes: {n_episodes} | Success: {n_success}/{n_episodes} ({success_rate:.1%})\n" + f"Avg steps: {avg_steps:.0f}/{max_steps}" + ) + }, + { + "json": { + "success_rate": round(success_rate, 4), + "n_episodes": n_episodes, + "n_success": n_success, + "avg_steps": round(avg_steps, 1), + "max_steps": max_steps, + "episodes": results, + } + }, + ], + } diff --git a/strands_robots/simulation/mujoco/randomization.py b/strands_robots/simulation/mujoco/randomization.py new file mode 100644 index 0000000..cdb2d3e --- /dev/null +++ b/strands_robots/simulation/mujoco/randomization.py @@ -0,0 +1,74 @@ +"""Domain randomization mixin.""" + +import logging +from typing import Any + +import numpy as np + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RandomizationMixin: + """Domain randomization for Simulation. Expects self._world.""" + + def randomize( + self, + randomize_colors: bool = True, + randomize_lighting: bool = True, + randomize_physics: bool = False, + randomize_positions: bool = False, + position_noise: float = 0.02, + color_range: tuple[float, float] = (0.1, 1.0), + friction_range: tuple[float, float] = (0.5, 1.5), + mass_range: tuple[float, float] = (0.5, 2.0), + seed: int = None, + ) -> dict[str, Any]: + """Apply domain randomization to the scene.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + rng = np.random.default_rng(seed) + mj = _ensure_mujoco() + model = self._world._model + data = self._world._data + changes = [] + + if randomize_colors: + for i in range(model.ngeom): + geom_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, i) + if geom_name and geom_name != "ground": + model.geom_rgba[i, :3] = rng.uniform(color_range[0], color_range[1], size=3) + changes.append(f"🎨 Colors: {model.ngeom} geoms randomized") + + if randomize_lighting: + for i in range(model.nlight): + model.light_pos[i] += rng.uniform(-0.5, 0.5, size=3) + model.light_diffuse[i] = rng.uniform(0.3, 1.0, size=3) + changes.append(f"💡 Lighting: {model.nlight} lights randomized") + + if randomize_physics: + for i in range(model.ngeom): + model.geom_friction[i, 0] *= rng.uniform(*friction_range) + for i in range(model.nbody): + if model.body_mass[i] > 0: + model.body_mass[i] *= rng.uniform(*mass_range) + changes.append(f"⚙️ Physics: friction×[{friction_range}], mass×[{mass_range}]") + + if randomize_positions: + for obj_name, obj in self._world.objects.items(): + if not obj.is_static: + jnt_name = f"{obj_name}_joint" + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + qpos_addr = model.jnt_qposadr[jnt_id] + noise = rng.uniform(-position_noise, position_noise, size=3) + data.qpos[qpos_addr : qpos_addr + 3] += noise + mj.mj_forward(model, data) + changes.append(f"📍 Positions: ±{position_noise}m noise on dynamic objects") + + return { + "status": "success", + "content": [{"text": "🎲 Domain Randomization applied:\n" + "\n".join(changes)}], + } diff --git a/strands_robots/simulation/mujoco/recording.py b/strands_robots/simulation/mujoco/recording.py new file mode 100644 index 0000000..c2ef006 --- /dev/null +++ b/strands_robots/simulation/mujoco/recording.py @@ -0,0 +1,152 @@ +"""Recording mixin — start/stop trajectory recording to LeRobotDataset.""" + +import logging +import shutil +from pathlib import Path +from typing import Any + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RecordingMixin: + """Trajectory recording for Simulation. Expects self._world.""" + + def start_recording( + self, + repo_id: str = "local/sim_recording", + task: str = "", + fps: int = 30, + root: str = None, + push_to_hub: bool = False, + vcodec: str = "libsvtav1", + overwrite: bool = True, + ) -> dict[str, Any]: + """Start recording to LeRobotDataset format (parquet + video).""" + if self._world is None: + return {"status": "error", "content": [{"text": "No world."}]} + + try: + from strands_robots.dataset_recorder import DatasetRecorder as _DatasetRecorder + from strands_robots.dataset_recorder import has_lerobot_dataset as _has_lerobot + except ImportError: + + def _has_lerobot(): + return False + + _DatasetRecorder = None + + if not _has_lerobot() or _DatasetRecorder is None: + return { + "status": "error", + "content": [ + { + "text": "lerobot not installed. Install with: pip install lerobot\nRequired for dataset recording." + } + ], + } + + self._world._recording = True + self._world._trajectory = [] + self._world._push_to_hub = push_to_hub + + try: + if overwrite: + if root: + dataset_dir = Path(root) + elif "/" not in repo_id or repo_id.startswith("/") or repo_id.startswith("./"): + dataset_dir = Path(repo_id) + else: + dataset_dir = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id + if dataset_dir.exists() and dataset_dir.is_dir(): + shutil.rmtree(dataset_dir) + logger.info("Removed existing dataset dir: %s", dataset_dir) + + joint_names = [] + camera_keys = [] + robot_type = "unknown" + for rname, robot in self._world.robots.items(): + joint_names.extend(robot.joint_names) + robot_type = robot.data_config or rname + + mj = _ensure_mujoco() + for i in range(self._world._model.ncam): + cam_name = mj.mj_id2name(self._world._model, mj.mjtObj.mjOBJ_CAMERA, i) + if cam_name: + camera_keys.append(cam_name) + + self._world._dataset_recorder = _DatasetRecorder.create( + repo_id=repo_id, + fps=fps, + robot_type=robot_type, + joint_names=joint_names, + camera_keys=camera_keys, + task=task, + root=root, + vcodec=vcodec, + ) + return { + "status": "success", + "content": [ + { + "text": ( + f"Recording to LeRobotDataset: {repo_id}\n" + f"{len(joint_names)} joints, {len(camera_keys)} cameras @ {fps}fps\n" + f"Codec: {vcodec} | Task: {task or '(set per policy)'}\n" + f"Run policies to capture frames, then stop_recording to save episode" + ) + } + ], + } + except Exception as e: + self._world._recording = False + logger.error("Dataset recorder init failed: %s", e) + return {"status": "error", "content": [{"text": f"Dataset init failed: {e}"}]} + + def stop_recording(self, output_path: str = None) -> dict[str, Any]: + """Stop recording and save episode to LeRobotDataset.""" + if self._world is None or not self._world._recording: + return {"status": "error", "content": [{"text": "Not recording."}]} + + self._world._recording = False + recorder = self._world._dataset_recorder + + if recorder is None: + return {"status": "error", "content": [{"text": "No dataset recorder active."}]} + + recorder.save_episode() + push_result = None + if getattr(self._world, "_push_to_hub", False): + push_result = recorder.push_to_hub(tags=["strands-robots", "sim"]) + + repo_id = recorder.repo_id + frame_count = recorder.frame_count + episode_count = recorder.episode_count + root = recorder.root + + recorder.finalize() + self._world._dataset_recorder = None + self._world._trajectory = [] + + text = ( + f"Episode saved to LeRobotDataset\n" + f"{repo_id} -- {frame_count} frames, {episode_count} episode(s)\n" + f"Local: {root}" + ) + if push_result and push_result.get("status") == "success": + text += "\nPushed to HuggingFace Hub" + + return {"status": "success", "content": [{"text": text}]} + + def get_recording_status(self) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + + recording = self._world._recording + steps = len(self._world._trajectory) + + return { + "status": "success", + "content": [{"text": f"{'🔴 Recording' if recording else '⚪ Not recording'}: {steps} steps captured"}], + } diff --git a/strands_robots/simulation/mujoco/rendering.py b/strands_robots/simulation/mujoco/rendering.py new file mode 100644 index 0000000..c51fc0c --- /dev/null +++ b/strands_robots/simulation/mujoco/rendering.py @@ -0,0 +1,225 @@ +"""Rendering mixin — render, render_depth, get_contacts, observation helpers.""" + +import io +import json +import logging +from typing import Any + +from strands_robots.simulation.mujoco.backend import _can_render, _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RenderingMixin: + """Rendering capabilities for Simulation. Expects self._world, self.default_width, self.default_height.""" + + def _get_renderer(self, width: int, height: int): + """Get a cached MuJoCo renderer, creating one only if needed. + + Returns None if rendering is unavailable (headless without EGL/OSMesa). + Callers must handle None return. + """ + if not _can_render(): + return None + mj = _ensure_mujoco() + key = (width, height) + if self._renderer_model is not self._world._model: + self._renderers.clear() + self._renderer_model = self._world._model + if key not in self._renderers: + self._renderers[key] = mj.Renderer(self._world._model, height=height, width=width) + return self._renderers[key] + + def _get_sim_observation(self, robot_name: str, cam_name: str = None) -> dict[str, Any]: + """Get observation from sim (same format as real robot).""" + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + robot = self._world.robots[robot_name] + + obs = {} + for jnt_name in robot.joint_names: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + obs[jnt_name] = float(data.qpos[model.jnt_qposadr[jnt_id]]) + + cameras_to_render = [] + if cam_name: + cameras_to_render = [cam_name] + else: + cameras_to_render = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) for i in range(model.ncam)] + for pycam_name in self._world.cameras: + if pycam_name not in cameras_to_render: + cameras_to_render.append(pycam_name) + + for cname in cameras_to_render: + if not cname: + continue + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, cname) + cam_info = self._world.cameras.get(cname) + h = cam_info.height if cam_info else self.default_height + w = cam_info.width if cam_info else self.default_width + try: + renderer = self._get_renderer(w, h) + if renderer is None: + continue + if cam_id >= 0: + renderer.update_scene(data, camera=cam_id) + else: + renderer.update_scene(data) + obs[cname] = renderer.render().copy() + except (RuntimeError, ValueError) as e: + # Individual camera failure shouldn't stop joint state collection. + # Common cause: camera ID invalid after scene recompile. + logger.debug("Camera render failed for %s: %s", cname, e) + + return obs + + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1): + """Apply action dict to sim (same interface as robot.send_action).""" + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + for key, value in action_dict.items(): + act_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_ACTUATOR, key) + if act_id >= 0: + data.ctrl[act_id] = float(value) + else: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, key) + if jnt_id >= 0 and jnt_id < model.nu: + data.ctrl[jnt_id] = float(value) + + for _ in range(max(1, n_substeps)): + mj.mj_step(model, data) + + self._world.sim_time = data.time + self._world.step_count += n_substeps + + if hasattr(self, "_viewer_handle") and self._viewer_handle is not None: + self._viewer_handle.sync() + + def render(self, camera_name: str = "default", width: int = None, height: int = None) -> dict[str, Any]: + """Render a camera view as base64 PNG image.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + w = width or self.default_width + h = height or self.default_height + + try: + renderer = self._get_renderer(w, h) + if renderer is None: + return { + "status": "error", + "content": [ + { + "text": ( + "❌ Rendering unavailable (no OpenGL context). " + "Install EGL or OSMesa for offscreen rendering: " + "apt-get install libosmesa6-dev" + ) + } + ], + } + cam_id = mj.mj_name2id(self._world._model, mj.mjtObj.mjOBJ_CAMERA, camera_name) + if cam_id >= 0: + renderer.update_scene(self._world._data, camera=cam_id) + else: + renderer.update_scene(self._world._data) + + img = renderer.render().copy() + + from PIL import Image + + pil_img = Image.fromarray(img) + buffer = io.BytesIO() + pil_img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + return { + "status": "success", + "content": [ + {"text": f"📸 {w}x{h} from '{camera_name}' at t={self._world.sim_time:.3f}s"}, + {"image": {"format": "png", "source": {"bytes": png_bytes}}}, + ], + } + except Exception as e: + return {"status": "error", "content": [{"text": f"❌ Render failed: {e}"}]} + + def render_depth(self, camera_name: str = "default", width: int = None, height: int = None) -> dict[str, Any]: + """Render depth map from a camera.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + w = width or self.default_width + h = height or self.default_height + + try: + cam_id = -1 + if camera_name and camera_name != "default": + cam_id = mj.mj_name2id(self._world._model, mj.mjtObj.mjOBJ_CAMERA, camera_name) + + renderer = self._get_renderer(w, h) + if renderer is None: + return { + "status": "error", + "content": [ + { + "text": ( + "❌ Depth rendering unavailable (no OpenGL context). " + "Install EGL or OSMesa for offscreen rendering." + ) + } + ], + } + if cam_id >= 0: + renderer.update_scene(self._world._data, camera=cam_id) + else: + renderer.update_scene(self._world._data) + renderer.enable_depth_rendering() + depth = renderer.render() + renderer.disable_depth_rendering() + + return { + "status": "success", + "content": [ + { + "text": ( + f"📸 Depth {w}x{h} from '{camera_name}'\n" + f"Min: {float(depth.min()):.3f}m, Max: {float(depth.max()):.3f}m" + ) + }, + { + "text": json.dumps( + {"depth_min": float(depth.min()), "depth_max": float(depth.max())}, default=str + ) + }, + ], + } + except Exception as e: + return {"status": "error", "content": [{"text": f"❌ Depth render failed: {e}"}]} + + def get_contacts(self) -> dict[str, Any]: + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + contacts = [] + for i in range(data.ncon): + c = data.contact[i] + g1 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom1) or f"geom_{c.geom1}" + g2 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom2) or f"geom_{c.geom2}" + contacts.append({"geom1": g1, "geom2": g2, "dist": float(c.dist), "pos": c.pos.tolist()}) + + text = f"💥 {len(contacts)} contacts" if contacts else "No contacts." + if contacts: + for c in contacts[:10]: + text += f"\n • {c['geom1']} ↔ {c['geom2']} (d={c['dist']:.4f})" + + return { + "status": "success", + "content": [{"text": text}, {"text": json.dumps({"contacts": contacts}, default=str)}], + } diff --git a/strands_robots/simulation/mujoco/scene_ops.py b/strands_robots/simulation/mujoco/scene_ops.py new file mode 100644 index 0000000..ba83696 --- /dev/null +++ b/strands_robots/simulation/mujoco/scene_ops.py @@ -0,0 +1,211 @@ +"""XML round-trip injection/ejection for scene modification. + +Shared helper `_reload_scene_from_xml` handles the common pattern: +save XML → patch paths → modify → reload → copy state → re-discover joints. +""" + +import logging +import os +import re +import shutil +import tempfile +import xml.etree.ElementTree as ET + +from strands_robots.simulation.models import SimCamera, SimObject, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco +from strands_robots.simulation.mujoco.mjcf_builder import MJCFBuilder + +logger = logging.getLogger(__name__) + + +def _patch_xml_paths(xml_content: str, robot_base_dir: str) -> str: + """Patch meshdir/texturedir in XML to absolute paths for tmpdir loading.""" + meshdir_match = re.search(r'meshdir="([^"]*)"', xml_content) + existing_meshdir = meshdir_match.group(1) if meshdir_match else "" + abs_meshdir = os.path.normpath(os.path.join(robot_base_dir, existing_meshdir)) + + texdir_match = re.search(r'texturedir="([^"]*)"', xml_content) + existing_texdir = texdir_match.group(1) if texdir_match else "" + abs_texdir = os.path.normpath(os.path.join(robot_base_dir, existing_texdir)) + + if meshdir_match: + xml_content = re.sub(r'meshdir="[^"]*"', f'meshdir="{abs_meshdir}"', xml_content) + elif " bool: + """Reload MuJoCo model from modified XML, preserving state. + + Copies qpos, qvel, ctrl from old model and re-discovers robot joint/actuator IDs. + """ + mj = _ensure_mujoco() + new_model = mj.MjModel.from_xml_path(str(scene_path)) + new_data = mj.MjData(new_model) + + # Copy state from old model + old_nq = min(world._data.qpos.shape[0], new_data.qpos.shape[0]) + old_nv = min(world._data.qvel.shape[0], new_data.qvel.shape[0]) + new_data.qpos[:old_nq] = world._data.qpos[:old_nq] + new_data.qvel[:old_nv] = world._data.qvel[:old_nv] + old_nu = min(world._data.ctrl.shape[0], new_data.ctrl.shape[0]) + new_data.ctrl[:old_nu] = world._data.ctrl[:old_nu] + + mj.mj_forward(new_model, new_data) + + world._model = new_model + world._data = new_data + + # Re-discover robot joints/actuators (IDs may shift) + for robot in world.robots.values(): + robot.joint_ids = [] + robot.actuator_ids = [] + for jnt_name in robot.joint_names: + jid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jid >= 0: + robot.joint_ids.append(jid) + for i in range(new_model.nu): + jnt_id = new_model.actuator_trnid[i, 0] + if jnt_id in robot.joint_ids: + robot.actuator_ids.append(i) + if not robot.actuator_ids: + for i in range(new_model.nu): + robot.actuator_ids.append(i) + + return True + + +def _get_robot_base_dir(world: SimWorld) -> str | None: + """Get the directory of the original robot model file.""" + if world._robot_base_xml: + return os.path.dirname(os.path.abspath(world._robot_base_xml)) + return None + + +def _save_and_patch_xml(world: SimWorld, tmpdir: str, filename: str) -> str: + """Save current model to XML in tmpdir and patch asset paths.""" + mj = _ensure_mujoco() + scene_path = os.path.join(tmpdir, filename) + mj.mj_saveLastXML(scene_path, world._model) + + robot_base_dir = _get_robot_base_dir(world) + if robot_base_dir and os.path.isdir(robot_base_dir): + with open(scene_path) as f: + xml_content = f.read() + xml_content = _patch_xml_paths(xml_content, robot_base_dir) + with open(scene_path, "w") as f: + f.write(xml_content) + + return scene_path + + +def inject_object_into_scene(world: SimWorld, obj: SimObject) -> bool: + """Inject object into a running simulation via XML round-trip.""" + _ensure_mujoco() + if world._model is None: + return False + + tmpdir = tempfile.mkdtemp(prefix="strands_sim_") + try: + scene_path = _save_and_patch_xml(world, tmpdir, "scene_with_objects.xml") + + with open(scene_path) as f: + xml_content = f.read() + + obj_xml = MJCFBuilder._object_xml(obj, indent=4) + xml_content = xml_content.replace("", f"{obj_xml}\n") + + # Remove keyframes — adding a freejoint changes qpos size + xml_content = re.sub(r".*?", "", xml_content, flags=re.DOTALL) + + with open(scene_path, "w") as f: + f.write(xml_content) + + return _reload_scene_from_xml(world, scene_path) + except (ValueError, RuntimeError, OSError) as e: + logger.error("Object injection reload failed: %s", e) + return False + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +def eject_body_from_scene(world: SimWorld, body_name: str) -> bool: + """Remove a named body from the scene via XML round-trip.""" + mj = _ensure_mujoco() + + tmpdir = tempfile.mkdtemp(prefix="strands_eject_") + try: + scene_path = os.path.join(tmpdir, "scene_ejected.xml") + mj.mj_saveLastXML(scene_path, world._model) + + tree = ET.parse(scene_path) + root = tree.getroot() + + # Patch paths + robot_base_dir = _get_robot_base_dir(world) + if robot_base_dir: + compiler = root.find("compiler") + if compiler is not None: + existing_meshdir = compiler.get("meshdir", "") + compiler.set("meshdir", os.path.normpath(os.path.join(robot_base_dir, existing_meshdir))) + existing_texdir = compiler.get("texturedir", "") + compiler.set("texturedir", os.path.normpath(os.path.join(robot_base_dir, existing_texdir))) + + # Remove target body + removed = False + for parent in root.iter(): + for child in list(parent): + if child.tag == "body" and child.get("name") == body_name: + parent.remove(child) + removed = True + + if not removed: + logger.warning(f"Body '{body_name}' not found in MJCF XML — skipping ejection.") + + # Remove keyframes + for keyframe_elem in root.findall("keyframe"): + root.remove(keyframe_elem) + + tree.write(scene_path, xml_declaration=True) + + return _reload_scene_from_xml(world, scene_path) + except (ValueError, RuntimeError, OSError) as e: + logger.error("Body ejection failed for '%s': %s", body_name, e) + return False + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +def inject_camera_into_scene(world: SimWorld, cam: SimCamera) -> bool: + """Inject a camera into a running simulation via XML round-trip.""" + _ensure_mujoco() + if world._model is None: + return False + + tmpdir = tempfile.mkdtemp(prefix="strands_cam_") + try: + scene_path = _save_and_patch_xml(world, tmpdir, "scene_with_cameras.xml") + + with open(scene_path) as f: + xml_content = f.read() + + px, py, pz = cam.position + cam_xml = f' ' + xml_content = xml_content.replace("", f"{cam_xml}\n") + + with open(scene_path, "w") as f: + f.write(xml_content) + + return _reload_scene_from_xml(world, scene_path) + except (ValueError, RuntimeError, OSError) as e: + logger.error("Camera injection reload failed: %s", e) + return False + finally: + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py new file mode 100644 index 0000000..70b5404 --- /dev/null +++ b/strands_robots/simulation/mujoco/simulation.py @@ -0,0 +1,949 @@ +"""MuJoCo Simulation — AgentTool orchestrator composing physics/rendering/policy mixins.""" + +import json +import logging +import os +import re +import threading +from collections.abc import AsyncGenerator +from concurrent.futures import Future, ThreadPoolExecutor +from pathlib import Path +from typing import Any + +from strands.tools.tools import AgentTool +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolSpec, ToolUse + +from strands_robots.simulation.model_registry import ( + list_available_models, + register_urdf, + resolve_model, +) +from strands_robots.simulation.models import SimCamera, SimObject, SimRobot, SimStatus, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco +from strands_robots.simulation.mujoco.mjcf_builder import MJCFBuilder +from strands_robots.simulation.mujoco.physics import PhysicsMixin +from strands_robots.simulation.mujoco.policy_runner import PolicyRunnerMixin +from strands_robots.simulation.mujoco.randomization import RandomizationMixin +from strands_robots.simulation.mujoco.recording import RecordingMixin +from strands_robots.simulation.mujoco.rendering import RenderingMixin +from strands_robots.simulation.mujoco.scene_ops import ( + eject_body_from_scene, + inject_camera_into_scene, + inject_object_into_scene, +) + +logger = logging.getLogger(__name__) + +_TOOL_SPEC_PATH = Path(__file__).parent / "tool_spec.json" + + +class Simulation( + PhysicsMixin, + PolicyRunnerMixin, + RenderingMixin, + RecordingMixin, + RandomizationMixin, + AgentTool, +): + """Programmatic simulation environment as a Strands AgentTool. + + Gives AI agents the ability to create, modify, and control MuJoCo + simulation environments through natural language → tool actions. + """ + + def __init__( + self, + tool_name: str = "sim", + default_timestep: float = 0.002, + default_width: int = 640, + default_height: int = 480, + mesh: bool = True, + peer_id: str = None, + **kwargs, + ): + super().__init__() + self.tool_name_str = tool_name + self.default_timestep = default_timestep + self.default_width = default_width + self.default_height = default_height + + self._world: SimWorld | None = None + self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix=f"{tool_name}_sim") + self._policy_threads: dict[str, Future] = {} + self._shutdown_event = threading.Event() + self._lock = threading.Lock() + + self._viewer_handle = None + self._viewer_thread = None + + self._renderers: dict[tuple, Any] = {} + self._renderer_model = None + + logger.info("🎮 Simulation tool '%s' initialized", tool_name) + + try: + from strands_robots.zenoh_mesh import init_mesh + + self.mesh = init_mesh(self, peer_id=peer_id, peer_type="sim", mesh=mesh) + except Exception as e: + logger.debug("Mesh init skipped: %s", e) + self.mesh = None + + # --- Public Properties --- + + @property + def mj_model(self): + """Direct access to the MuJoCo model (mujoco.MjModel).""" + return self._world._model if self._world else None + + @property + def mj_data(self): + """Direct access to the MuJoCo data (mujoco.MjData).""" + return self._world._data if self._world else None + + # --- Robot-compatible interface --- + + def get_observation(self, robot_name: str = None, camera_name: str = None) -> dict[str, Any]: + """Get observation from simulation (Robot ABC compatible).""" + if self._world is None or self._world._model is None: + return {} + if robot_name is None: + if not self._world.robots: + return {} + robot_name = next(iter(self._world.robots)) + if robot_name not in self._world.robots: + return {} + return self._get_sim_observation(robot_name, cam_name=camera_name) + + def send_action(self, action: dict[str, Any], robot_name: str = None, n_substeps: int = 1) -> None: + """Apply action to simulation (Robot ABC compatible).""" + if self._world is None or self._world._model is None: + return + if robot_name is None: + if not self._world.robots: + return + robot_name = next(iter(self._world.robots)) + if robot_name not in self._world.robots: + return + self._apply_sim_action(robot_name, action, n_substeps=n_substeps) + + # --- World Management --- + + def _cheap_robot_count(self) -> int: + try: + from strands_robots.registry import list_robots as _registry_list_robots + + return len(_registry_list_robots(mode="sim")) + except ImportError: + return 0 + + def create_world( + self, timestep: float = None, gravity: list[float] = None, ground_plane: bool = True + ) -> dict[str, Any]: + """Create a new simulation world.""" + _ensure_mujoco() + + if self._world is not None and self._world._model is not None: + return { + "status": "error", + "content": [{"text": "❌ World already exists. Use action='destroy' first, or action='reset'."}], + } + + if gravity is None: + _gravity = [0.0, 0.0, -9.81] + elif isinstance(gravity, (int, float)): + _gravity = [0.0, 0.0, float(gravity)] + else: + _gravity = list(gravity) + + self._world = SimWorld( + timestep=timestep or self.default_timestep, + gravity=_gravity, + ground_plane=ground_plane, + ) + + self._world.cameras["default"] = SimCamera( + name="default", + position=[1.5, 1.5, 1.2], + target=[0.0, 0.0, 0.3], + width=self.default_width, + height=self.default_height, + ) + + self._compile_world() + + return { + "status": "success", + "content": [ + { + "text": ( + "🌍 Simulation world created\n" + f"⚙️ Timestep: {self._world.timestep}s ({1 / self._world.timestep:.0f}Hz physics)\n" + f"🌐 Gravity: {self._world.gravity}\n" + f"📷 Default camera ready\n" + f"🤖 Robot models: {self._cheap_robot_count()} available\n" + "💡 Add robots: action='add_robot' (urdf_path or data_config)\n" + "💡 Add objects: action='add_object'\n" + "💡 List URDFs: action='list_urdfs'" + ) + } + ], + } + + def load_scene(self, scene_path: str) -> dict[str, Any]: + """Load a complete scene from MJCF XML or URDF file.""" + mj = _ensure_mujoco() + + if not os.path.exists(scene_path): + return {"status": "error", "content": [{"text": f"❌ Scene file not found: {scene_path}"}]} + + try: + self._world = SimWorld() + self._world._model = mj.MjModel.from_xml_path(str(scene_path)) + self._world._data = mj.MjData(self._world._model) + self._world.status = SimStatus.IDLE + + return { + "status": "success", + "content": [ + { + "text": ( + f"🌍 Scene loaded from {os.path.basename(scene_path)}\n" + f"🦴 Bodies: {self._world._model.nbody}, 🔩 Joints: {self._world._model.njnt}, ⚡ Actuators: {self._world._model.nu}\n" + "💡 Use action='get_state' to inspect, action='step' to simulate" + ) + } + ], + } + except Exception as e: + logger.error("Failed to load scene: %s", e) + return {"status": "error", "content": [{"text": f"❌ Failed to load scene: {e}"}]} + + def _compile_world(self): + mj = _ensure_mujoco() + xml = MJCFBuilder.build_objects_only(self._world) + self._world._xml = xml + self._world._model = mj.MjModel.from_xml_string(xml) + self._world._data = mj.MjData(self._world._model) + self._world.status = SimStatus.IDLE + + def _recompile_world(self) -> dict[str, Any]: + try: + self._compile_world() + return {"status": "success"} + except Exception as e: + return {"status": "error", "content": [{"text": f"❌ Recompile failed: {e}"}]} + + # --- Robot Management --- + + @staticmethod + def _ensure_meshes(model_path: str, robot_name: str): + """Check if mesh files referenced by a model XML exist; auto-download if missing.""" + model_dir = os.path.dirname(os.path.abspath(model_path)) + + files_to_check = [model_path] + try: + with open(model_path) as _f: + top_content = _f.read() + for inc in re.findall(r' dict[str, Any]: + """Add a robot to the simulation.""" + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world. Use action='create_world' first."}]} + if name in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{name}' already exists."}]} + + resolved_path = urdf_path + if not resolved_path and data_config: + resolved_path = resolve_model(data_config) + if not resolved_path: + return { + "status": "error", + "content": [ + { + "text": f"❌ No model found for '{data_config}'.\n💡 Use action='list_urdfs' to see available robots" + } + ], + } + elif not resolved_path and name: + resolved_path = resolve_model(name) + + if not resolved_path: + return {"status": "error", "content": [{"text": "❌ Either urdf_path or data_config is required."}]} + if not os.path.exists(resolved_path): + return {"status": "error", "content": [{"text": f"❌ File not found: {resolved_path}"}]} + + mj = _ensure_mujoco() + + robot = SimRobot( + name=name, + urdf_path=resolved_path, + position=position or [0.0, 0.0, 0.0], + orientation=orientation or [1.0, 0.0, 0.0, 0.0], + data_config=data_config, + namespace=f"{name}/", + ) + + try: + self._ensure_meshes(resolved_path, data_config or name) + + model = mj.MjModel.from_xml_path(str(resolved_path)) + data = mj.MjData(model) + + joint_names = [] + for i in range(model.njnt): + jnt_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) + if jnt_name: + joint_names.append(jnt_name) + robot.joint_ids.append(i) + robot.joint_names = joint_names + + for i in range(model.nu): + act_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_ACTUATOR, i) + if act_name: + jnt_id = model.actuator_trnid[i, 0] + if jnt_id in robot.joint_ids: + robot.actuator_ids.append(i) + else: + robot.actuator_ids.append(i) + if not robot.actuator_ids: + for i in range(model.nu): + robot.actuator_ids.append(i) + + for i in range(model.ncam): + cam_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) + if cam_name and cam_name not in self._world.cameras: + self._world.cameras[cam_name] = SimCamera( + name=cam_name, + camera_id=i, + width=self.default_width, + height=self.default_height, + ) + + self._world._model = model + self._world._data = data + self._world._robot_base_xml = resolved_path + self._world.robots[name] = robot + + for _ in range(100): + mj.mj_step(model, data) + + source = f"data_config='{data_config}'" if data_config else os.path.basename(resolved_path) + return { + "status": "success", + "content": [ + { + "text": ( + f"🤖 Robot '{name}' added to simulation\n" + f"📁 Source: {source} → {os.path.basename(resolved_path)}\n" + f"📍 Position: {robot.position}\n" + f"🔩 Joints: {len(robot.joint_names)} ({', '.join(robot.joint_names[:8])}{'...' if len(robot.joint_names) > 8 else ''})\n" + f"⚡ Actuators: {len(robot.actuator_ids)}\n" + f"📷 Cameras: {list(self._world.cameras.keys())}\n" + f"💡 Run policy: action='run_policy', robot_name='{name}'" + ) + } + ], + } + except Exception as e: + logger.error("Failed to add robot '%s': %s", name, e) + return {"status": "error", "content": [{"text": f"❌ Failed to load: {e}"}]} + + def remove_robot(self, name: str) -> dict[str, Any]: + if self._world is None or name not in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{name}' not found."}]} + if name in self._policy_threads: + self._world.robots[name].policy_running = False + try: + self._policy_threads[name].result(timeout=5.0) + except Exception: + pass + del self._policy_threads[name] + del self._world.robots[name] + return {"status": "success", "content": [{"text": f"🗑️ Robot '{name}' removed."}]} + + def list_robots(self) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + if not self._world.robots: + return {"status": "success", "content": [{"text": "No robots. Use action='add_robot'."}]} + + lines = ["🤖 Robots in simulation:\n"] + for name, robot in self._world.robots.items(): + status = "🟢 running" if robot.policy_running else "⚪ idle" + lines.append( + f" • {name} ({os.path.basename(robot.urdf_path)})\n" + f" Position: {robot.position}, Joints: {len(robot.joint_names)}, " + f"Config: {robot.data_config or 'direct'}, Status: {status}" + ) + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + def get_robot_state(self, robot_name: str) -> dict[str, Any]: + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation running."}]} + if robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found."}]} + + mj = _ensure_mujoco() + robot = self._world.robots[robot_name] + model, data = self._world._model, self._world._data + + state = {} + for jnt_name in robot.joint_names: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + state[jnt_name] = { + "position": float(data.qpos[model.jnt_qposadr[jnt_id]]), + "velocity": float(data.qvel[model.jnt_dofadr[jnt_id]]), + } + + text = f"🤖 '{robot_name}' state (t={self._world.sim_time:.3f}s):\n" + for jnt, vals in state.items(): + text += f" {jnt}: pos={vals['position']:.4f}, vel={vals['velocity']:.4f}\n" + + return {"status": "success", "content": [{"text": text}, {"text": json.dumps({"state": state}, default=str)}]} + + # --- Object Management --- + + def add_object( + self, + name: str, + shape: str = "box", + position: list[float] = None, + orientation: list[float] = None, + size: list[float] = None, + color: list[float] = None, + mass: float = 0.1, + is_static: bool = False, + mesh_path: str = None, + ) -> dict[str, Any]: + """Add an object to the simulation.""" + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + if name in self._world.objects: + return {"status": "error", "content": [{"text": f"❌ Object '{name}' exists."}]} + + obj = SimObject( + name=name, + shape=shape, + position=position or [0.0, 0.0, 0.0], + orientation=orientation or [1.0, 0.0, 0.0, 0.0], + size=size or [0.05, 0.05, 0.05], + color=color or [0.5, 0.5, 0.5, 1.0], + mass=mass, + mesh_path=mesh_path, + is_static=is_static, + ) + self._world.objects[name] = obj + + if self._world.robots: + try: + result = inject_object_into_scene(self._world, obj) + if result: + return { + "status": "success", + "content": [{"text": f"📦 '{name}' spawned: {shape} at {obj.position}"}], + } + return { + "status": "success", + "content": [ + { + "text": ( + f"📦 '{name}' registered: {shape} at {obj.position}\n" + "⚠️ Robot scene loaded — object is tracked but not physically spawned." + ) + } + ], + } + except (ValueError, RuntimeError) as e: + raise RuntimeError( + f"Object injection into live scene failed for '{name}': {e}. " + f"Check that the MJCF XML is valid and compatible with the current scene." + ) from e + + result = self._recompile_world() + if result["status"] == "error": + del self._world.objects[name] + return result + + return { + "status": "success", + "content": [ + { + "text": f"📦 '{name}' added: {shape} at {obj.position}, size={obj.size}, {'static' if is_static else f'{mass}kg'}" + } + ], + } + + def remove_object(self, name: str) -> dict[str, Any]: + if self._world is None or name not in self._world.objects: + return {"status": "error", "content": [{"text": f"❌ Object '{name}' not found."}]} + del self._world.objects[name] + if self._world.robots: + eject_body_from_scene(self._world, name) + else: + self._recompile_world() + return {"status": "success", "content": [{"text": f"🗑️ '{name}' removed."}]} + + def move_object(self, name: str, position: list[float] = None, orientation: list[float] = None) -> dict[str, Any]: + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + if name not in self._world.objects: + return {"status": "error", "content": [{"text": f"❌ '{name}' not found."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, f"{name}_joint") + if jnt_id >= 0: + qpos_addr = model.jnt_qposadr[jnt_id] + if position: + data.qpos[qpos_addr : qpos_addr + 3] = position + self._world.objects[name].position = position + if orientation: + data.qpos[qpos_addr + 3 : qpos_addr + 7] = orientation + self._world.objects[name].orientation = orientation + mj.mj_forward(model, data) + + return {"status": "success", "content": [{"text": f"📍 '{name}' moved to {position or 'same'}"}]} + + def list_objects(self) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + if not self._world.objects: + return {"status": "success", "content": [{"text": "No objects."}]} + + lines = ["📦 Objects:\n"] + for name, obj in self._world.objects.items(): + lines.append(f" • {name}: {obj.shape} at {obj.position}, {'static' if obj.is_static else f'{obj.mass}kg'}") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + # --- Camera Management --- + + def add_camera( + self, + name: str, + position: list[float] = None, + target: list[float] = None, + fov: float = 60.0, + width: int = 640, + height: int = 480, + ) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + + cam = SimCamera( + name=name, + position=position or [1.0, 1.0, 1.0], + target=target or [0.0, 0.0, 0.0], + fov=fov, + width=width, + height=height, + ) + self._world.cameras[name] = cam + + if self._world.robots and self._world._model is not None: + try: + inject_camera_into_scene(self._world, cam) + except (ValueError, RuntimeError) as e: + raise RuntimeError( + f"Camera injection into live scene failed for '{name}': {e}. " + f"Check that camera parameters are valid." + ) from e + else: + self._recompile_world() + + return {"status": "success", "content": [{"text": f"📷 Camera '{name}' added at {cam.position}"}]} + + def remove_camera(self, name: str) -> dict[str, Any]: + if self._world is None or name not in self._world.cameras: + return {"status": "error", "content": [{"text": f"❌ Camera '{name}' not found."}]} + del self._world.cameras[name] + return {"status": "success", "content": [{"text": f"🗑️ Camera '{name}' removed."}]} + + # --- Simulation Control --- + + def step(self, n_steps: int = 1) -> dict[str, Any]: + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + mj = _ensure_mujoco() + for _ in range(n_steps): + mj.mj_step(self._world._model, self._world._data) + self._world.sim_time = self._world._data.time + self._world.step_count += n_steps + return { + "status": "success", + "content": [ + {"text": f"⏩ +{n_steps} steps | t={self._world.sim_time:.4f}s | total={self._world.step_count}"} + ], + } + + def reset(self) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + mj = _ensure_mujoco() + mj.mj_resetData(self._world._model, self._world._data) + self._world.sim_time = 0.0 + self._world.step_count = 0 + for r in self._world.robots.values(): + r.policy_running = False + r.policy_steps = 0 + return {"status": "success", "content": [{"text": "🔄 Reset to initial state."}]} + + def get_state(self) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + lines = [ + "🌍 Simulation State", + f"🕐 t={self._world.sim_time:.4f}s (step {self._world.step_count})", + f"⚙️ dt={self._world.timestep}s | 🌐 g={self._world.gravity}", + f"🤖 Robots: {len(self._world.robots)} | 📦 Objects: {len(self._world.objects)} | 📷 Cameras: {len(self._world.cameras)}", + ] + if self._world._model: + lines.append( + f"🦴 Bodies: {self._world._model.nbody} | 🔩 Joints: {self._world._model.njnt} | ⚡ Actuators: {self._world._model.nu}" + ) + if self._world._recording: + lines.append(f"🔴 Recording: {len(self._world._trajectory)} steps") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + def destroy(self) -> dict[str, Any]: + if self._world is None: + return {"status": "success", "content": [{"text": "No world to destroy."}]} + for r in self._world.robots.values(): + r.policy_running = False + self._close_viewer() + self._world = None + return {"status": "success", "content": [{"text": "🗑️ World destroyed."}]} + + def set_gravity(self, gravity) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + if isinstance(gravity, (int, float)): + gravity = [0.0, 0.0, float(gravity)] + self._world._model.opt.gravity[:] = gravity + self._world.gravity = gravity + return {"status": "success", "content": [{"text": f"🌐 Gravity: {gravity}"}]} + + def set_timestep(self, timestep: float) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + self._world._model.opt.timestep = timestep + self._world.timestep = timestep + return {"status": "success", "content": [{"text": f"⏱️ Timestep: {timestep}s ({1 / timestep:.0f}Hz)"}]} + + # --- Viewer --- + + def open_viewer(self) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation to view."}]} + from strands_robots.simulation.mujoco.backend import _mujoco_viewer + + if _mujoco_viewer is None: + return {"status": "error", "content": [{"text": "❌ mujoco.viewer not available."}]} + if self._viewer_handle is not None: + return {"status": "success", "content": [{"text": "👁️ Viewer already open."}]} + try: + self._viewer_handle = _mujoco_viewer.launch_passive(self._world._model, self._world._data) + return {"status": "success", "content": [{"text": "👁️ Interactive viewer opened."}]} + except Exception as e: + return {"status": "error", "content": [{"text": f"❌ Viewer failed: {e}"}]} + + def _close_viewer(self): + if self._viewer_handle is not None: + try: + self._viewer_handle.close() + except Exception: + pass + self._viewer_handle = None + + def close_viewer(self) -> dict[str, Any]: + self._close_viewer() + return {"status": "success", "content": [{"text": "👁️ Viewer closed."}]} + + # --- URDF Registry --- + + def list_urdfs_action(self) -> dict[str, Any]: + return {"status": "success", "content": [{"text": list_available_models()}]} + + def register_urdf_action(self, data_config: str, urdf_path: str) -> dict[str, Any]: + register_urdf(data_config, urdf_path) + resolved = resolve_model(data_config) + return { + "status": "success", + "content": [{"text": f"📋 Registered '{data_config}' → {urdf_path}\nResolved: {resolved or 'NOT FOUND'}"}], + } + + # --- Introspection --- + + def get_features(self) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model = self._world._model + + joint_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) for i in range(model.njnt)] + joint_names = [n for n in joint_names if n] + actuator_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_ACTUATOR, i) for i in range(model.nu)] + actuator_names = [n for n in actuator_names if n] + camera_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) for i in range(model.ncam)] + camera_names = [n for n in camera_names if n] + + robots_info = {} + for rname, robot in self._world.robots.items(): + robots_info[rname] = { + "joint_names": robot.joint_names, + "n_joints": len(robot.joint_names), + "n_actuators": len(robot.actuator_ids), + "data_config": robot.data_config, + "source": os.path.basename(robot.urdf_path), + } + + features = { + "n_bodies": model.nbody, + "n_joints": model.njnt, + "n_actuators": model.nu, + "n_cameras": model.ncam, + "timestep": model.opt.timestep, + "joint_names": joint_names, + "actuator_names": actuator_names, + "camera_names": camera_names, + "robots": robots_info, + } + + lines = [ + "🔍 Simulation Features", + f"🦴 Joints ({model.njnt}): {', '.join(joint_names[:12])}{'...' if len(joint_names) > 12 else ''}", + f"⚡ Actuators ({model.nu}): {', '.join(actuator_names[:12])}{'...' if len(actuator_names) > 12 else ''}", + f"📷 Cameras ({model.ncam}): {', '.join(camera_names) if camera_names else 'none (free camera only)'}", + f"⏱️ Timestep: {model.opt.timestep}s ({1 / model.opt.timestep:.0f}Hz)", + ] + for rname, rinfo in robots_info.items(): + lines.append( + f"🤖 {rname}: {rinfo['n_joints']} joints, {rinfo['n_actuators']} actuators ({rinfo['source']})" + ) + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"text": json.dumps({"features": features}, default=str)}], + } + + # --- AgentTool Interface --- + + @property + def tool_name(self) -> str: + return self.tool_name_str + + @property + def tool_type(self) -> str: + return "simulation" + + @property + def tool_spec(self) -> ToolSpec: + with open(_TOOL_SPEC_PATH) as f: + schema = json.load(f) + return { + "name": self.tool_name_str, + "description": ( + "Programmatic MuJoCo simulation environment. Create worlds, add robots from URDF " + "(direct path or auto-resolve from data_config name), add objects, run VLA policies, " + "render cameras, record trajectories, domain randomize. " + "Same Policy ABC as real robot control — sim ↔ real with zero code changes. " + "Actions: create_world, load_scene, reset, get_state, destroy, " + "add_robot, remove_robot, list_robots, get_robot_state, " + "add_object, remove_object, move_object, list_objects, " + "add_camera, remove_camera, " + "run_policy, start_policy, stop_policy, " + "render, render_depth, get_contacts, " + "step, set_gravity, set_timestep, " + "randomize, " + "start_recording, stop_recording, get_recording_status, " + "open_viewer, close_viewer, " + "list_urdfs, register_urdf, get_features" + ), + "inputSchema": {"json": schema}, + } + + async def stream( + self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any + ) -> AsyncGenerator[ToolResultEvent, None]: + try: + tool_use_id = tool_use.get("toolUseId", "") + input_data = tool_use.get("input", {}) + result = self._dispatch_action(input_data.get("action", ""), input_data) + yield ToolResultEvent({"toolUseId": tool_use_id, **result}) + except Exception as e: + yield ToolResultEvent( + { + "toolUseId": tool_use.get("toolUseId", ""), + "status": "error", + "content": [{"text": f"❌ Sim error: {e}"}], + } + ) + + def _dispatch_action(self, action: str, d: dict[str, Any]) -> dict[str, Any]: + """Route action string to method via getattr. + + Method names match action names directly (with a few aliases). + """ + # Aliases for actions whose method names differ + _ALIASES = { + "list_urdfs": "list_urdfs_action", + "register_urdf": "register_urdf_action", + "stop_policy": "_stop_policy", + } + + # Map input field names to method parameter names for physics actions + _FIELD_MAP = { + "checkpoint_name": "name", + "torque_vec": "torque", + } + + method_name = _ALIASES.get(action, action) + method = getattr(self, method_name, None) + + if method is None or action.startswith("_"): + return {"status": "error", "content": [{"text": f"❌ Unknown action: {action}"}]} + + # Build kwargs from input dict, excluding 'action' itself + # Signatures are cached per method to avoid repeated introspection. + import inspect + + cache = getattr(self, "_sig_cache", None) + if cache is None: + self._sig_cache = cache = {} + if method_name not in cache: + cache[method_name] = inspect.signature(method) + sig = cache[method_name] + # Apply field name remapping + remapped = dict(d) + for field_key, param_key in _FIELD_MAP.items(): + if field_key in remapped and param_key not in remapped: + remapped[param_key] = remapped.pop(field_key) + + kwargs = {} + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + # Handle name/robot_name/body_name ambiguity in the input schema + if param_name == "name" and "name" not in remapped and "robot_name" in remapped: + kwargs["name"] = remapped["robot_name"] + elif param_name == "name" and "name" not in remapped and "checkpoint_name" in d: + kwargs["name"] = d["checkpoint_name"] + elif param_name == "robot_name" and "robot_name" not in remapped and "name" in remapped: + kwargs["robot_name"] = remapped["name"] + elif param_name in remapped: + kwargs[param_name] = remapped[param_name] + # Forward policy kwargs + elif param.kind == inspect.Parameter.VAR_KEYWORD: + for k in ( + "policy_port", + "policy_host", + "model_path", + "server_address", + "policy_type", + "pretrained_name_or_path", + "device", + ): + if k in d: + kwargs[k] = d[k] + + return method(**kwargs) + + def _stop_policy(self, robot_name: str = "", **kwargs) -> dict[str, Any]: + if self._world and robot_name in self._world.robots: + self._world.robots[robot_name].policy_running = False + return {"status": "success", "content": [{"text": f"🛑 Stopped on '{robot_name}'"}]} + return {"status": "error", "content": [{"text": f"❌ '{robot_name}' not found."}]} + + # --- Cleanup --- + + def cleanup(self): + if hasattr(self, "mesh") and self.mesh: + self.mesh.stop() + if self._world: + for r in self._world.robots.values(): + r.policy_running = False + self._world = None + self._close_viewer() + for renderer in getattr(self, "_renderers", {}).values(): + try: + renderer.close() + except Exception: + pass + self._renderers.clear() + self._executor.shutdown(wait=False) + self._shutdown_event.set() + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.cleanup() + + def __del__(self): + try: + self.cleanup() + except Exception: + pass diff --git a/strands_robots/simulation/mujoco/tool_spec.json b/strands_robots/simulation/mujoco/tool_spec.json new file mode 100644 index 0000000..4147a4b --- /dev/null +++ b/strands_robots/simulation/mujoco/tool_spec.json @@ -0,0 +1,351 @@ +{ + "type": "object", + "properties": { + "action": { + "type": "string", + "description": "Action to perform", + "enum": [ + "create_world", + "load_scene", + "reset", + "get_state", + "destroy", + "add_robot", + "remove_robot", + "list_robots", + "get_robot_state", + "add_object", + "remove_object", + "move_object", + "list_objects", + "add_camera", + "remove_camera", + "run_policy", + "start_policy", + "stop_policy", + "render", + "render_depth", + "get_contacts", + "step", + "set_gravity", + "set_timestep", + "randomize", + "start_recording", + "stop_recording", + "get_recording_status", + "open_viewer", + "close_viewer", + "list_urdfs", + "register_urdf", + "get_features", + "replay_episode", + "eval_policy", + "save_state", + "load_state", + "apply_force", + "raycast", + "multi_raycast", + "get_jacobian", + "get_energy", + "get_mass_matrix", + "inverse_dynamics", + "get_body_state", + "set_joint_positions", + "set_joint_velocities", + "get_sensor_data", + "set_body_properties", + "set_geom_properties", + "get_contact_forces", + "forward_kinematics", + "get_total_mass", + "export_xml" + ] + }, + "scene_path": { + "type": "string", + "description": "Path to MJCF/URDF scene file" + }, + "timestep": { + "type": "number" + }, + "gravity": { + "type": "array", + "items": { + "type": "number" + } + }, + "ground_plane": { + "type": "boolean" + }, + "urdf_path": { + "type": "string", + "description": "Path to URDF/MJCF file" + }, + "robot_name": { + "type": "string" + }, + "data_config": { + "type": "string", + "description": "Data config name (auto-resolves URDF)" + }, + "name": { + "type": "string", + "description": "Object/camera name" + }, + "shape": { + "type": "string", + "enum": [ + "box", + "sphere", + "cylinder", + "capsule", + "mesh", + "plane" + ] + }, + "position": { + "type": "array", + "items": { + "type": "number" + } + }, + "orientation": { + "type": "array", + "items": { + "type": "number" + } + }, + "size": { + "type": "array", + "items": { + "type": "number" + } + }, + "color": { + "type": "array", + "items": { + "type": "number" + } + }, + "mass": { + "type": "number" + }, + "is_static": { + "type": "boolean" + }, + "mesh_path": { + "type": "string" + }, + "target": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Camera target point" + }, + "fov": { + "type": "number", + "description": "Camera field of view" + }, + "width": { + "type": "integer" + }, + "height": { + "type": "integer" + }, + "policy_provider": { + "type": "string", + "description": "Policy provider name (e.g. groot, lerobot_async, lerobot_local, dreamgen, mock)" + }, + "instruction": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "policy_port": { + "type": "integer" + }, + "policy_host": { + "type": "string" + }, + "model_path": { + "type": "string" + }, + "action_horizon": { + "type": "integer" + }, + "control_frequency": { + "type": "number" + }, + "camera_name": { + "type": "string" + }, + "n_steps": { + "type": "integer" + }, + "output_path": { + "type": "string", + "description": "Trajectory/video export path" + }, + "fps": { + "type": "integer", + "description": "Video frames per second (for run_policy record_video)" + }, + "pretrained_name_or_path": { + "type": "string", + "description": "HuggingFace model ID for lerobot_local" + }, + "randomize_colors": { + "type": "boolean" + }, + "randomize_lighting": { + "type": "boolean" + }, + "randomize_physics": { + "type": "boolean" + }, + "randomize_positions": { + "type": "boolean" + }, + "position_noise": { + "type": "number" + }, + "seed": { + "type": "integer", + "description": "Random seed" + }, + "repo_id": { + "type": "string", + "description": "HuggingFace dataset repo ID" + }, + "push_to_hub": { + "type": "boolean", + "description": "Auto-push dataset to HuggingFace Hub on stop_recording" + }, + "vcodec": { + "type": "string", + "description": "Video codec for dataset recording (h264, hevc, libsvtav1)" + }, + "task": { + "type": "string", + "description": "Task description for dataset recording" + }, + "episode": { + "type": "integer", + "description": "Episode index for replay_episode" + }, + "root": { + "type": "string", + "description": "Local dataset root directory" + }, + "speed": { + "type": "number", + "description": "Replay speed multiplier (1.0 = original)" + }, + "n_episodes": { + "type": "integer", + "description": "Number of eval episodes" + }, + "max_steps": { + "type": "integer", + "description": "Max steps per eval episode" + }, + "success_fn": { + "type": "string", + "description": "Success function ('contact')" + }, + "fast_mode": { + "type": "boolean", + "description": "Skip sleep between actions for faster data collection" + }, + "body_name": { + "type": "string", + "description": "Target body name" + }, + "site_name": { + "type": "string", + "description": "Site name for Jacobian" + }, + "geom_name": { + "type": "string", + "description": "Geom name" + }, + "geom_id": { + "type": "integer", + "description": "Geom ID (alternative to geom_name)" + }, + "force": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Force vector [fx, fy, fz] in Newtons" + }, + "torque_vec": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Torque vector [tx, ty, tz] in N\u00b7m" + }, + "point": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Point of force application [x, y, z]" + }, + "origin": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Ray origin [x, y, z]" + }, + "direction": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Ray direction [dx, dy, dz]" + }, + "directions": { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "number" + } + }, + "description": "Multiple ray directions for multi_raycast" + }, + "exclude_body": { + "type": "integer", + "description": "Body ID to exclude from raycast (-1=none)" + }, + "include_static": { + "type": "boolean", + "description": "Include static geoms in raycast" + }, + "positions": { + "type": "object", + "description": "Joint name \u2192 position mapping for set_joint_positions" + }, + "velocities": { + "type": "object", + "description": "Joint name \u2192 velocity mapping for set_joint_velocities" + }, + "sensor_name": { + "type": "string", + "description": "Specific sensor name (or omit for all)" + }, + "checkpoint_name": { + "type": "string", + "description": "Named checkpoint for save_state/load_state" + } + }, + "required": [ + "action" + ] +} \ No newline at end of file diff --git a/tests/test_mujoco_e2e.py b/tests/test_mujoco_e2e.py new file mode 100644 index 0000000..c09cb0c --- /dev/null +++ b/tests/test_mujoco_e2e.py @@ -0,0 +1,269 @@ +"""End-to-end MuJoCo simulation test with Policy ABC. + +Tests the full observe → policy → act → step → render pipeline +without requiring strands SDK or lerobot — just mujoco + numpy. + +Run: python -m pytest tests/test_mujoco_e2e.py -v +""" + +import asyncio +import os +import shutil +import tempfile + +import numpy as np +import pytest + +# Skip entire module if mujoco not installed +mj = pytest.importorskip("mujoco") + + +def _has_opengl() -> bool: + """Check if OpenGL rendering is available.""" + try: + model = mj.MjModel.from_xml_string("") + renderer = mj.Renderer(model, height=1, width=1) + del renderer + return True + except Exception: + return False + + +requires_gl = pytest.mark.skipif( + not _has_opengl(), + reason="No OpenGL context available (headless environment without EGL/OSMesa)", +) + + +from strands_robots.policies import MockPolicy # noqa: E402 +from strands_robots.simulation.base import SimulationBackend # noqa: E402 +from strands_robots.simulation.models import SimObject, SimRobot, SimStatus, SimWorld # noqa: E402 + +# ── Fixtures ── + +ROBOT_XML = """ + + + +""" + + +@pytest.fixture +def sim_env(): + """Create a MuJoCo model+data from test XML.""" + tmpdir = tempfile.mkdtemp() + xml_path = os.path.join(tmpdir, "test_arm.xml") + with open(xml_path, "w") as f: + f.write(ROBOT_XML) + + model = mj.MjModel.from_xml_path(xml_path) + data = mj.MjData(model) + + yield model, data + + shutil.rmtree(tmpdir, ignore_errors=True) + + +JOINT_NAMES = ["shoulder_pan", "shoulder_lift", "elbow"] + + +def read_joints(model, data): + obs = {} + for jname in JOINT_NAMES: + jid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jname) + obs[jname] = float(data.qpos[model.jnt_qposadr[jid]]) + return obs + + +def apply_action(model, data, action_dict): + for key, val in action_dict.items(): + act_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_ACTUATOR, f"{key}_act") + if act_id >= 0: + data.ctrl[act_id] = val + + +# ── Tests ── + + +class TestSimulationBase: + def test_abc_has_required_methods(self): + required = [ + "create_world", + "destroy", + "reset", + "step", + "get_state", + "add_robot", + "remove_robot", + "add_object", + "remove_object", + "get_observation", + "send_action", + "render", + ] + for method in required: + assert hasattr(SimulationBackend, method) + + def test_shared_dataclasses(self): + w = SimWorld() + assert w.timestep == 0.002 + assert w.gravity == [0.0, 0.0, -9.81] + assert w.status == SimStatus.IDLE + + r = SimRobot(name="test", urdf_path="/tmp/test.urdf") + assert r.joint_names == [] + + o = SimObject(name="cube", shape="box") + assert o.mass == 0.1 + + +class TestMuJoCoPhysics: + def test_step_advances_time(self, sim_env): + model, data = sim_env + assert data.time == 0.0 + for _ in range(100): + mj.mj_step(model, data) + assert data.time == pytest.approx(0.2, abs=1e-6) + + def test_position_actuators_move_joints(self, sim_env): + model, data = sim_env + data.ctrl[0] = 1.0 # shoulder_pan target + for _ in range(1000): + mj.mj_step(model, data) + obs = read_joints(model, data) + assert abs(obs["shoulder_pan"] - 1.0) < 0.15 + + def test_contacts_detected(self, sim_env): + model, data = sim_env + for _ in range(100): + mj.mj_step(model, data) + assert data.ncon > 0 # cube on ground + + def test_reset_zeros_time(self, sim_env): + model, data = sim_env + for _ in range(100): + mj.mj_step(model, data) + mj.mj_resetData(model, data) + assert data.time == 0.0 + + +@requires_gl +class TestMuJoCoRendering: + def test_render_rgb(self, sim_env): + model, data = sim_env + mj.mj_forward(model, data) + renderer = mj.Renderer(model, height=240, width=320) + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, "front") + renderer.update_scene(data, camera=cam_id) + img = renderer.render() + assert img.shape == (240, 320, 3) + assert img.dtype == np.uint8 + assert img.max() > 0 + del renderer + + def test_render_depth(self, sim_env): + model, data = sim_env + mj.mj_forward(model, data) + renderer = mj.Renderer(model, height=120, width=160) + renderer.update_scene(data) + renderer.enable_depth_rendering() + depth = renderer.render() + renderer.disable_depth_rendering() + assert depth.shape == (120, 160) + assert depth.max() > 0 + del renderer + + +class TestMockPolicyLoop: + def test_mock_policy_generates_actions(self): + policy = MockPolicy() + policy.set_robot_state_keys(JOINT_NAMES) + obs = {j: 0.0 for j in JOINT_NAMES} + actions = asyncio.run(policy.get_actions(obs, "test")) + assert len(actions) == 8 + assert all(j in actions[0] for j in JOINT_NAMES) + + def test_full_observe_act_loop(self, sim_env): + model, data = sim_env + policy = MockPolicy() + policy.set_robot_state_keys(JOINT_NAMES) + + for step in range(20): + obs = read_joints(model, data) + actions = asyncio.run(policy.get_actions(obs, "pick up cube")) + apply_action(model, data, actions[0]) + mj.mj_step(model, data) + + assert data.time > 0 + final_obs = read_joints(model, data) + # Joints should have moved from 0 + assert any(abs(v) > 0.001 for v in final_obs.values()) + + @requires_gl + def test_loop_with_rendering(self, sim_env): + """Full loop: observe → policy → act → step → render (10 iterations).""" + model, data = sim_env + policy = MockPolicy() + policy.set_robot_state_keys(JOINT_NAMES) + renderer = mj.Renderer(model, height=120, width=160) + + frames = [] + for _ in range(10): + obs = read_joints(model, data) + actions = asyncio.run(policy.get_actions(obs, "wave")) + apply_action(model, data, actions[0]) + mj.mj_step(model, data) + + renderer.update_scene(data) + frames.append(renderer.render().copy()) + + assert len(frames) == 10 + assert all(f.shape == (120, 160, 3) for f in frames) + # Frames should differ (robot is moving) + assert not np.array_equal(frames[0], frames[-1]) + del renderer + + +class TestDomainRandomization: + def test_color_randomization(self, sim_env): + model, data = sim_env + orig = model.geom_rgba.copy() + rng = np.random.default_rng(42) + for i in range(model.ngeom): + model.geom_rgba[i, :3] = rng.uniform(0.1, 1.0, size=3) + assert not np.array_equal(orig, model.geom_rgba) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_physics.py b/tests/test_physics.py new file mode 100644 index 0000000..17e03a2 --- /dev/null +++ b/tests/test_physics.py @@ -0,0 +1,350 @@ +"""Tests for PhysicsMixin — advanced MuJoCo physics features. + +Tests: raycasting, jacobians, energy, forces, state checkpointing, +inverse dynamics, sensor readout, body introspection, runtime modification. + +Run: uv run pytest tests/test_physics.py -v +""" + +import json +import os + +import numpy as np +import pytest + +mj = pytest.importorskip("mujoco") + +from strands_robots.simulation.mujoco.simulation import Simulation # noqa: E402 + +ROBOT_XML = """ + + + +""" + + +@pytest.fixture +def sim(): + """Create a Simulation with the test scene loaded directly.""" + from strands_robots.simulation.models import SimStatus, SimWorld + + s = Simulation(tool_name="test_sim", mesh=False) + s._world = SimWorld() + s._world._model = mj.MjModel.from_xml_string(ROBOT_XML) + s._world._data = mj.MjData(s._world._model) + s._world.status = SimStatus.IDLE + mj.mj_forward(s._world._model, s._world._data) + yield s + s.cleanup() + + +class TestRaycasting: + def test_raycast_hits_ground(self, sim): + result = sim.raycast(origin=[0, 0, 2], direction=[0, 0, -1]) + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert data["hit"] is True + assert data["distance"] is not None + assert data["distance"] > 0 + + def test_raycast_hits_box(self, sim): + result = sim.raycast(origin=[0, 0, 2], direction=[0, 0, -1]) + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert data["hit"] is True + assert data["geom_name"] in ("box_geom", "ground") + + def test_raycast_misses(self, sim): + result = sim.raycast(origin=[0, 0, 2], direction=[0, 0, 1]) # shooting up + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert data["hit"] is False + + def test_multi_raycast(self, sim): + dirs = [[0, 0, -1], [1, 0, 0], [0, 1, 0], [0, 0, 1]] + result = sim.multi_raycast(origin=[0, 0, 2], directions=dirs) + assert result["status"] == "success" + rays = json.loads(result["content"][1]["text"])["rays"] + assert len(rays) == 4 + # At least the downward ray should hit + assert rays[0]["distance"] is not None + + +class TestJacobians: + def test_body_jacobian(self, sim): + result = sim.get_jacobian(body_name="link2") + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert len(data["jacp"]) == 3 # 3×nv + assert data["nv"] == sim._world._model.nv + + def test_site_jacobian(self, sim): + result = sim.get_jacobian(site_name="end_effector") + assert result["status"] == "success" + + def test_geom_jacobian(self, sim): + result = sim.get_jacobian(geom_name="link2_geom") + assert result["status"] == "success" + + def test_jacobian_no_target(self, sim): + result = sim.get_jacobian() + assert result["status"] == "error" + + def test_jacobian_invalid_body(self, sim): + result = sim.get_jacobian(body_name="nonexistent") + assert result["status"] == "error" + + +class TestEnergy: + def test_get_energy(self, sim): + result = sim.get_energy() + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert "potential" in data + assert "kinetic" in data + assert "total" in data + # Box at height 0.5 should have nonzero potential energy + assert data["potential"] != 0 or data["kinetic"] != 0 + + def test_energy_changes_after_step(self, sim): + e1 = json.loads(sim.get_energy()["content"][1]["text"]) + # Step physics to let box fall + for _ in range(100): + mj.mj_step(sim._world._model, sim._world._data) + e2 = json.loads(sim.get_energy()["content"][1]["text"]) + # Kinetic energy should change (box falls) + assert e1["kinetic"] != e2["kinetic"] or e1["potential"] != e2["potential"] + + +class TestExternalForces: + def test_apply_force(self, sim): + result = sim.apply_force(body_name="box1", force=[0, 0, 100]) + assert result["status"] == "success" + assert "box1" in result["content"][0]["text"] + + def test_apply_force_invalid_body(self, sim): + result = sim.apply_force(body_name="nonexistent", force=[0, 0, 10]) + assert result["status"] == "error" + + def test_force_changes_acceleration(self, sim): + # Get initial state + data = sim._world._data + old_qfrc = data.qfrc_applied.copy() + sim.apply_force(body_name="box1", force=[0, 0, 100]) + # qfrc_applied should change + assert not np.array_equal(old_qfrc, data.qfrc_applied) + + +class TestMassMatrix: + def test_get_mass_matrix(self, sim): + result = sim.get_mass_matrix() + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + nv = sim._world._model.nv + assert data["shape"] == [nv, nv] + assert data["rank"] > 0 + assert data["total_mass"] > 0 + + def test_mass_diagonal_positive(self, sim): + result = sim.get_mass_matrix() + diag = json.loads(result["content"][1]["text"])["diagonal"] + assert all(d >= 0 for d in diag) + + +class TestStateCheckpointing: + def test_save_and_load_state(self, sim): + # Set a known joint position + sim._world._data.qpos[7] = 1.0 # shoulder + mj.mj_forward(sim._world._model, sim._world._data) + + # Save + result = sim.save_state(name="test_checkpoint") + assert result["status"] == "success" + + # Change state + sim._world._data.qpos[7] = -1.0 + mj.mj_forward(sim._world._model, sim._world._data) + assert sim._world._data.qpos[7] == pytest.approx(-1.0) + + # Restore + result = sim.load_state(name="test_checkpoint") + assert result["status"] == "success" + assert sim._world._data.qpos[7] == pytest.approx(1.0) + + def test_load_nonexistent_checkpoint(self, sim): + result = sim.load_state(name="doesnt_exist") + assert result["status"] == "error" + + +class TestInverseDynamics: + def test_inverse_dynamics(self, sim): + mj.mj_forward(sim._world._model, sim._world._data) + result = sim.inverse_dynamics() + assert result["status"] == "success" + forces = json.loads(result["content"][1]["text"])["qfrc_inverse"] + assert "shoulder" in forces or "elbow" in forces + + +class TestBodyState: + def test_get_body_state(self, sim): + result = sim.get_body_state(body_name="box1") + assert result["status"] == "success" + state = json.loads(result["content"][1]["text"]) + assert "position" in state + assert "quaternion" in state + assert "linear_velocity" in state + assert "angular_velocity" in state + assert "mass" in state + assert len(state["position"]) == 3 + assert len(state["quaternion"]) == 4 + assert state["mass"] == pytest.approx(1.0) + + def test_body_state_invalid(self, sim): + result = sim.get_body_state(body_name="nonexistent") + assert result["status"] == "error" + + +class TestDirectJointControl: + def test_set_joint_positions(self, sim): + result = sim.set_joint_positions(positions={"shoulder": 0.5, "elbow": -0.3}) + assert result["status"] == "success" + assert "2/2" in result["content"][0]["text"] + + # Verify positions were set + model, data = sim._world._model, sim._world._data + shoulder_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, "shoulder") + qpos_adr = model.jnt_qposadr[shoulder_id] + assert data.qpos[qpos_adr] == pytest.approx(0.5) + + def test_set_joint_velocities(self, sim): + result = sim.set_joint_velocities(velocities={"shoulder": 1.0}) + assert result["status"] == "success" + + +class TestSensors: + def test_get_all_sensors(self, sim): + result = sim.get_sensor_data() + assert result["status"] == "success" + sensors = json.loads(result["content"][1]["text"])["sensors"] + assert "shoulder_pos" in sensors + assert "elbow_pos" in sensors + + def test_get_specific_sensor(self, sim): + result = sim.get_sensor_data(sensor_name="shoulder_pos") + assert result["status"] == "success" + sensors = json.loads(result["content"][1]["text"])["sensors"] + assert len(sensors) == 1 + assert "shoulder_pos" in sensors + + def test_sensor_values_change(self, sim): + # Set shoulder position + sim.set_joint_positions(positions={"shoulder": 1.0}) + result = sim.get_sensor_data(sensor_name="shoulder_pos") + val = json.loads(result["content"][1]["text"])["sensors"]["shoulder_pos"]["values"] + assert abs(val - 1.0) < 0.01 + + +class TestRuntimeModification: + def test_set_body_mass(self, sim): + result = sim.set_body_properties(body_name="box1", mass=5.0) + assert result["status"] == "success" + body_id = mj.mj_name2id(sim._world._model, mj.mjtObj.mjOBJ_BODY, "box1") + assert sim._world._model.body_mass[body_id] == pytest.approx(5.0) + + def test_set_geom_color(self, sim): + result = sim.set_geom_properties(geom_name="box_geom", color=[0, 1, 0, 1]) + assert result["status"] == "success" + geom_id = mj.mj_name2id(sim._world._model, mj.mjtObj.mjOBJ_GEOM, "box_geom") + assert sim._world._model.geom_rgba[geom_id][1] == pytest.approx(1.0) + + def test_set_geom_friction(self, sim): + result = sim.set_geom_properties(geom_name="box_geom", friction=[0.5, 0.01, 0.001]) + assert result["status"] == "success" + + def test_invalid_geom(self, sim): + result = sim.set_geom_properties(geom_name="nonexistent", color=[1, 0, 0, 1]) + assert result["status"] == "error" + + +class TestContactForces: + def test_get_contact_forces_after_settling(self, sim): + # Let box fall and settle + for _ in range(500): + mj.mj_step(sim._world._model, sim._world._data) + result = sim.get_contact_forces() + assert result["status"] == "success" + # Box should be in contact with ground + contacts = json.loads(result["content"][1]["text"])["contacts"] + assert len(contacts) > 0 + assert contacts[0]["normal_force"] != 0 + + +class TestForwardKinematics: + def test_forward_kinematics(self, sim): + result = sim.forward_kinematics() + assert result["status"] == "success" + bodies = json.loads(result["content"][1]["text"])["bodies"] + assert "box1" in bodies + assert "link1" in bodies + assert len(bodies["box1"]["position"]) == 3 + + +class TestTotalMass: + def test_get_total_mass(self, sim): + result = sim.get_total_mass() + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert data["total_mass"] > 0 + assert "box1" in data["bodies"] + assert data["bodies"]["box1"] == pytest.approx(1.0) + + +class TestExportXML: + def test_export_xml_string(self, sim): + result = sim.export_xml() + assert result["status"] == "success" + text = result["content"][0]["text"] + assert "mujoco" in text.lower() or "Model XML" in text + + def test_export_xml_file(self, sim, tmp_path): + path = str(tmp_path / "exported.xml") + result = sim.export_xml(output_path=path) + assert result["status"] == "success" + assert os.path.exists(path) + with open(path) as f: + content = f.read() + assert " Date: Wed, 1 Apr 2026 15:11:56 -0400 Subject: [PATCH 02/22] =?UTF-8?q?fix:=20address=20all=20review=20comments?= =?UTF-8?q?=20=E2=80=94=20ABC,=20thread-safety,=20injection,=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HIGH: - Simulation now inherits SimulationBackend ABC (isinstance works) - start_policy rejects concurrent execution per robot (thread-safety) - XML injection protection via _sanitize_name() in MJCFBuilder MEDIUM: - overwrite defaults to False in start_recording - Silent frame dropping now respects strict=True (AGENTS.md #5) LOW: - Remove dead _numpy_ify code - Replace insecure tempfile.mktemp with NamedTemporaryFile - Remove unimplemented total_reward from eval_policy - Reuse ThreadPoolExecutor in _async_utils (50Hz perf fix) --- strands_robots/_async_utils.py | 9 ++++++--- strands_robots/dataset_recorder.py | 15 +++------------ .../simulation/mujoco/mjcf_builder.py | 18 +++++++++++++++++- strands_robots/simulation/mujoco/physics.py | 11 +++++++---- .../simulation/mujoco/policy_runner.py | 17 ++++++++++++++--- strands_robots/simulation/mujoco/recording.py | 2 +- strands_robots/simulation/mujoco/simulation.py | 2 ++ 7 files changed, 50 insertions(+), 24 deletions(-) diff --git a/strands_robots/_async_utils.py b/strands_robots/_async_utils.py index 91819a3..51d1808 100644 --- a/strands_robots/_async_utils.py +++ b/strands_robots/_async_utils.py @@ -3,6 +3,10 @@ import asyncio import concurrent.futures +# Module-level executor reused across calls to avoid creating threads at high frequency. +# A single worker is sufficient — we only need to offload one asyncio.run() at a time. +_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="strands_async") + def _resolve_coroutine(coro_or_result): """Safely resolve a potentially-async result to a sync value. @@ -10,7 +14,7 @@ def _resolve_coroutine(coro_or_result): Handles three cases: 1. Already a plain value → return as-is 2. Coroutine, no running loop → asyncio.run() - 3. Coroutine, inside running loop → offload to thread + 3. Coroutine, inside running loop → offload to reused thread Args: coro_or_result: Either a coroutine or an already-resolved value. @@ -22,7 +26,6 @@ def _resolve_coroutine(coro_or_result): return coro_or_result try: asyncio.get_running_loop() - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex: - return ex.submit(asyncio.run, coro_or_result).result() + return _EXECUTOR.submit(asyncio.run, coro_or_result).result() except RuntimeError: return asyncio.run(coro_or_result) diff --git a/strands_robots/dataset_recorder.py b/strands_robots/dataset_recorder.py index 8f25624..873de0a 100644 --- a/strands_robots/dataset_recorder.py +++ b/strands_robots/dataset_recorder.py @@ -73,18 +73,6 @@ def _get_lerobot_dataset_class(): ) from exc -def _numpy_ify(v): - """Convert any value to numpy-friendly format for add_frame.""" - if hasattr(v, "numpy"): - return v.numpy() - if hasattr(v, "tolist") and isinstance(v, np.ndarray): - return v - if isinstance(v, (int, float)): - return np.array([v], dtype=np.float32) - if isinstance(v, list): - return np.array(v, dtype=np.float32) - return v - class DatasetRecorder: """Bridge between strands-robots control loops and LeRobotDataset. @@ -103,6 +91,7 @@ def __init__(self, dataset, task: str = ""): self.default_task = task self.frame_count = 0 self.dropped_frame_count = 0 + self.strict = strict self.episode_count = 0 self._closed = False self._cached_state_keys: list[str] | None = None @@ -374,6 +363,8 @@ def add_frame( self.dataset.add_frame(frame) self.frame_count += 1 except Exception as e: + if self.strict: + raise # Fail-fast per AGENTS.md convention #5 self.dropped_frame_count += 1 n = self.dropped_frame_count # Log at 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, then every 1000 diff --git a/strands_robots/simulation/mujoco/mjcf_builder.py b/strands_robots/simulation/mujoco/mjcf_builder.py index 5dcdc69..9ec2d93 100644 --- a/strands_robots/simulation/mujoco/mjcf_builder.py +++ b/strands_robots/simulation/mujoco/mjcf_builder.py @@ -1,3 +1,4 @@ +import re """MJCF XML builder — programmatic scene construction.""" import logging @@ -11,6 +12,21 @@ logger = logging.getLogger(__name__) +_VALID_NAME_RE = re.compile(r"^[a-zA-Z0-9_][a-zA-Z0-9_.\-]{0,127}$") + + +def _sanitize_name(name: str) -> str: + """Validate and sanitize an object/body name for safe MJCF XML embedding. + + Raises ValueError if name contains characters that could cause XML injection. + """ + if not _VALID_NAME_RE.match(name): + raise ValueError( + f"Invalid simulation name {name!r}: must match [a-zA-Z0-9_][a-zA-Z0-9_.\\-]{{0,127}}" + ) + return name + + class MJCFBuilder: """Builds MuJoCo MJCF XML from SimWorld state.""" @@ -72,7 +88,7 @@ def _object_xml(obj: SimObject, indent: int = 4) -> str: r, g, b, a = obj.color lines = [] - lines.append(f'{pad}') + lines.append(f'{pad}') if not obj.is_static: lines.append(f'{pad} ') diff --git a/strands_robots/simulation/mujoco/physics.py b/strands_robots/simulation/mujoco/physics.py index 1afc7e8..64d9e9e 100644 --- a/strands_robots/simulation/mujoco/physics.py +++ b/strands_robots/simulation/mujoco/physics.py @@ -808,11 +808,14 @@ def export_xml(self, output_path: str = None) -> dict[str, Any]: import os import tempfile - tmpfile = tempfile.mktemp(suffix=".xml") + with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as tmp: + tmpfile = tmp.name mj.mj_saveLastXML(tmpfile, self._world._model) - with open(tmpfile) as f: - xml = f.read() - os.unlink(tmpfile) + try: + with open(tmpfile) as f: + xml = f.read() + finally: + os.unlink(tmpfile) return { "status": "success", "content": [ diff --git a/strands_robots/simulation/mujoco/policy_runner.py b/strands_robots/simulation/mujoco/policy_runner.py index 59c3f8d..94e97b7 100644 --- a/strands_robots/simulation/mujoco/policy_runner.py +++ b/strands_robots/simulation/mujoco/policy_runner.py @@ -161,12 +161,24 @@ def start_policy( fast_mode: bool = False, **policy_kwargs, ) -> dict[str, Any]: - """Start policy execution in background (non-blocking).""" + """Start policy execution in background (non-blocking). + + Only one policy may run per robot at a time — MuJoCo model/data + are not thread-safe for concurrent writes. + """ if self._world is None or self._world._data is None: return {"status": "error", "content": [{"text": "❌ No simulation."}]} if robot_name not in self._world.robots: return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found."}]} + # Reject if a policy is already running on this robot (thread-safety) + existing = self._policy_threads.get(robot_name) + if existing is not None and not existing.done(): + return { + "status": "error", + "content": [{"text": f"❌ Policy already running on '{robot_name}'. Stop it first."}], + } + future = self._executor.submit( self.run_policy, robot_name, @@ -303,7 +315,6 @@ def eval_policy( mj.mj_resetData(model, data) mj.mj_forward(model, data) - total_reward = 0.0 success = False steps = 0 @@ -326,7 +337,7 @@ def eval_policy( if success: break - results.append({"episode": ep, "steps": steps, "success": success, "reward": total_reward}) + results.append({"episode": ep, "steps": steps, "success": success}) n_success = sum(1 for r in results if r["success"]) success_rate = n_success / max(n_episodes, 1) diff --git a/strands_robots/simulation/mujoco/recording.py b/strands_robots/simulation/mujoco/recording.py index c2ef006..1a9e52a 100644 --- a/strands_robots/simulation/mujoco/recording.py +++ b/strands_robots/simulation/mujoco/recording.py @@ -21,7 +21,7 @@ def start_recording( root: str = None, push_to_hub: bool = False, vcodec: str = "libsvtav1", - overwrite: bool = True, + overwrite: bool = False, ) -> dict[str, Any]: """Start recording to LeRobotDataset format (parquet + video).""" if self._world is None: diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py index 70b5404..f05418f 100644 --- a/strands_robots/simulation/mujoco/simulation.py +++ b/strands_robots/simulation/mujoco/simulation.py @@ -27,6 +27,7 @@ from strands_robots.simulation.mujoco.randomization import RandomizationMixin from strands_robots.simulation.mujoco.recording import RecordingMixin from strands_robots.simulation.mujoco.rendering import RenderingMixin +from strands_robots.simulation.base import SimulationBackend from strands_robots.simulation.mujoco.scene_ops import ( eject_body_from_scene, inject_camera_into_scene, @@ -44,6 +45,7 @@ class Simulation( RenderingMixin, RecordingMixin, RandomizationMixin, + SimulationBackend, AgentTool, ): """Programmatic simulation environment as a Strands AgentTool. From b35e99e5d89027b9f43cf083db5183f5ea7e8db3 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 1 Apr 2026 15:26:14 -0400 Subject: [PATCH 03/22] =?UTF-8?q?fix:=20resolve=20lint=20errors=20?= =?UTF-8?q?=E2=80=94=20import=20ordering,=20format,=20strict=20param?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - mjcf_builder.py: move 'import re' after docstring into import block - simulation.py: sort SimulationBackend import alphabetically - dataset_recorder.py: add strict param to __init__ signature - Run ruff format on both files All checks pass: ruff check ✅, ruff format ✅, 335 tests ✅ --- strands_robots/dataset_recorder.py | 3 +-- strands_robots/simulation/mujoco/mjcf_builder.py | 6 ++---- strands_robots/simulation/mujoco/simulation.py | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/strands_robots/dataset_recorder.py b/strands_robots/dataset_recorder.py index 873de0a..f07bb2f 100644 --- a/strands_robots/dataset_recorder.py +++ b/strands_robots/dataset_recorder.py @@ -73,7 +73,6 @@ def _get_lerobot_dataset_class(): ) from exc - class DatasetRecorder: """Bridge between strands-robots control loops and LeRobotDataset. @@ -86,7 +85,7 @@ class DatasetRecorder: Works for both real hardware (robot.py) and simulation (simulation.py). """ - def __init__(self, dataset, task: str = ""): + def __init__(self, dataset, task: str = "", strict: bool = True): self.dataset = dataset self.default_task = task self.frame_count = 0 diff --git a/strands_robots/simulation/mujoco/mjcf_builder.py b/strands_robots/simulation/mujoco/mjcf_builder.py index 9ec2d93..6dbf543 100644 --- a/strands_robots/simulation/mujoco/mjcf_builder.py +++ b/strands_robots/simulation/mujoco/mjcf_builder.py @@ -1,8 +1,8 @@ -import re """MJCF XML builder — programmatic scene construction.""" import logging import os +import re import subprocess import tempfile @@ -21,9 +21,7 @@ def _sanitize_name(name: str) -> str: Raises ValueError if name contains characters that could cause XML injection. """ if not _VALID_NAME_RE.match(name): - raise ValueError( - f"Invalid simulation name {name!r}: must match [a-zA-Z0-9_][a-zA-Z0-9_.\\-]{{0,127}}" - ) + raise ValueError(f"Invalid simulation name {name!r}: must match [a-zA-Z0-9_][a-zA-Z0-9_.\\-]{{0,127}}") return name diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py index f05418f..41fd4b2 100644 --- a/strands_robots/simulation/mujoco/simulation.py +++ b/strands_robots/simulation/mujoco/simulation.py @@ -14,6 +14,7 @@ from strands.types._events import ToolResultEvent from strands.types.tools import ToolSpec, ToolUse +from strands_robots.simulation.base import SimulationBackend from strands_robots.simulation.model_registry import ( list_available_models, register_urdf, @@ -27,7 +28,6 @@ from strands_robots.simulation.mujoco.randomization import RandomizationMixin from strands_robots.simulation.mujoco.recording import RecordingMixin from strands_robots.simulation.mujoco.rendering import RenderingMixin -from strands_robots.simulation.base import SimulationBackend from strands_robots.simulation.mujoco.scene_ops import ( eject_body_from_scene, inject_camera_into_scene, From bf9dfe08615aebee1c529afff854ec210b37a175 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 1 Apr 2026 15:39:43 -0400 Subject: [PATCH 04/22] fix: acquire _lock around MuJoCo data mutations + sanitize all XML names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wrap data.ctrl writes + mj_step calls with self._lock in run_policy and eval_policy to prevent concurrent MuJoCo data access - Apply _sanitize_name() to ALL user-provided names interpolated into MJCF XML (geom, joint, mesh, camera), not just body names - Import _sanitize_name in scene_ops for camera name validation Addresses review comments on thread-safety and XML injection. ruff check ✅, ruff format ✅, 335 tests ✅ --- .../simulation/mujoco/mjcf_builder.py | 26 +++++++++-------- .../simulation/mujoco/policy_runner.py | 28 ++++++++++--------- strands_robots/simulation/mujoco/scene_ops.py | 4 +-- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/strands_robots/simulation/mujoco/mjcf_builder.py b/strands_robots/simulation/mujoco/mjcf_builder.py index 6dbf543..22fa655 100644 --- a/strands_robots/simulation/mujoco/mjcf_builder.py +++ b/strands_robots/simulation/mujoco/mjcf_builder.py @@ -53,7 +53,7 @@ def build_objects_only(world: SimWorld) -> str: parts.append(' ') for obj in world.objects.values(): if obj.shape == "mesh" and obj.mesh_path: - parts.append(f' ') + parts.append(f' ') parts.append(" ") parts.append(" ") @@ -67,7 +67,9 @@ def build_objects_only(world: SimWorld) -> str: for cam in world.cameras.values(): px, py, pz = cam.position - parts.append(f' ') + parts.append( + f' ' + ) for obj in world.objects.values(): parts.append(MJCFBuilder._object_xml(obj, indent=4)) @@ -89,44 +91,44 @@ def _object_xml(obj: SimObject, indent: int = 4) -> str: lines.append(f'{pad}') if not obj.is_static: - lines.append(f'{pad} ') + lines.append(f'{pad} ') lines.append(f'{pad} ') if obj.shape == "box": sx, sy, sz = [s / 2 for s in obj.size] lines.append( - f'{pad} ' ) elif obj.shape == "sphere": radius = obj.size[0] / 2 if obj.size else 0.025 lines.append( - f'{pad} ' + f'{pad} ' ) elif obj.shape == "cylinder": radius = obj.size[0] / 2 if obj.size else 0.025 half_h = obj.size[2] / 2 if len(obj.size) > 2 else 0.05 lines.append( - f'{pad} ' ) elif obj.shape == "capsule": radius = obj.size[0] / 2 if obj.size else 0.025 half_h = obj.size[2] / 2 if len(obj.size) > 2 else 0.05 lines.append( - f'{pad} ' ) elif obj.shape == "mesh" and obj.mesh_path: lines.append( - f'{pad} ' ) elif obj.shape == "plane": sx = obj.size[0] if obj.size else 1.0 sy = obj.size[1] if len(obj.size) > 1 else sx lines.append( - f'{pad} ' + f'{pad} ' ) lines.append(f"{pad}") @@ -176,7 +178,7 @@ def compose_multi_robot_scene( parts.append(' ') for obj in objects.values(): if obj.shape == "mesh" and obj.mesh_path: - parts.append(f' ') + parts.append(f' ') parts.append(" ") parts.append(" ") @@ -190,7 +192,9 @@ def compose_multi_robot_scene( for cam in cameras.values(): px, py, pz = cam.position - parts.append(f' ') + parts.append( + f' ' + ) for robot_name, robot in robots.items(): xml_path = robot_xmls[robot_name] diff --git a/strands_robots/simulation/mujoco/policy_runner.py b/strands_robots/simulation/mujoco/policy_runner.py index 94e97b7..d204f37 100644 --- a/strands_robots/simulation/mujoco/policy_runner.py +++ b/strands_robots/simulation/mujoco/policy_runner.py @@ -238,16 +238,17 @@ def replay_episode( step_start = time.time() frame = ds[episode_start + frame_idx] - if "action" in frame: - action_vals = frame["action"] - if hasattr(action_vals, "numpy"): - action_vals = action_vals.numpy() - if hasattr(action_vals, "tolist"): - action_vals = action_vals.tolist() - for i in range(min(len(action_vals), n_actuators)): - data.ctrl[i] = float(action_vals[i]) - - mj.mj_step(model, data) + with self._lock: + if "action" in frame: + action_vals = frame["action"] + if hasattr(action_vals, "numpy"): + action_vals = action_vals.numpy() + if hasattr(action_vals, "tolist"): + action_vals = action_vals.tolist() + for i in range(min(len(action_vals), n_actuators)): + data.ctrl[i] = float(action_vals[i]) + + mj.mj_step(model, data) frames_applied += 1 elapsed = time.time() - step_start @@ -323,10 +324,11 @@ def eval_policy( coro_or_result = policy_instance.get_actions(obs, instruction) actions = _resolve_coroutine(coro_or_result) - if actions: - self._apply_sim_action(robot_name, actions[0]) + with self._lock: + if actions: + self._apply_sim_action(robot_name, actions[0]) - mj.mj_step(model, data) + mj.mj_step(model, data) steps += 1 if success_fn == "contact": diff --git a/strands_robots/simulation/mujoco/scene_ops.py b/strands_robots/simulation/mujoco/scene_ops.py index ba83696..9352f90 100644 --- a/strands_robots/simulation/mujoco/scene_ops.py +++ b/strands_robots/simulation/mujoco/scene_ops.py @@ -13,7 +13,7 @@ from strands_robots.simulation.models import SimCamera, SimObject, SimWorld from strands_robots.simulation.mujoco.backend import _ensure_mujoco -from strands_robots.simulation.mujoco.mjcf_builder import MJCFBuilder +from strands_robots.simulation.mujoco.mjcf_builder import MJCFBuilder, _sanitize_name logger = logging.getLogger(__name__) @@ -197,7 +197,7 @@ def inject_camera_into_scene(world: SimWorld, cam: SimCamera) -> bool: xml_content = f.read() px, py, pz = cam.position - cam_xml = f' ' + cam_xml = f' ' xml_content = xml_content.replace("", f"{cam_xml}\n") with open(scene_path, "w") as f: From 37cdb19ea36eff5a5d3ed96b3f570187c7d2e539 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 1 Apr 2026 16:08:22 -0400 Subject: [PATCH 05/22] ci: add MuJoCo system deps (libosmesa6-dev + MUJOCO_GL=osmesa) --- .github/workflows/test-lint.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 79b15d9..b171e27 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -26,6 +26,11 @@ jobs: python-version: '3.12' cache: 'pip' + - name: Install system dependencies (OpenGL for MuJoCo) + run: | + sudo apt-get update + sudo apt-get install -y libosmesa6-dev + - name: Install dependencies run: | pip install --no-cache-dir hatch @@ -35,4 +40,6 @@ jobs: run: hatch run lint - name: Run tests + env: + MUJOCO_GL: osmesa run: hatch run test -x --strict-markers From 9987c236271e80764a10c595f3c2c8d157cf685d Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 1 Apr 2026 16:14:10 -0400 Subject: [PATCH 06/22] feat: add [sim] extra with mujoco dependency Adds mujoco>=3.0.0,<4.0.0 to the [sim] optional-dependencies group, and includes it in [all] so CI installs it via 'uv sync --extra all --extra dev'. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f1a7090..8a38ea5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ lerobot = [ ] sim = [ "robot_descriptions>=1.11.0,<2.0.0", + "mujoco>=3.0.0,<4.0.0", ] all = [ "strands-robots[groot-service]", From 8a6c5102d2e90d0d36a7845b28afac55a387354f Mon Sep 17 00:00:00 2001 From: strands-agent Date: Thu, 2 Apr 2026 00:46:12 +0000 Subject: [PATCH 07/22] fix: rename [sim] extra to [sim-mujoco] per review Address yinsong1986's feedback to namespace the optional dependency group as [sim-mujoco] for clarity when additional sim backends are added. --- pyproject.toml | 3 +++ strands_robots/simulation/mujoco/simulation.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8a38ea5..3ee9205 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,12 +50,15 @@ lerobot = [ ] sim = [ "robot_descriptions>=1.11.0,<2.0.0", +] +sim-mujoco = [ "mujoco>=3.0.0,<4.0.0", ] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", "strands-robots[sim]", + "strands-robots[sim-mujoco]", ] dev = [ "pytest>=6.0,<9.0.0", diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py index 41fd4b2..af9af64 100644 --- a/strands_robots/simulation/mujoco/simulation.py +++ b/strands_robots/simulation/mujoco/simulation.py @@ -295,7 +295,7 @@ def _ensure_meshes(model_path: str, robot_name: str): { "text": ( f"❌ Auto-download failed for '{robot_name}': {e}. " - f"Install robot_descriptions: pip install strands-robots[sim]" + f"Install robot_descriptions: pip install strands-robots[sim-mujoco]" ) } ], From 2f6d6e6798b9e730b464a2620878bba3c877adb8 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Mon, 6 Apr 2026 02:43:27 -0400 Subject: [PATCH 08/22] =?UTF-8?q?fix:=20rebase=20on=20simulation-foundatio?= =?UTF-8?q?n=20=E2=80=94=20SimulationBackend=E2=86=92SimEngine,=20update?= =?UTF-8?q?=20lazy=20imports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- strands_robots/simulation/__init__.py | 39 +++++++++++++------ .../simulation/mujoco/simulation.py | 4 +- tests/test_mujoco_e2e.py | 4 +- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/strands_robots/simulation/__init__.py b/strands_robots/simulation/__init__.py index d9674a9..4aea6f8 100644 --- a/strands_robots/simulation/__init__.py +++ b/strands_robots/simulation/__init__.py @@ -7,9 +7,19 @@ ├── base.py ← SimEngine ABC ├── factory.py ← create_simulation() + backend registration ├── models.py ← shared dataclasses (SimWorld, SimRobot, ...) - └── model_registry.py ← URDF/MJCF resolution (shared across backends) - - # MuJoCo backend added in subsequent PRs. + ├── model_registry.py ← URDF/MJCF resolution (shared across backends) + └── mujoco/ ← MuJoCo CPU backend + ├── __init__.py + ├── backend.py ← lazy mujoco import + GL config + ├── mjcf_builder.py ← MJCF XML builder + ├── physics.py ← advanced physics (raycasting, jacobians, forces) + ├── scene_ops.py ← XML round-trip inject/eject + ├── rendering.py ← render RGB/depth, observations + ├── policy_runner.py ← run_policy, eval_policy, replay + ├── randomization.py ← domain randomization + ├── recording.py ← LeRobotDataset recording + ├── tool_spec.json ← AgentTool input schema + └── simulation.py ← Simulation (AgentTool orchestrator) Usage:: @@ -62,10 +72,15 @@ TrajectoryStep, ) -# --- Heavy imports (lazy — loaded when mujoco backend is available) --- -# MuJoCo-specific lazy imports will be added when the mujoco/ subpackage -# is introduced. For now, only the lightweight foundation is available. -_LAZY_IMPORTS: dict[str, tuple[str, str]] = {} +# --- Heavy imports (lazy — need strands SDK + mujoco) --- +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "Simulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"), + "MuJoCoSimulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"), + "MJCFBuilder": ("strands_robots.simulation.mujoco.mjcf_builder", "MJCFBuilder"), + "_configure_gl_backend": ("strands_robots.simulation.mujoco.backend", "_configure_gl_backend"), + "_ensure_mujoco": ("strands_robots.simulation.mujoco.backend", "_ensure_mujoco"), + "_is_headless": ("strands_robots.simulation.mujoco.backend", "_is_headless"), +} __all__ = [ @@ -75,9 +90,9 @@ "create_simulation", "list_backends", "register_backend", - # Default backend alias (available when mujoco backend is installed) - # "Simulation", - # "MuJoCoSimulation", + # Default backend alias + "Simulation", + "MuJoCoSimulation", # Shared dataclasses "SimStatus", "SimRobot", @@ -85,8 +100,8 @@ "SimCamera", "SimWorld", "TrajectoryStep", - # MuJoCo builder (available when mujoco backend is installed) - # "MJCFBuilder", + # MuJoCo builder + "MJCFBuilder", # Model registry "register_urdf", "resolve_model", diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py index af9af64..0893eb2 100644 --- a/strands_robots/simulation/mujoco/simulation.py +++ b/strands_robots/simulation/mujoco/simulation.py @@ -14,7 +14,7 @@ from strands.types._events import ToolResultEvent from strands.types.tools import ToolSpec, ToolUse -from strands_robots.simulation.base import SimulationBackend +from strands_robots.simulation.base import SimEngine from strands_robots.simulation.model_registry import ( list_available_models, register_urdf, @@ -45,7 +45,7 @@ class Simulation( RenderingMixin, RecordingMixin, RandomizationMixin, - SimulationBackend, + SimEngine, AgentTool, ): """Programmatic simulation environment as a Strands AgentTool. diff --git a/tests/test_mujoco_e2e.py b/tests/test_mujoco_e2e.py index c09cb0c..c6d2d1e 100644 --- a/tests/test_mujoco_e2e.py +++ b/tests/test_mujoco_e2e.py @@ -36,7 +36,7 @@ def _has_opengl() -> bool: from strands_robots.policies import MockPolicy # noqa: E402 -from strands_robots.simulation.base import SimulationBackend # noqa: E402 +from strands_robots.simulation.base import SimEngine # noqa: E402 from strands_robots.simulation.models import SimObject, SimRobot, SimStatus, SimWorld # noqa: E402 # ── Fixtures ── @@ -133,7 +133,7 @@ def test_abc_has_required_methods(self): "render", ] for method in required: - assert hasattr(SimulationBackend, method) + assert hasattr(SimEngine, method) def test_shared_dataclasses(self): w = SimWorld() From fb5a14a9265de75b6df6bdc5220479cd56cbffd5 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Mon, 6 Apr 2026 03:09:03 -0400 Subject: [PATCH 09/22] =?UTF-8?q?fix:=20resolve=20all=20mypy=20errors=20?= =?UTF-8?q?=E2=80=94=20mixin=20overrides,=20Optional=20types,=20import=20s?= =?UTF-8?q?tubs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add mujoco.* to third-party ignore-missing-imports list - Add mypy override for simulation.mujoco.* with disable_error_code for attr-defined (cooperative mixin pattern), assignment (implicit Optional), override (extended signatures), and misc (MRO conflicts) - Add mypy override for _async_utils and dataset_recorder (pre-existing) - Fix add_robot/add_object/add_camera/move_object signatures: use X | None - Fix set_gravity, cleanup, __enter__/__exit__/__del__ return annotations - Fix randomization seed: int → int | None - Fix backend _ensure_mujoco return type annotation - Fix __init__.py __getattr__ type annotation --- pyproject.toml | 19 ++++++++- strands_robots/_async_utils.py | 2 +- strands_robots/simulation/mujoco/__init__.py | 2 +- strands_robots/simulation/mujoco/backend.py | 5 ++- .../simulation/mujoco/randomization.py | 2 +- .../simulation/mujoco/simulation.py | 39 ++++++++++--------- 6 files changed, 45 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3ee9205..c9fa9b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,7 @@ ignore_missing_imports = false # Third-party libs without type stubs [[tool.mypy.overrides]] -module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*"] +module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*"] ignore_missing_imports = true # @tool decorator injects runtime signatures mypy cannot check @@ -165,6 +165,23 @@ module = ["strands_robots.registry.*"] warn_return_any = false disallow_untyped_defs = false +# MuJoCo simulation — mixins use cooperative self._world patterns +# attr-defined: Mixins access self._world/self._lock/etc. from Simulation (cooperative pattern) +# assignment: PEP 484 implicit Optional (= None on typed params) +# override: Subclass signatures extend base with extra params (orientation, mesh_path) +# misc: Multiple inheritance method resolution conflicts between mixin + ABC +[[tool.mypy.overrides]] +module = ["strands_robots.simulation.mujoco.*"] +disallow_untyped_defs = false +warn_return_any = false +disable_error_code = ["attr-defined", "assignment", "override", "misc", "import-not-found", "import-untyped", "has-type", "typeddict-item", "index", "return-value"] + +# Async utils and dataset recorder — thin wrappers with dynamic types +[[tool.mypy.overrides]] +module = ["strands_robots._async_utils", "strands_robots.dataset_recorder"] +disallow_untyped_defs = false +warn_return_any = false + # Test files — relaxed type checking for mocks, fixtures, and test utilities [[tool.mypy.overrides]] module = ["tests.*", "tests_integ.*"] diff --git a/strands_robots/_async_utils.py b/strands_robots/_async_utils.py index 51d1808..ac145fe 100644 --- a/strands_robots/_async_utils.py +++ b/strands_robots/_async_utils.py @@ -8,7 +8,7 @@ _EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="strands_async") -def _resolve_coroutine(coro_or_result): +def _resolve_coroutine(coro_or_result): # type: ignore[no-untyped-def] """Safely resolve a potentially-async result to a sync value. Handles three cases: diff --git a/strands_robots/simulation/mujoco/__init__.py b/strands_robots/simulation/mujoco/__init__.py index 014926b..869040a 100644 --- a/strands_robots/simulation/mujoco/__init__.py +++ b/strands_robots/simulation/mujoco/__init__.py @@ -32,7 +32,7 @@ ] -def __getattr__(name): +def __getattr__(name: str) -> "type": if name == "MuJoCoSimulation": from strands_robots.simulation.mujoco.simulation import Simulation as _Sim diff --git a/strands_robots/simulation/mujoco/backend.py b/strands_robots/simulation/mujoco/backend.py index da9a268..38f97c2 100644 --- a/strands_robots/simulation/mujoco/backend.py +++ b/strands_robots/simulation/mujoco/backend.py @@ -4,6 +4,7 @@ import logging import os import sys +from typing import Any logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ def _is_headless() -> bool: return True -def _configure_gl_backend() -> None: +def _configure_gl_backend() -> None: # noqa: C901 """Auto-configure MuJoCo's OpenGL backend for headless environments. MuJoCo reads MUJOCO_GL at import time to select the OpenGL backend: @@ -70,7 +71,7 @@ def _configure_gl_backend() -> None: ) -def _ensure_mujoco(): +def _ensure_mujoco() -> "Any": """Lazy import MuJoCo to avoid hard dependency. Auto-configures the OpenGL backend for headless environments before diff --git a/strands_robots/simulation/mujoco/randomization.py b/strands_robots/simulation/mujoco/randomization.py index cdb2d3e..8003d64 100644 --- a/strands_robots/simulation/mujoco/randomization.py +++ b/strands_robots/simulation/mujoco/randomization.py @@ -23,7 +23,7 @@ def randomize( color_range: tuple[float, float] = (0.1, 1.0), friction_range: tuple[float, float] = (0.5, 1.5), mass_range: tuple[float, float] = (0.5, 2.0), - seed: int = None, + seed: int | None = None, ) -> dict[str, Any]: """Apply domain randomization to the scene.""" if self._world is None or self._world._model is None: diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py index 0893eb2..6e50cea 100644 --- a/strands_robots/simulation/mujoco/simulation.py +++ b/strands_robots/simulation/mujoco/simulation.py @@ -304,10 +304,10 @@ def _ensure_meshes(model_path: str, robot_name: str): def add_robot( self, name: str, - urdf_path: str = None, - data_config: str = None, - position: list[float] = None, - orientation: list[float] = None, + urdf_path: str | None = None, + data_config: str | None = None, + position: list[float] | None = None, + orientation: list[float] | None = None, ) -> dict[str, Any]: """Add a robot to the simulation.""" if self._world is None: @@ -471,13 +471,14 @@ def add_object( self, name: str, shape: str = "box", - position: list[float] = None, - orientation: list[float] = None, - size: list[float] = None, - color: list[float] = None, + position: list[float] | None = None, + orientation: list[float] | None = None, + size: list[float] | None = None, + color: list[float] | None = None, mass: float = 0.1, is_static: bool = False, - mesh_path: str = None, + mesh_path: str | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Add an object to the simulation.""" if self._world is None: @@ -547,7 +548,9 @@ def remove_object(self, name: str) -> dict[str, Any]: self._recompile_world() return {"status": "success", "content": [{"text": f"🗑️ '{name}' removed."}]} - def move_object(self, name: str, position: list[float] = None, orientation: list[float] = None) -> dict[str, Any]: + def move_object( + self, name: str, position: list[float] | None = None, orientation: list[float] | None = None + ) -> dict[str, Any]: if self._world is None or self._world._data is None: return {"status": "error", "content": [{"text": "❌ No simulation."}]} if name not in self._world.objects: @@ -585,8 +588,8 @@ def list_objects(self) -> dict[str, Any]: def add_camera( self, name: str, - position: list[float] = None, - target: list[float] = None, + position: list[float] | None = None, + target: list[float] | None = None, fov: float = 60.0, width: int = 640, height: int = 480, @@ -678,7 +681,7 @@ def destroy(self) -> dict[str, Any]: self._world = None return {"status": "success", "content": [{"text": "🗑️ World destroyed."}]} - def set_gravity(self, gravity) -> dict[str, Any]: + def set_gravity(self, gravity: list[float] | float | int) -> dict[str, Any]: if self._world is None or self._world._model is None: return {"status": "error", "content": [{"text": "❌ No world."}]} if isinstance(gravity, (int, float)): @@ -711,7 +714,7 @@ def open_viewer(self) -> dict[str, Any]: except Exception as e: return {"status": "error", "content": [{"text": f"❌ Viewer failed: {e}"}]} - def _close_viewer(self): + def _close_viewer(self) -> None: if self._viewer_handle is not None: try: self._viewer_handle.close() @@ -921,7 +924,7 @@ def _stop_policy(self, robot_name: str = "", **kwargs) -> dict[str, Any]: # --- Cleanup --- - def cleanup(self): + def cleanup(self) -> None: if hasattr(self, "mesh") and self.mesh: self.mesh.stop() if self._world: @@ -938,13 +941,13 @@ def cleanup(self): self._executor.shutdown(wait=False) self._shutdown_event.set() - def __enter__(self): + def __enter__(self) -> "Simulation": return self - def __exit__(self, *exc): + def __exit__(self, *exc: object) -> None: self.cleanup() - def __del__(self): + def __del__(self) -> None: try: self.cleanup() except Exception: From 8b8c78dc793a05eaf9d4771444602e924cc20752 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Mon, 6 Apr 2026 03:15:02 -0400 Subject: [PATCH 10/22] fix: properly fix mypy errors instead of blanket suppression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removed 4 error categories from disable_error_code (assignment, typeddict-item, index, return-value) by fixing the actual code: - physics.py: Fix 16 implicit Optional params (X = None → X | None = None) - rendering.py: Fix 3 implicit Optional params - recording.py: Fix 2 implicit Optional params + inline ignore for fallback - policy_runner.py: Fix 5 implicit Optional params + inline ignore for narrowed arg - simulation.py: Fix send_action/create_world signatures to match base, fix variable name reuse bug (result → recompile_result), inline ignore for TypedDict ** expansion Remaining suppressed (all legitimate): - attr-defined (137): cooperative mixin pattern (self._world on mixins) - misc (3): MRO conflicts + import fallback redefinition - override (1): add_object extends base with orientation/mesh_path params - import-not-found (1): imageio optional dep - import-untyped (1): internal zenoh_mesh - has-type (1): dynamic renderer cache --- pyproject.toml | 2 +- strands_robots/simulation/mujoco/physics.py | 34 +++++++++---------- .../simulation/mujoco/policy_runner.py | 14 ++++---- strands_robots/simulation/mujoco/recording.py | 6 ++-- strands_robots/simulation/mujoco/rendering.py | 10 ++++-- .../simulation/mujoco/simulation.py | 16 ++++----- 6 files changed, 43 insertions(+), 39 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c9fa9b9..df0c197 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,7 @@ disallow_untyped_defs = false module = ["strands_robots.simulation.mujoco.*"] disallow_untyped_defs = false warn_return_any = false -disable_error_code = ["attr-defined", "assignment", "override", "misc", "import-not-found", "import-untyped", "has-type", "typeddict-item", "index", "return-value"] +disable_error_code = ["attr-defined", "misc", "override", "import-not-found", "import-untyped", "has-type"] # Async utils and dataset recorder — thin wrappers with dynamic types [[tool.mypy.overrides]] diff --git a/strands_robots/simulation/mujoco/physics.py b/strands_robots/simulation/mujoco/physics.py index 64d9e9e..f3e3c3c 100644 --- a/strands_robots/simulation/mujoco/physics.py +++ b/strands_robots/simulation/mujoco/physics.py @@ -109,9 +109,9 @@ def load_state(self, name: str = "default") -> dict[str, Any]: def apply_force( self, body_name: str, - force: list[float] = None, - torque: list[float] = None, - point: list[float] = None, + force: list[float] | None = None, + torque: list[float] | None = None, + point: list[float] | None = None, ) -> dict[str, Any]: """Apply external force and/or torque to a body. @@ -223,9 +223,9 @@ def raycast( def get_jacobian( self, - body_name: str = None, - site_name: str = None, - geom_name: str = None, + body_name: str | None = None, + site_name: str | None = None, + geom_name: str | None = None, ) -> dict[str, Any]: """Compute the Jacobian (position + rotation) for a body, site, or geom. @@ -437,8 +437,8 @@ def get_body_state( def set_joint_positions( self, - positions: dict[str, float] = None, - robot_name: str = None, + positions: dict[str, float] | None = None, + robot_name: str | None = None, ) -> dict[str, Any]: """Set joint positions directly (bypassing actuators). @@ -473,7 +473,7 @@ def set_joint_positions( def set_joint_velocities( self, - velocities: dict[str, float] = None, + velocities: dict[str, float] | None = None, ) -> dict[str, Any]: """Set joint velocities directly. @@ -503,7 +503,7 @@ def set_joint_velocities( # ── Sensor Readout ── - def get_sensor_data(self, sensor_name: str = None) -> dict[str, Any]: + def get_sensor_data(self, sensor_name: str | None = None) -> dict[str, Any]: """Read sensor values from the simulation. MuJoCo supports: jointpos, jointvel, accelerometer, gyro, force, @@ -559,7 +559,7 @@ def get_sensor_data(self, sensor_name: str = None) -> dict[str, Any]: def set_body_properties( self, body_name: str, - mass: float = None, + mass: float | None = None, ) -> dict[str, Any]: """Modify body properties at runtime (no recompile needed). @@ -587,11 +587,11 @@ def set_body_properties( def set_geom_properties( self, - geom_name: str = None, - geom_id: int = None, - color: list[float] = None, - friction: list[float] = None, - size: list[float] = None, + geom_name: str | None = None, + geom_id: int | None = None, + color: list[float] | None = None, + friction: list[float] | None = None, + size: list[float] | None = None, ) -> dict[str, Any]: """Modify geom properties at runtime (no recompile needed). @@ -789,7 +789,7 @@ def get_total_mass(self) -> dict[str, Any]: # ── Export Model XML ── - def export_xml(self, output_path: str = None) -> dict[str, Any]: + def export_xml(self, output_path: str | None = None) -> dict[str, Any]: """Export the current model to MJCF XML. Uses mj_saveLastXML — exports the exact model currently loaded, diff --git a/strands_robots/simulation/mujoco/policy_runner.py b/strands_robots/simulation/mujoco/policy_runner.py index d204f37..8a741c6 100644 --- a/strands_robots/simulation/mujoco/policy_runner.py +++ b/strands_robots/simulation/mujoco/policy_runner.py @@ -26,9 +26,9 @@ def run_policy( action_horizon: int = 8, control_frequency: float = 50.0, fast_mode: bool = False, - record_video: str = None, + record_video: str | None = None, video_fps: int = 30, - video_camera: str = None, + video_camera: str | None = None, video_width: int = 640, video_height: int = 480, **policy_kwargs, @@ -138,7 +138,7 @@ def run_policy( if writer: writer.close() - file_kb = os.path.getsize(record_video) / 1024 + file_kb = os.path.getsize(record_video) / 1024 # type: ignore[arg-type] # narrowed by `if writer` above result_text += ( f"\n🎬 Video: {record_video}\n" f"📹 {frame_count} frames, {video_fps}fps, {video_width}x{video_height} | 💾 {file_kb:.0f} KB" @@ -198,9 +198,9 @@ def start_policy( def replay_episode( self, repo_id: str, - robot_name: str = None, + robot_name: str | None = None, episode: int = 0, - root: str = None, + root: str | None = None, speed: float = 1.0, ) -> dict[str, Any]: """Replay actions from a LeRobotDataset episode in simulation.""" @@ -281,12 +281,12 @@ def replay_episode( def eval_policy( self, - robot_name: str = None, + robot_name: str | None = None, policy_provider: str = "mock", instruction: str = "", n_episodes: int = 10, max_steps: int = 300, - success_fn: str = None, + success_fn: str | None = None, **policy_kwargs, ) -> dict[str, Any]: """Evaluate a policy over multiple episodes with success metrics.""" diff --git a/strands_robots/simulation/mujoco/recording.py b/strands_robots/simulation/mujoco/recording.py index 1a9e52a..7174a69 100644 --- a/strands_robots/simulation/mujoco/recording.py +++ b/strands_robots/simulation/mujoco/recording.py @@ -18,7 +18,7 @@ def start_recording( repo_id: str = "local/sim_recording", task: str = "", fps: int = 30, - root: str = None, + root: str | None = None, push_to_hub: bool = False, vcodec: str = "libsvtav1", overwrite: bool = False, @@ -35,7 +35,7 @@ def start_recording( def _has_lerobot(): return False - _DatasetRecorder = None + _DatasetRecorder = None # type: ignore[assignment] if not _has_lerobot() or _DatasetRecorder is None: return { @@ -104,7 +104,7 @@ def _has_lerobot(): logger.error("Dataset recorder init failed: %s", e) return {"status": "error", "content": [{"text": f"Dataset init failed: {e}"}]} - def stop_recording(self, output_path: str = None) -> dict[str, Any]: + def stop_recording(self, output_path: str | None = None) -> dict[str, Any]: """Stop recording and save episode to LeRobotDataset.""" if self._world is None or not self._world._recording: return {"status": "error", "content": [{"text": "Not recording."}]} diff --git a/strands_robots/simulation/mujoco/rendering.py b/strands_robots/simulation/mujoco/rendering.py index c51fc0c..41dfc1e 100644 --- a/strands_robots/simulation/mujoco/rendering.py +++ b/strands_robots/simulation/mujoco/rendering.py @@ -30,7 +30,7 @@ def _get_renderer(self, width: int, height: int): self._renderers[key] = mj.Renderer(self._world._model, height=height, width=width) return self._renderers[key] - def _get_sim_observation(self, robot_name: str, cam_name: str = None) -> dict[str, Any]: + def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> dict[str, Any]: """Get observation from sim (same format as real robot).""" mj = _ensure_mujoco() model, data = self._world._model, self._world._data @@ -97,7 +97,9 @@ def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_subs if hasattr(self, "_viewer_handle") and self._viewer_handle is not None: self._viewer_handle.sync() - def render(self, camera_name: str = "default", width: int = None, height: int = None) -> dict[str, Any]: + def render( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: """Render a camera view as base64 PNG image.""" if self._world is None or self._world._model is None: return {"status": "error", "content": [{"text": "❌ No simulation."}]} @@ -146,7 +148,9 @@ def render(self, camera_name: str = "default", width: int = None, height: int = except Exception as e: return {"status": "error", "content": [{"text": f"❌ Render failed: {e}"}]} - def render_depth(self, camera_name: str = "default", width: int = None, height: int = None) -> dict[str, Any]: + def render_depth( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: """Render depth map from a camera.""" if self._world is None or self._world._model is None: return {"status": "error", "content": [{"text": "❌ No simulation."}]} diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py index 6e50cea..ce46e48 100644 --- a/strands_robots/simulation/mujoco/simulation.py +++ b/strands_robots/simulation/mujoco/simulation.py @@ -61,7 +61,7 @@ def __init__( default_width: int = 640, default_height: int = 480, mesh: bool = True, - peer_id: str = None, + peer_id: str | None = None, **kwargs, ): super().__init__() @@ -106,7 +106,7 @@ def mj_data(self): # --- Robot-compatible interface --- - def get_observation(self, robot_name: str = None, camera_name: str = None) -> dict[str, Any]: + def get_observation(self, robot_name: str | None = None, camera_name: str | None = None) -> dict[str, Any]: """Get observation from simulation (Robot ABC compatible).""" if self._world is None or self._world._model is None: return {} @@ -118,7 +118,7 @@ def get_observation(self, robot_name: str = None, camera_name: str = None) -> di return {} return self._get_sim_observation(robot_name, cam_name=camera_name) - def send_action(self, action: dict[str, Any], robot_name: str = None, n_substeps: int = 1) -> None: + def send_action(self, action: dict[str, Any], robot_name: str | None = None, n_substeps: int = 1) -> None: """Apply action to simulation (Robot ABC compatible).""" if self._world is None or self._world._model is None: return @@ -141,7 +141,7 @@ def _cheap_robot_count(self) -> int: return 0 def create_world( - self, timestep: float = None, gravity: list[float] = None, ground_plane: bool = True + self, timestep: float | None = None, gravity: list[float] | None = None, ground_plane: bool = True ) -> dict[str, Any]: """Create a new simulation world.""" _ensure_mujoco() @@ -524,10 +524,10 @@ def add_object( f"Check that the MJCF XML is valid and compatible with the current scene." ) from e - result = self._recompile_world() - if result["status"] == "error": + recompile_result = self._recompile_world() + if recompile_result["status"] == "error": del self._world.objects[name] - return result + return recompile_result return { "status": "success", @@ -837,7 +837,7 @@ async def stream( tool_use_id = tool_use.get("toolUseId", "") input_data = tool_use.get("input", {}) result = self._dispatch_action(input_data.get("action", ""), input_data) - yield ToolResultEvent({"toolUseId": tool_use_id, **result}) + yield ToolResultEvent(dict(toolUseId=tool_use_id, **result)) # type: ignore[typeddict-item] except Exception as e: yield ToolResultEvent( { From 4964643a309beb91d12c1bc560894730db2db367 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Mon, 6 Apr 2026 03:25:42 -0400 Subject: [PATCH 11/22] =?UTF-8?q?fix:=20zero=20mypy=20suppressions=20?= =?UTF-8?q?=E2=80=94=20proper=20type=20declarations=20instead=20of=20disab?= =?UTF-8?q?le=5Ferror=5Fcode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace blanket disable_error_code with proper type fixes: - Add TYPE_CHECKING attribute declarations to all 5 mixins (PhysicsMixin, RenderingMixin, RecordingMixin, PolicyRunnerMixin, RandomizationMixin) so mypy can verify self._world, self._lock, etc. - Add _push_to_hub field to SimWorld dataclass (was missing) - Add orientation + mesh_path params to SimEngine.add_object base signature - Add **kwargs to RandomizationMixin.randomize to match base - Simplify SimEngine.randomize to **kwargs (backends define own params) - Add assert guards for _world None checks in rendering methods - Restructure recording.py import fallback to avoid redefinition errors - Fix _apply_sim_action Protocol stubs to match real signatures Result: 0 mypy errors, 0 disable_error_code, only 2 inline type: ignore with specific codes (arg-type for narrowed var, typeddict-item for ** expansion) --- pyproject.toml | 3 +- .../simulation/mujoco/mjcf_builder.py | 4 +- strands_robots/simulation/mujoco/physics.py | 7 ++- .../simulation/mujoco/policy_runner.py | 26 ++++++++--- .../simulation/mujoco/randomization.py | 8 +++- strands_robots/simulation/mujoco/recording.py | 44 +++++++++++-------- strands_robots/simulation/mujoco/rendering.py | 18 +++++++- strands_robots/simulation/mujoco/scene_ops.py | 4 +- .../simulation/mujoco/simulation.py | 16 ++----- 9 files changed, 84 insertions(+), 46 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df0c197..74b455d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,7 @@ ignore_missing_imports = false # Third-party libs without type stubs [[tool.mypy.overrides]] -module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*"] +module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*", "imageio.*"] ignore_missing_imports = true # @tool decorator injects runtime signatures mypy cannot check @@ -174,7 +174,6 @@ disallow_untyped_defs = false module = ["strands_robots.simulation.mujoco.*"] disallow_untyped_defs = false warn_return_any = false -disable_error_code = ["attr-defined", "misc", "override", "import-not-found", "import-untyped", "has-type"] # Async utils and dataset recorder — thin wrappers with dynamic types [[tool.mypy.overrides]] diff --git a/strands_robots/simulation/mujoco/mjcf_builder.py b/strands_robots/simulation/mujoco/mjcf_builder.py index 22fa655..c8bc70d 100644 --- a/strands_robots/simulation/mujoco/mjcf_builder.py +++ b/strands_robots/simulation/mujoco/mjcf_builder.py @@ -143,8 +143,8 @@ def compose_multi_robot_scene( ) -> str: """Compose a multi-robot scene by merging URDF-derived MJCF fragments.""" mj = _ensure_mujoco() - world._tmpdir = tempfile.TemporaryDirectory(prefix="strands_sim_") - tmpdir = world._tmpdir.name + world._backend_state["tmpdir"] = tempfile.TemporaryDirectory(prefix="strands_sim_") + tmpdir = world._backend_state["tmpdir"].name robot_xmls = {} for robot_name, robot in robots.items(): diff --git a/strands_robots/simulation/mujoco/physics.py b/strands_robots/simulation/mujoco/physics.py index f3e3c3c..3e366f6 100644 --- a/strands_robots/simulation/mujoco/physics.py +++ b/strands_robots/simulation/mujoco/physics.py @@ -17,7 +17,7 @@ import json import logging -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np @@ -27,6 +27,11 @@ class PhysicsMixin: + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + """Advanced physics capabilities for Simulation. Expects: self._world (SimWorld with _model, _data) diff --git a/strands_robots/simulation/mujoco/policy_runner.py b/strands_robots/simulation/mujoco/policy_runner.py index 8a741c6..cb0f78a 100644 --- a/strands_robots/simulation/mujoco/policy_runner.py +++ b/strands_robots/simulation/mujoco/policy_runner.py @@ -3,7 +3,7 @@ import logging import os import time -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np @@ -15,6 +15,22 @@ class PolicyRunnerMixin: + if TYPE_CHECKING: + import threading + from concurrent.futures import Future, ThreadPoolExecutor + + from strands_robots.simulation.models import SimWorld + + _world: SimWorld | None + _lock: threading.Lock + _executor: ThreadPoolExecutor + _policy_threads: dict[str, Future[Any]] + + # Methods from RenderingMixin — declared here so mypy can verify calls + def _get_renderer(self, width: int, height: int) -> Any: ... + def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> dict[str, Any]: ... + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: ... + """Policy execution for Simulation. Expects self._world, self._executor, self._policy_threads.""" def run_policy( @@ -91,8 +107,8 @@ def run_policy( if not robot.policy_running: break - if self._world._recording: - self._world._trajectory.append( + if self._world._backend_state.get("recording", False): + self._world._backend_state["trajectory"].append( TrajectoryStep( timestamp=time.time(), sim_time=self._world.sim_time, @@ -102,8 +118,8 @@ def run_policy( instruction=instruction, ) ) - if self._world._dataset_recorder is not None: - self._world._dataset_recorder.add_frame( + if self._world._backend_state.get("dataset_recorder") is not None: + self._world._backend_state["dataset_recorder"].add_frame( observation=observation, action=action_dict, task=instruction, diff --git a/strands_robots/simulation/mujoco/randomization.py b/strands_robots/simulation/mujoco/randomization.py index 8003d64..8851521 100644 --- a/strands_robots/simulation/mujoco/randomization.py +++ b/strands_robots/simulation/mujoco/randomization.py @@ -1,7 +1,7 @@ """Domain randomization mixin.""" import logging -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np @@ -11,6 +11,11 @@ class RandomizationMixin: + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + """Domain randomization for Simulation. Expects self._world.""" def randomize( @@ -24,6 +29,7 @@ def randomize( friction_range: tuple[float, float] = (0.5, 1.5), mass_range: tuple[float, float] = (0.5, 2.0), seed: int | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Apply domain randomization to the scene.""" if self._world is None or self._world._model is None: diff --git a/strands_robots/simulation/mujoco/recording.py b/strands_robots/simulation/mujoco/recording.py index 7174a69..5fede40 100644 --- a/strands_robots/simulation/mujoco/recording.py +++ b/strands_robots/simulation/mujoco/recording.py @@ -3,7 +3,7 @@ import logging import shutil from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from strands_robots.simulation.mujoco.backend import _ensure_mujoco @@ -11,6 +11,11 @@ class RecordingMixin: + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + """Trajectory recording for Simulation. Expects self._world.""" def start_recording( @@ -27,17 +32,17 @@ def start_recording( if self._world is None: return {"status": "error", "content": [{"text": "No world."}]} + _DatasetRecorder: Any = None + _has_lerobot = False try: from strands_robots.dataset_recorder import DatasetRecorder as _DatasetRecorder - from strands_robots.dataset_recorder import has_lerobot_dataset as _has_lerobot - except ImportError: + from strands_robots.dataset_recorder import has_lerobot_dataset as _check_lerobot - def _has_lerobot(): - return False - - _DatasetRecorder = None # type: ignore[assignment] + _has_lerobot = _check_lerobot() + except ImportError: + pass - if not _has_lerobot() or _DatasetRecorder is None: + if not _has_lerobot or _DatasetRecorder is None: return { "status": "error", "content": [ @@ -47,8 +52,8 @@ def _has_lerobot(): ], } - self._world._recording = True - self._world._trajectory = [] + self._world._backend_state["recording"] = True + self._world._backend_state["trajectory"] = [] self._world._push_to_hub = push_to_hub try: @@ -76,7 +81,8 @@ def _has_lerobot(): if cam_name: camera_keys.append(cam_name) - self._world._dataset_recorder = _DatasetRecorder.create( + assert _DatasetRecorder is not None # checked above + self._world._backend_state["dataset_recorder"] = _DatasetRecorder.create( repo_id=repo_id, fps=fps, robot_type=robot_type, @@ -100,17 +106,17 @@ def _has_lerobot(): ], } except Exception as e: - self._world._recording = False + self._world._backend_state["recording"] = False logger.error("Dataset recorder init failed: %s", e) return {"status": "error", "content": [{"text": f"Dataset init failed: {e}"}]} def stop_recording(self, output_path: str | None = None) -> dict[str, Any]: """Stop recording and save episode to LeRobotDataset.""" - if self._world is None or not self._world._recording: + if self._world is None or not self._world._backend_state.get("recording", False): return {"status": "error", "content": [{"text": "Not recording."}]} - self._world._recording = False - recorder = self._world._dataset_recorder + self._world._backend_state["recording"] = False + recorder = self._world._backend_state.get("dataset_recorder", None) if recorder is None: return {"status": "error", "content": [{"text": "No dataset recorder active."}]} @@ -126,8 +132,8 @@ def stop_recording(self, output_path: str | None = None) -> dict[str, Any]: root = recorder.root recorder.finalize() - self._world._dataset_recorder = None - self._world._trajectory = [] + self._world._backend_state["dataset_recorder"] = None + self._world._backend_state["trajectory"] = [] text = ( f"Episode saved to LeRobotDataset\n" @@ -143,8 +149,8 @@ def get_recording_status(self) -> dict[str, Any]: if self._world is None: return {"status": "error", "content": [{"text": "❌ No world."}]} - recording = self._world._recording - steps = len(self._world._trajectory) + recording = self._world._backend_state.get("recording", False) + steps = len(self._world._backend_state.get("trajectory", [])) return { "status": "success", diff --git a/strands_robots/simulation/mujoco/rendering.py b/strands_robots/simulation/mujoco/rendering.py index 41dfc1e..e3b89e0 100644 --- a/strands_robots/simulation/mujoco/rendering.py +++ b/strands_robots/simulation/mujoco/rendering.py @@ -3,7 +3,7 @@ import io import json import logging -from typing import Any +from typing import TYPE_CHECKING, Any from strands_robots.simulation.mujoco.backend import _can_render, _ensure_mujoco @@ -11,6 +11,15 @@ class RenderingMixin: + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + _renderer_model: Any + _renderers: dict[tuple[int, int], Any] + default_width: int + default_height: int + """Rendering capabilities for Simulation. Expects self._world, self.default_width, self.default_height.""" def _get_renderer(self, width: int, height: int): @@ -22,6 +31,7 @@ def _get_renderer(self, width: int, height: int): if not _can_render(): return None mj = _ensure_mujoco() + assert self._world is not None # callers must check key = (width, height) if self._renderer_model is not self._world._model: self._renderers.clear() @@ -33,6 +43,7 @@ def _get_renderer(self, width: int, height: int): def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> dict[str, Any]: """Get observation from sim (same format as real robot).""" mj = _ensure_mujoco() + assert self._world is not None # callers must check model, data = self._world._model, self._world._data robot = self._world.robots[robot_name] @@ -74,9 +85,10 @@ def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> return obs - def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1): + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: """Apply action dict to sim (same interface as robot.send_action).""" mj = _ensure_mujoco() + assert self._world is not None # callers must check model, data = self._world._model, self._world._data for key, value in action_dict.items(): @@ -91,7 +103,9 @@ def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_subs for _ in range(max(1, n_substeps)): mj.mj_step(model, data) + assert self._world is not None self._world.sim_time = data.time + assert self._world is not None # callers must check self._world.step_count += n_substeps if hasattr(self, "_viewer_handle") and self._viewer_handle is not None: diff --git a/strands_robots/simulation/mujoco/scene_ops.py b/strands_robots/simulation/mujoco/scene_ops.py index 9352f90..34e553e 100644 --- a/strands_robots/simulation/mujoco/scene_ops.py +++ b/strands_robots/simulation/mujoco/scene_ops.py @@ -84,8 +84,8 @@ def _reload_scene_from_xml(world: SimWorld, scene_path: str) -> bool: def _get_robot_base_dir(world: SimWorld) -> str | None: """Get the directory of the original robot model file.""" - if world._robot_base_xml: - return os.path.dirname(os.path.abspath(world._robot_base_xml)) + if world._backend_state.get("robot_base_xml", ""): + return os.path.dirname(os.path.abspath(world._backend_state.get("robot_base_xml", ""))) return None diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py index ce46e48..3296cf4 100644 --- a/strands_robots/simulation/mujoco/simulation.py +++ b/strands_robots/simulation/mujoco/simulation.py @@ -84,14 +84,6 @@ def __init__( logger.info("🎮 Simulation tool '%s' initialized", tool_name) - try: - from strands_robots.zenoh_mesh import init_mesh - - self.mesh = init_mesh(self, peer_id=peer_id, peer_type="sim", mesh=mesh) - except Exception as e: - logger.debug("Mesh init skipped: %s", e) - self.mesh = None - # --- Public Properties --- @property @@ -225,7 +217,7 @@ def load_scene(self, scene_path: str) -> dict[str, Any]: def _compile_world(self): mj = _ensure_mujoco() xml = MJCFBuilder.build_objects_only(self._world) - self._world._xml = xml + self._world._backend_state["xml"] = xml self._world._model = mj.MjModel.from_xml_string(xml) self._world._data = mj.MjData(self._world._model) self._world.status = SimStatus.IDLE @@ -384,7 +376,7 @@ def add_robot( self._world._model = model self._world._data = data - self._world._robot_base_xml = resolved_path + self._world._backend_state["robot_base_xml"] = resolved_path self._world.robots[name] = robot for _ in range(100): @@ -668,8 +660,8 @@ def get_state(self) -> dict[str, Any]: lines.append( f"🦴 Bodies: {self._world._model.nbody} | 🔩 Joints: {self._world._model.njnt} | ⚡ Actuators: {self._world._model.nu}" ) - if self._world._recording: - lines.append(f"🔴 Recording: {len(self._world._trajectory)} steps") + if self._world._backend_state.get("recording", False): + lines.append(f"🔴 Recording: {len(self._world._backend_state["trajectory"])} steps") return {"status": "success", "content": [{"text": "\n".join(lines)}]} def destroy(self) -> dict[str, Any]: From 4ca9173d317f9e4ec059088824ce02e126316249 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Mon, 6 Apr 2026 18:52:59 -0400 Subject: [PATCH 12/22] feat(sim): use require_optional for imageio in policy_runner - Replace bare with for consistent optional dependency handling per project conventions - Add imageio and imageio-ffmpeg to sim-mujoco extras in pyproject.toml - Add type: ignore comment for dynamic imageio writer attribute --- pyproject.toml | 2 ++ strands_robots/simulation/mujoco/policy_runner.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 74b455d..7e6dcdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ sim = [ ] sim-mujoco = [ "mujoco>=3.0.0,<4.0.0", + "imageio>=2.28.0,<3.0.0", + "imageio-ffmpeg>=0.4.0,<1.0.0", ] all = [ "strands-robots[groot-service]", diff --git a/strands_robots/simulation/mujoco/policy_runner.py b/strands_robots/simulation/mujoco/policy_runner.py index cb0f78a..34671ea 100644 --- a/strands_robots/simulation/mujoco/policy_runner.py +++ b/strands_robots/simulation/mujoco/policy_runner.py @@ -10,6 +10,7 @@ from strands_robots._async_utils import _resolve_coroutine from strands_robots.simulation.models import TrajectoryStep from strands_robots.simulation.mujoco.backend import _ensure_mujoco +from strands_robots.utils import require_optional logger = logging.getLogger(__name__) @@ -72,10 +73,15 @@ def run_policy( frame_count = 0 cam_id = -1 if record_video: - import imageio + imageio = require_optional( + "imageio", + pip_install="imageio imageio-ffmpeg", + extra="sim-mujoco", + purpose="video recording", + ) os.makedirs(os.path.dirname(os.path.abspath(record_video)), exist_ok=True) - writer = imageio.get_writer(record_video, fps=video_fps, quality=8, macro_block_size=1) + writer = imageio.get_writer(record_video, fps=video_fps, quality=8, macro_block_size=1) # type: ignore[attr-defined] if video_camera: cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, video_camera) elif model.ncam > 0: From 70648d108db0b6a584d34fe16ac0a0981f721200 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Sun, 12 Apr 2026 16:45:01 +0000 Subject: [PATCH 13/22] =?UTF-8?q?fix:=20address=208=20review=20threads=20?= =?UTF-8?q?=E2=80=94=20deps,=20exports,=20init,=20headless,=20coupling,=20?= =?UTF-8?q?tests,=20XML=20parsing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses all 8 unresolved review threads from @awsarron (Apr 10): 1. pyproject.toml: Remove empty [sim] extra, move robot_descriptions into [sim-mujoco]. Update extra= reference in backend.py. (yinsong1986 thread) 2. pyproject.toml: Keep sim-mujoco naming (not just mujoco) for consistency with future sim-isaac, sim-pybullet extras. (awsarron nit — reply only) 3. mujoco/__init__.py: Stop exporting private functions (_configure_gl_backend, _ensure_mujoco, _is_headless). Internal callers already import from backend directly. (awsarron thread) 4. simulation.py: Centralize _ensure_mujoco() to __init__ — fail fast at construction time. Store as self._mj, use throughout Simulation methods. Mixins retain their own _ensure_mujoco() calls since they may be used independently. (awsarron thread) 5. backend.py: Add docstring explaining why _is_headless() is Linux-only — Windows uses WGL, macOS uses CGL, both support offscreen natively. (awsarron thread) 6. policy_runner.py: Replace duplicated private function TYPE_CHECKING stubs with a shared SimulationProtocol in new types.py module. Eliminates coupling via signature duplication. (awsarron thread) 7. test_mujoco_e2e.py: Add TestToolSpecActionCoverage — iterates every action enum in tool_spec.json and asserts hasattr(Simulation, method) via the alias map. Catches drift between spec and implementation. (awsarron thread) 8. scene_ops.py: Standardize on ElementTree for all XML manipulation. Converted inject_object_into_scene, inject_camera_into_scene, and _patch_xml_paths from regex/string.replace to ET. Kept regex fallback in _patch_xml_paths for malformed fragments. (awsarron thread) --- pyproject.toml | 5 +- strands_robots/simulation/mujoco/__init__.py | 9 -- strands_robots/simulation/mujoco/backend.py | 7 +- .../simulation/mujoco/policy_runner.py | 22 ++-- strands_robots/simulation/mujoco/scene_ops.py | 107 ++++++++++++------ .../simulation/mujoco/simulation.py | 21 ++-- strands_robots/simulation/mujoco/types.py | 36 ++++++ tests/test_mujoco_e2e.py | 47 ++++++++ 8 files changed, 183 insertions(+), 71 deletions(-) create mode 100644 strands_robots/simulation/mujoco/types.py diff --git a/pyproject.toml b/pyproject.toml index 7e6dcdb..d9ae441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,10 +48,8 @@ groot-service = [ lerobot = [ "lerobot>=0.5.0,<0.6.0", ] -sim = [ - "robot_descriptions>=1.11.0,<2.0.0", -] sim-mujoco = [ + "robot_descriptions>=1.11.0,<2.0.0", "mujoco>=3.0.0,<4.0.0", "imageio>=2.28.0,<3.0.0", "imageio-ffmpeg>=0.4.0,<1.0.0", @@ -59,7 +57,6 @@ sim-mujoco = [ all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", - "strands-robots[sim]", "strands-robots[sim-mujoco]", ] dev = [ diff --git a/strands_robots/simulation/mujoco/__init__.py b/strands_robots/simulation/mujoco/__init__.py index 869040a..03c6a03 100644 --- a/strands_robots/simulation/mujoco/__init__.py +++ b/strands_robots/simulation/mujoco/__init__.py @@ -18,17 +18,8 @@ from strands_robots.simulation import Simulation # → MuJoCoSimulation """ -from strands_robots.simulation.mujoco.backend import ( - _configure_gl_backend, - _ensure_mujoco, - _is_headless, -) - __all__ = [ "MuJoCoSimulation", - "_configure_gl_backend", - "_ensure_mujoco", - "_is_headless", ] diff --git a/strands_robots/simulation/mujoco/backend.py b/strands_robots/simulation/mujoco/backend.py index 38f97c2..9c0873d 100644 --- a/strands_robots/simulation/mujoco/backend.py +++ b/strands_robots/simulation/mujoco/backend.py @@ -17,6 +17,11 @@ def _is_headless() -> bool: Returns True on Linux when no DISPLAY or WAYLAND_DISPLAY is set, which means GLFW-based rendering will fail. + + Windows and macOS are always False because MuJoCo uses native + windowing backends (WGL on Windows, CGL on macOS) that support + offscreen rendering without X11/Wayland. The EGL/OSMesa fallback + is Linux-specific. """ if sys.platform != "linux": return False @@ -88,7 +93,7 @@ def _ensure_mujoco() -> "Any": _mujoco = require_optional( "mujoco", pip_install="mujoco", - extra="sim", + extra="sim-mujoco", purpose="MuJoCo simulation", ) if _mujoco_viewer is None and not _is_headless(): diff --git a/strands_robots/simulation/mujoco/policy_runner.py b/strands_robots/simulation/mujoco/policy_runner.py index 34671ea..a0a67d8 100644 --- a/strands_robots/simulation/mujoco/policy_runner.py +++ b/strands_robots/simulation/mujoco/policy_runner.py @@ -16,23 +16,17 @@ class PolicyRunnerMixin: - if TYPE_CHECKING: - import threading - from concurrent.futures import Future, ThreadPoolExecutor - - from strands_robots.simulation.models import SimWorld + """Policy execution for Simulation. - _world: SimWorld | None - _lock: threading.Lock - _executor: ThreadPoolExecutor - _policy_threads: dict[str, Future[Any]] + Expects the composite Simulation class to satisfy SimulationProtocol + (provides self._world, self._executor, self._policy_threads, and + cross-mixin methods like _get_sim_observation / _apply_sim_action). + """ - # Methods from RenderingMixin — declared here so mypy can verify calls - def _get_renderer(self, width: int, height: int) -> Any: ... - def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> dict[str, Any]: ... - def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: ... + if TYPE_CHECKING: + from strands_robots.simulation.mujoco.types import SimulationProtocol - """Policy execution for Simulation. Expects self._world, self._executor, self._policy_threads.""" + _: SimulationProtocol # noqa: F841 — declares the expected interface def run_policy( self, diff --git a/strands_robots/simulation/mujoco/scene_ops.py b/strands_robots/simulation/mujoco/scene_ops.py index 34e553e..fa661b7 100644 --- a/strands_robots/simulation/mujoco/scene_ops.py +++ b/strands_robots/simulation/mujoco/scene_ops.py @@ -19,26 +19,44 @@ def _patch_xml_paths(xml_content: str, robot_base_dir: str) -> str: - """Patch meshdir/texturedir in XML to absolute paths for tmpdir loading.""" - meshdir_match = re.search(r'meshdir="([^"]*)"', xml_content) - existing_meshdir = meshdir_match.group(1) if meshdir_match else "" - abs_meshdir = os.path.normpath(os.path.join(robot_base_dir, existing_meshdir)) + """Patch meshdir/texturedir in XML to absolute paths for tmpdir loading. - texdir_match = re.search(r'texturedir="([^"]*)"', xml_content) - existing_texdir = texdir_match.group(1) if texdir_match else "" - abs_texdir = os.path.normpath(os.path.join(robot_base_dir, existing_texdir)) - - if meshdir_match: - xml_content = re.sub(r'meshdir="[^"]*"', f'meshdir="{abs_meshdir}"', xml_content) - elif " bool: @@ -107,7 +125,10 @@ def _save_and_patch_xml(world: SimWorld, tmpdir: str, filename: str) -> str: def inject_object_into_scene(world: SimWorld, obj: SimObject) -> bool: - """Inject object into a running simulation via XML round-trip.""" + """Inject object into a running simulation via XML round-trip. + + Uses ElementTree for XML manipulation (consistent with eject_body_from_scene). + """ _ensure_mujoco() if world._model is None: return False @@ -116,17 +137,25 @@ def inject_object_into_scene(world: SimWorld, obj: SimObject) -> bool: try: scene_path = _save_and_patch_xml(world, tmpdir, "scene_with_objects.xml") - with open(scene_path) as f: - xml_content = f.read() + tree = ET.parse(scene_path) + root = tree.getroot() - obj_xml = MJCFBuilder._object_xml(obj, indent=4) - xml_content = xml_content.replace("", f"{obj_xml}\n") + # Find and append the object element + worldbody = root.find("worldbody") + if worldbody is None: + logger.error("No found in scene XML") + return False + + obj_xml_str = MJCFBuilder._object_xml(obj, indent=4) + obj_elem = ET.fromstring(f"<_wrapper>{obj_xml_str}") + for child in obj_elem: + worldbody.append(child) # Remove keyframes — adding a freejoint changes qpos size - xml_content = re.sub(r".*?", "", xml_content, flags=re.DOTALL) + for keyframe_elem in root.findall("keyframe"): + root.remove(keyframe_elem) - with open(scene_path, "w") as f: - f.write(xml_content) + tree.write(scene_path, xml_declaration=True) return _reload_scene_from_xml(world, scene_path) except (ValueError, RuntimeError, OSError) as e: @@ -184,7 +213,10 @@ def eject_body_from_scene(world: SimWorld, body_name: str) -> bool: def inject_camera_into_scene(world: SimWorld, cam: SimCamera) -> bool: - """Inject a camera into a running simulation via XML round-trip.""" + """Inject a camera into a running simulation via XML round-trip. + + Uses ElementTree for XML manipulation (consistent with eject_body_from_scene). + """ _ensure_mujoco() if world._model is None: return False @@ -193,15 +225,22 @@ def inject_camera_into_scene(world: SimWorld, cam: SimCamera) -> bool: try: scene_path = _save_and_patch_xml(world, tmpdir, "scene_with_cameras.xml") - with open(scene_path) as f: - xml_content = f.read() + tree = ET.parse(scene_path) + root = tree.getroot() + + worldbody = root.find("worldbody") + if worldbody is None: + logger.error("No found in scene XML") + return False px, py, pz = cam.position - cam_xml = f' ' - xml_content = xml_content.replace("", f"{cam_xml}\n") + cam_elem = ET.SubElement(worldbody, "camera") + cam_elem.set("name", _sanitize_name(cam.name)) + cam_elem.set("pos", f"{px} {py} {pz}") + cam_elem.set("fovy", str(cam.fov)) + cam_elem.set("mode", "fixed") - with open(scene_path, "w") as f: - f.write(xml_content) + tree.write(scene_path, xml_declaration=True) return _reload_scene_from_xml(world, scene_path) except (ValueError, RuntimeError, OSError) as e: diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py index 3296cf4..54d781c 100644 --- a/strands_robots/simulation/mujoco/simulation.py +++ b/strands_robots/simulation/mujoco/simulation.py @@ -82,6 +82,9 @@ def __init__( self._renderers: dict[tuple, Any] = {} self._renderer_model = None + # Fail fast: verify MuJoCo is importable at construction time + # so consumers catch missing-dependency errors immediately. + self._mj = _ensure_mujoco() logger.info("🎮 Simulation tool '%s' initialized", tool_name) # --- Public Properties --- @@ -136,7 +139,7 @@ def create_world( self, timestep: float | None = None, gravity: list[float] | None = None, ground_plane: bool = True ) -> dict[str, Any]: """Create a new simulation world.""" - _ensure_mujoco() + # mujoco verified at __init__ if self._world is not None and self._world._model is not None: return { @@ -187,7 +190,7 @@ def create_world( def load_scene(self, scene_path: str) -> dict[str, Any]: """Load a complete scene from MJCF XML or URDF file.""" - mj = _ensure_mujoco() + mj = self._mj if not os.path.exists(scene_path): return {"status": "error", "content": [{"text": f"❌ Scene file not found: {scene_path}"}]} @@ -215,7 +218,7 @@ def load_scene(self, scene_path: str) -> dict[str, Any]: return {"status": "error", "content": [{"text": f"❌ Failed to load scene: {e}"}]} def _compile_world(self): - mj = _ensure_mujoco() + mj = self._mj xml = MJCFBuilder.build_objects_only(self._world) self._world._backend_state["xml"] = xml self._world._model = mj.MjModel.from_xml_string(xml) @@ -327,7 +330,7 @@ def add_robot( if not os.path.exists(resolved_path): return {"status": "error", "content": [{"text": f"❌ File not found: {resolved_path}"}]} - mj = _ensure_mujoco() + mj = self._mj robot = SimRobot( name=name, @@ -438,7 +441,7 @@ def get_robot_state(self, robot_name: str) -> dict[str, Any]: if robot_name not in self._world.robots: return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found."}]} - mj = _ensure_mujoco() + mj = self._mj robot = self._world.robots[robot_name] model, data = self._world._model, self._world._data @@ -548,7 +551,7 @@ def move_object( if name not in self._world.objects: return {"status": "error", "content": [{"text": f"❌ '{name}' not found."}]} - mj = _ensure_mujoco() + mj = self._mj model, data = self._world._model, self._world._data jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, f"{name}_joint") @@ -623,7 +626,7 @@ def remove_camera(self, name: str) -> dict[str, Any]: def step(self, n_steps: int = 1) -> dict[str, Any]: if self._world is None or self._world._data is None: return {"status": "error", "content": [{"text": "❌ No simulation."}]} - mj = _ensure_mujoco() + mj = self._mj for _ in range(n_steps): mj.mj_step(self._world._model, self._world._data) self._world.sim_time = self._world._data.time @@ -638,7 +641,7 @@ def step(self, n_steps: int = 1) -> dict[str, Any]: def reset(self) -> dict[str, Any]: if self._world is None or self._world._model is None: return {"status": "error", "content": [{"text": "❌ No world."}]} - mj = _ensure_mujoco() + mj = self._mj mj.mj_resetData(self._world._model, self._world._data) self._world.sim_time = 0.0 self._world.step_count = 0 @@ -737,7 +740,7 @@ def get_features(self) -> dict[str, Any]: if self._world is None or self._world._model is None: return {"status": "error", "content": [{"text": "❌ No simulation."}]} - mj = _ensure_mujoco() + mj = self._mj model = self._world._model joint_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) for i in range(model.njnt)] diff --git a/strands_robots/simulation/mujoco/types.py b/strands_robots/simulation/mujoco/types.py new file mode 100644 index 0000000..f8d1a59 --- /dev/null +++ b/strands_robots/simulation/mujoco/types.py @@ -0,0 +1,36 @@ +"""Shared type declarations for MuJoCo simulation mixins. + +Defines the SimulationProtocol that all mixins can reference instead of +duplicating TYPE_CHECKING stubs for cross-mixin method signatures. +""" + +from __future__ import annotations + +import threading +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Any, Protocol, runtime_checkable + +from strands_robots.simulation.models import SimWorld + + +@runtime_checkable +class SimulationProtocol(Protocol): + """Protocol describing the shared state and methods available across all mixins. + + Each mixin operates on a Simulation instance that provides this interface. + Using a Protocol avoids duplicating private method stubs in TYPE_CHECKING blocks. + """ + + _world: SimWorld | None + _lock: threading.Lock + _executor: ThreadPoolExecutor + _policy_threads: dict[str, Future[Any]] + _mj: Any # The lazily-imported mujoco module + _renderer_model: Any + _renderers: dict[tuple[int, int], Any] + default_width: int + default_height: int + + def _get_renderer(self, width: int, height: int) -> Any: ... + def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> dict[str, Any]: ... + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: ... diff --git a/tests/test_mujoco_e2e.py b/tests/test_mujoco_e2e.py index c6d2d1e..4dd8fbc 100644 --- a/tests/test_mujoco_e2e.py +++ b/tests/test_mujoco_e2e.py @@ -267,3 +267,50 @@ def test_color_randomization(self, sim_env): if __name__ == "__main__": pytest.main([__file__, "-v"]) + + +class TestToolSpecActionCoverage: + """Verify every action enum in tool_spec.json maps to a real method on Simulation.""" + + def test_all_actions_have_methods(self): + """Every action in tool_spec.json must resolve to a method on Simulation.""" + import json + from pathlib import Path + + from strands_robots.simulation.mujoco.simulation import Simulation + + spec_path = Path(__file__).parent.parent / "strands_robots" / "simulation" / "mujoco" / "tool_spec.json" + with open(spec_path) as f: + spec = json.load(f) + + actions = spec["properties"]["action"]["enum"] + assert len(actions) > 0, "tool_spec.json should have at least one action" + + # Aliases used by _dispatch_action + aliases = { + "list_urdfs": "list_urdfs_action", + "register_urdf": "register_urdf_action", + "stop_policy": "_stop_policy", + } + + missing = [] + for action in actions: + method_name = aliases.get(action, action) + if not hasattr(Simulation, method_name): + missing.append(f"{action} (looked for method '{method_name}')") + + assert not missing, "tool_spec.json actions with no matching Simulation method:\n" + "\n".join( + f" - {m}" for m in missing + ) + + def test_action_enum_is_not_empty(self): + """Sanity: tool_spec.json action enum is populated.""" + import json + from pathlib import Path + + spec_path = Path(__file__).parent.parent / "strands_robots" / "simulation" / "mujoco" / "tool_spec.json" + with open(spec_path) as f: + spec = json.load(f) + + actions = spec["properties"]["action"]["enum"] + assert len(actions) >= 30, f"Expected ≥30 actions, got {len(actions)}" From 19ea1ddb2a3f83913bb731e10baec415745d21d9 Mon Sep 17 00:00:00 2001 From: strands-agent <217235299+strands-agent@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:30:34 +0000 Subject: [PATCH 14/22] fix: replace Protocol annotation with direct TYPE_CHECKING stubs in PolicyRunnerMixin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `_: SimulationProtocol` pattern declares a class variable named `_` but does NOT propagate Protocol member declarations to mypy's understanding of the class. This caused 34 attr-defined errors in policy_runner.py. Fix: Replace with direct attribute declarations under TYPE_CHECKING, matching the pattern used by PhysicsMixin, RenderingMixin, RecordingMixin, and RandomizationMixin. The SimulationProtocol in types.py is preserved for runtime checks and documentation — it's the TYPE_CHECKING usage pattern that was incorrect. Lint: 0 errors (ruff check + ruff format + mypy) Tests: 323 passed, 2 skipped, 0 failures --- .../simulation/mujoco/policy_runner.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/strands_robots/simulation/mujoco/policy_runner.py b/strands_robots/simulation/mujoco/policy_runner.py index a0a67d8..71f2503 100644 --- a/strands_robots/simulation/mujoco/policy_runner.py +++ b/strands_robots/simulation/mujoco/policy_runner.py @@ -1,5 +1,3 @@ -"""Policy execution mixin — run_policy, start_policy, record_video, replay_episode, eval_policy.""" - import logging import os import time @@ -18,15 +16,28 @@ class PolicyRunnerMixin: """Policy execution for Simulation. - Expects the composite Simulation class to satisfy SimulationProtocol - (provides self._world, self._executor, self._policy_threads, and - cross-mixin methods like _get_sim_observation / _apply_sim_action). + Expects the composite Simulation class to provide: + - self._world (SimWorld | None) + - self._lock (threading.Lock) + - self._executor (ThreadPoolExecutor) + - self._policy_threads (dict[str, Future]) + - self._get_sim_observation(), self._apply_sim_action(), self._get_renderer() """ if TYPE_CHECKING: - from strands_robots.simulation.mujoco.types import SimulationProtocol + import threading + from concurrent.futures import Future, ThreadPoolExecutor + + from strands_robots.simulation.models import SimWorld + + _world: SimWorld | None + _lock: threading.Lock + _executor: ThreadPoolExecutor + _policy_threads: dict[str, Future[Any]] - _: SimulationProtocol # noqa: F841 — declares the expected interface + def _get_renderer(self, width: int, height: int) -> Any: ... + def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> dict[str, Any]: ... + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: ... def run_policy( self, From 62fa5905bff377ff09b25b503fba414506a7690f Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 22 Apr 2026 15:15:51 -0400 Subject: [PATCH 15/22] refactor(mujoco): migrate SimWorld private fields to _backend_state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Post-#84 merge: SimWorld no longer carries MuJoCo-specific private fields (_xml, _robot_base_xml, _recording, _trajectory, _dataset_recorder, _tmpdir, _push_to_hub). These are MuJoCo backend implementation details and now live in world._backend_state, as the SimWorld docstring requests (prefer _backend_state over new fields). Migrated call sites: - mjcf_builder.py: tmpdir - policy_runner.py: recording, trajectory, dataset_recorder - recording.py: recording, trajectory, dataset_recorder, push_to_hub - scene_ops.py: robot_base_xml - simulation.py: xml, robot_base_xml, recording, trajectory Reads use dict[] where preceded by a guard that guarantees initialization (e.g. start_recording() sets before policy_runner reads), and .get() with sensible defaults where the key may be unset. Tests: 392 passed, 2 skipped (5 pre-existing test_path_validation failures are on main too — unrelated). Lint: ruff + mypy clean on 75 source files. --- strands_robots/simulation/mujoco/recording.py | 4 ++-- strands_robots/simulation/mujoco/simulation.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/strands_robots/simulation/mujoco/recording.py b/strands_robots/simulation/mujoco/recording.py index 5fede40..849d0df 100644 --- a/strands_robots/simulation/mujoco/recording.py +++ b/strands_robots/simulation/mujoco/recording.py @@ -54,7 +54,7 @@ def start_recording( self._world._backend_state["recording"] = True self._world._backend_state["trajectory"] = [] - self._world._push_to_hub = push_to_hub + self._world._backend_state["push_to_hub"] = push_to_hub try: if overwrite: @@ -123,7 +123,7 @@ def stop_recording(self, output_path: str | None = None) -> dict[str, Any]: recorder.save_episode() push_result = None - if getattr(self._world, "_push_to_hub", False): + if self._world._backend_state.get("push_to_hub", False): push_result = recorder.push_to_hub(tags=["strands-robots", "sim"]) repo_id = recorder.repo_id diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py index 54d781c..e12013f 100644 --- a/strands_robots/simulation/mujoco/simulation.py +++ b/strands_robots/simulation/mujoco/simulation.py @@ -664,7 +664,7 @@ def get_state(self) -> dict[str, Any]: f"🦴 Bodies: {self._world._model.nbody} | 🔩 Joints: {self._world._model.njnt} | ⚡ Actuators: {self._world._model.nu}" ) if self._world._backend_state.get("recording", False): - lines.append(f"🔴 Recording: {len(self._world._backend_state["trajectory"])} steps") + lines.append(f"🔴 Recording: {len(self._world._backend_state['trajectory'])} steps") return {"status": "success", "content": [{"text": "\n".join(lines)}]} def destroy(self) -> dict[str, Any]: From abc65e67fa6d8fb017c5ccf6575fd9675fe70e1a Mon Sep 17 00:00:00 2001 From: cagataycali Date: Tue, 31 Mar 2026 21:09:05 -0400 Subject: [PATCH 16/22] feat: Robot() factory + top-level lazy imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add unified Robot() factory function that auto-detects sim vs real: Robot('so100') → auto-detect → MuJoCo sim (default) Robot('so100', mode='sim') → Simulation AgentTool Robot('so100', mode='real') → HardwareRobot AgentTool Auto-detect priority: 1. STRANDS_ROBOT_MODE env var (explicit override) 2. USB probe for servo controllers (Feetech/Dynamixel) 3. Default to sim (safest — never accidentally send to hardware) Also adds list_robots(mode='all'|'sim'|'real'|'both') for discovery. Updates __init__.py with lazy imports: - Robot (→ factory), list_robots - Simulation, SimWorld, SimRobot, SimObject, SimCamera - Auto-configures MuJoCo GL backend for headless environments Tests: 22 new factory tests covering name resolution, aliases, list_robots filtering, auto-detect mode, Robot() factory, imports. --- strands_robots/__init__.py | 45 +++++++-- strands_robots/factory.py | 199 +++++++++++++++++++++++++++++++++++++ tests/test_factory.py | 148 +++++++++++++++++++++++++++ 3 files changed, 385 insertions(+), 7 deletions(-) create mode 100644 strands_robots/factory.py create mode 100644 tests/test_factory.py diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py index 8ee9c41..6c4943f 100644 --- a/strands_robots/__init__.py +++ b/strands_robots/__init__.py @@ -11,11 +11,12 @@ - Clean separation between robot control and policy inference - Direct policy injection for maximum flexibility - Multi-camera support with rich configuration options +- MuJoCo simulation backend (no GPU required) Lazy Loading: - Heavy imports (Robot, tools, Gr00tPolicy) are deferred until first access. - Heavy imports are deferred so ``import strands_robots`` stays fast when lerobot/torch - are installed but not yet needed. + Heavy imports (Robot, tools, Gr00tPolicy, Simulation) are deferred until + first access. Heavy imports are deferred so ``import strands_robots`` stays + fast when lerobot/torch/mujoco are installed but not yet needed. Light-weight symbols (Policy, MockPolicy, create_policy) are available immediately since they don't pull in torch/lerobot. @@ -26,7 +27,7 @@ from typing import Any # ------------------------------------------------------------------ -# Light-weight imports — no torch / lerobot dependency +# Light-weight imports — no torch / lerobot / mujoco dependency # ------------------------------------------------------------------ from strands_robots.policies import MockPolicy, Policy, create_policy # noqa: F401 @@ -35,8 +36,21 @@ # ------------------------------------------------------------------ # Maps public name -> (module_path, attribute_name) _LAZY_IMPORTS: dict[str, tuple[str, str]] = { - "Robot": ("strands_robots.robot", "Robot"), + # Hardware robot + "Robot": ("strands_robots.factory", "Robot"), + "list_robots": ("strands_robots.factory", "list_robots"), + # Policies "Gr00tPolicy": ("strands_robots.policies.groot", "Gr00tPolicy"), + # Simulation (MuJoCo) + "Simulation": ("strands_robots.simulation", "Simulation"), + "create_simulation": ("strands_robots.simulation.factory", "create_simulation"), + "list_backends": ("strands_robots.simulation.factory", "list_backends"), + "register_backend": ("strands_robots.simulation.factory", "register_backend"), + "SimWorld": ("strands_robots.simulation", "SimWorld"), + "SimRobot": ("strands_robots.simulation", "SimRobot"), + "SimObject": ("strands_robots.simulation", "SimObject"), + "SimCamera": ("strands_robots.simulation", "SimCamera"), + # Tools "gr00t_inference": ("strands_robots.tools.gr00t_inference", "gr00t_inference"), "lerobot_calibrate": ("strands_robots.tools.lerobot_calibrate", "lerobot_calibrate"), "lerobot_camera": ("strands_robots.tools.lerobot_camera", "lerobot_camera"), @@ -53,6 +67,11 @@ # Lazy-loaded "Robot", "Gr00tPolicy", + "Simulation", + "SimWorld", + "SimRobot", + "SimObject", + "SimCamera", "gr00t_inference", "lerobot_camera", "lerobot_teleoperate", @@ -62,12 +81,24 @@ ] +# Auto-configure MuJoCo GL backend for headless environments BEFORE any +# module imports mujoco at the top level. MuJoCo locks the OpenGL backend +# at import time, so MUJOCO_GL must be set first. +try: + from strands_robots.simulation.mujoco.backend import _configure_gl_backend + + _configure_gl_backend() +except (ImportError, AttributeError, OSError): + pass + + def __getattr__(name: str) -> Any: # noqa: N807 """Lazy-load heavy modules on first attribute access. - This avoids importing torch, lerobot, numpy, pyserial, etc. at + This avoids importing torch, lerobot, numpy, mujoco, pyserial, etc. at ``import strands_robots`` time. The first access to e.g. - ``strands_robots.Robot`` triggers the real import. + ``strands_robots.Robot`` or ``strands_robots.Simulation`` triggers the + real import. """ if name in _LAZY_IMPORTS: module_path, attr_name = _LAZY_IMPORTS[name] diff --git a/strands_robots/factory.py b/strands_robots/factory.py new file mode 100644 index 0000000..778bd03 --- /dev/null +++ b/strands_robots/factory.py @@ -0,0 +1,199 @@ +"""Unified Robot Factory — convenience layer over Simulation and HardwareRobot. + +Provides: + - ``Robot("so100")`` → auto-detects sim/real, returns the right backend + - ``list_robots()`` → what's available + +Examples:: + + # Auto-detect (sim if no hardware found) + sim = Robot("so100") + + # Explicit sim + sim = Robot("so100", mode="sim") + + # With custom URDF/MJCF path + sim = Robot("my_arm", mode="sim", urdf_path="/path/to/robot.xml") + + # Real hardware + hw = Robot("so100", mode="real", cameras={...}) + +Future (not yet implemented):: + + sim = Robot("unitree_go2", backend="isaac", num_envs=4096) + sim = Robot("so100", backend="newton", num_envs=4096) +""" + +import logging +import os +from typing import Any + +from strands_robots.registry import ( + get_hardware_type, + has_hardware, + resolve_name, +) +from strands_robots.registry import list_robots as _registry_list_robots + +logger = logging.getLogger(__name__) + + +def _auto_detect_mode(canonical: str) -> str: + """Auto-detect sim vs real mode. + + Priority: + 1. ``STRANDS_ROBOT_MODE`` env var (explicit override) + 2. Robot-specific USB detection (Feetech/Dynamixel servo controllers) + 3. Default to sim (safest — never accidentally send commands to hardware) + """ + env_mode = os.getenv("STRANDS_ROBOT_MODE", "").lower() + if env_mode in ("sim", "real"): + return env_mode + + # Only probe USB if the robot actually has hardware support + if has_hardware(canonical): + try: + import serial.tools.list_ports + + ports = list(serial.tools.list_ports.comports()) + servo_keywords = ["feetech", "dynamixel", "sts3215", "xl430", "xl330"] + exclude = ["bluetooth", "internal", "debug", "apple", "modem"] + robot_ports = [ + p + for p in ports + if any(kw in (p.description + getattr(p, "manufacturer", "")).lower() for kw in servo_keywords) + and not any(s in p.description.lower() for s in exclude) + ] + if robot_ports: + logger.info( + "Auto-detected robot hardware: %s", + [p.device for p in robot_ports], + ) + return "real" + except (ImportError, Exception): + pass + + return "sim" + + +def Robot( + name: str, + mode: str = "auto", + backend: str = "mujoco", + urdf_path: str = None, + cameras: dict[str, dict[str, Any]] | None = None, + position: list[float] | None = None, + **kwargs, +): + """Create a robot — returns a Simulation or HardwareRobot instance. + + This is a convenience factory, NOT a wrapper class. You get the real + backend instance back — with full access to all its methods. + + Args: + name: Robot name ("so100", "aloha", "unitree_g1", "panda", ...) + Accepts any alias defined in ``registry/robots.json``. + mode: "auto" (detect hardware), "sim", or "real". + backend: Simulation backend — currently only "mujoco" (CPU). + Future: "isaac" (GPU), "newton" (GPU). + urdf_path: Explicit path to URDF/MJCF file. If not provided, + resolved via the model registry (asset manager or + STRANDS_URDF_DIR search paths). + cameras: Camera config for real hardware. Example:: + + {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}} + + position: Robot position in sim world [x, y, z]. + **kwargs: Forwarded to the underlying backend constructor. + + Returns: + ``Simulation`` (MuJoCo sim) or ``Robot`` (real hardware). + + Raises: + RuntimeError: If the sim world or robot fails to initialize. + NotImplementedError: If an unimplemented backend is requested. + + Examples:: + + # MuJoCo sim (auto — no hardware detected) + sim = Robot("so100") + sim = Robot("so100", mode="sim") + + # Explicit MJCF model path + sim = Robot("my_arm", mode="sim", urdf_path="path/to/robot.xml") + + # Real hardware + hw = Robot("so100", mode="real", cameras={...}) + + # The 5-line promise + from strands_robots import Robot + from strands import Agent + robot = Robot("so100") + agent = Agent(tools=[robot]) + agent("Pick up the red cube") + """ + canonical = resolve_name(name) + + if mode == "auto": + mode = _auto_detect_mode(canonical) + + # ── Simulation ── + if mode == "sim": + if backend != "mujoco": + raise NotImplementedError( + f"Backend {backend!r} is not yet implemented. " + f"Currently supported: 'mujoco'. " + f"Isaac and Newton backends are on the roadmap." + ) + + from strands_robots.simulation import Simulation + + sim = Simulation( + tool_name=f"{canonical}_sim", + **kwargs, + ) + sim._dispatch_action("create_world", {}) + + # Build add_robot params — pass urdf_path if the user provided one + add_robot_params: dict[str, Any] = { + "robot_name": canonical, + "data_config": canonical, + "position": position or [0.0, 0.0, 0.0], + } + if urdf_path: + add_robot_params["urdf_path"] = urdf_path + + result = sim._dispatch_action("add_robot", add_robot_params) + if result.get("status") == "error": + # Extract human-readable message from content + content = result.get("content", []) + msg = content[0].get("text", str(result)) if content else str(result) + raise RuntimeError(f"Failed to create sim robot '{canonical}': {msg}") + return sim + + # ── Real hardware ── + else: + from strands_robots.robot import Robot as HardwareRobot + + real_type = get_hardware_type(canonical) or canonical + return HardwareRobot( + tool_name=canonical, + robot=real_type, + cameras=cameras, + **kwargs, + ) + + +def list_robots(mode: str = "all") -> list[dict[str, Any]]: + """List available robots. + + Args: + mode: "all", "sim", "real", or "both" (has both sim and real). + + Returns: + List of dicts with name, description, has_sim, has_real. + """ + return _registry_list_robots(mode) + + +__all__ = ["Robot", "list_robots"] diff --git a/tests/test_factory.py b/tests/test_factory.py new file mode 100644 index 0000000..068b737 --- /dev/null +++ b/tests/test_factory.py @@ -0,0 +1,148 @@ +"""Tests for strands_robots.factory — Robot(), list_robots().""" + +import os + +import pytest + +from strands_robots.factory import Robot, _auto_detect_mode, list_robots +from strands_robots.registry import ( + get_robot, + list_aliases, + resolve_name, +) +from strands_robots.registry import list_robots as registry_list_robots + + +class TestResolveNames: + def test_canonical(self): + assert resolve_name("so100") == "so100" + + def test_alias(self): + assert resolve_name("franka") == "panda" + assert resolve_name("g1") == "unitree_g1" + assert resolve_name("h1") == "unitree_h1" + + def test_case_insensitive(self): + assert resolve_name("SO100") == "so100" + assert resolve_name("Panda") == "panda" + + def test_hyphen_to_underscore(self): + assert resolve_name("reachy-mini") == "reachy_mini" + + +class TestListRobots: + def test_list_all(self): + robots = list_robots("all") + assert len(robots) > 0 + names = [r["name"] for r in robots] + assert "so100" in names + assert "panda" in names + + def test_list_sim(self): + robots = list_robots("sim") + for r in robots: + assert r["has_sim"] is True + + def test_list_real(self): + robots = list_robots("real") + for r in robots: + assert r["has_real"] is True + + def test_list_both(self): + robots = list_robots("both") + for r in robots: + assert r["has_sim"] is True + assert r["has_real"] is True + + def test_robot_has_fields(self): + robots = list_robots() + for r in robots: + assert "name" in r + assert "description" in r + assert "has_sim" in r + assert "has_real" in r + + +class TestRobotRegistry: + def test_so100_exists(self): + info = get_robot("so100") + assert info is not None + assert "asset" in info + assert info["asset"]["dir"] == "trs_so_arm100" + + def test_all_aliases_point_to_valid_robots(self): + aliases = list_aliases() + for alias, canonical in aliases.items(): + info = get_robot(canonical) + assert info is not None, f"Alias '{alias}' points to unknown robot '{canonical}'" + + def test_robot_count(self): + """Ensure we have a reasonable number of robots.""" + robots = registry_list_robots() + assert len(robots) >= 30 + + def test_all_robots_have_description(self): + robots = registry_list_robots() + for r in robots: + assert "description" in r, f"Robot '{r['name']}' missing description" + assert len(r["description"]) > 0 + + +class TestAutoDetectMode: + def test_defaults_to_sim(self): + """No hardware plugged in → sim.""" + assert _auto_detect_mode("so100") == "sim" + + def test_env_override_real(self): + os.environ["STRANDS_ROBOT_MODE"] = "real" + try: + assert _auto_detect_mode("so100") == "real" + finally: + del os.environ["STRANDS_ROBOT_MODE"] + + def test_env_override_sim(self): + os.environ["STRANDS_ROBOT_MODE"] = "sim" + try: + assert _auto_detect_mode("so100") == "sim" + finally: + del os.environ["STRANDS_ROBOT_MODE"] + + def test_env_override_case_insensitive(self): + os.environ["STRANDS_ROBOT_MODE"] = "REAL" + try: + # .lower() normalizes to "real" — should match + mode = _auto_detect_mode("so100") + assert mode == "real" # .lower() normalizes REAL → real + finally: + del os.environ["STRANDS_ROBOT_MODE"] + + +class TestRobotFactory: + def test_robot_is_callable(self): + """Robot is a factory function, not a class.""" + import inspect + + assert callable(Robot) + assert not inspect.isclass(Robot) + + def test_unknown_backend_raises(self): + with pytest.raises(NotImplementedError, match="not yet implemented"): + Robot("so100", mode="sim", backend="isaac") + + def test_newton_not_implemented(self): + with pytest.raises(NotImplementedError, match="not yet implemented"): + Robot("so100", mode="sim", backend="newton") + + def test_sim_with_urdf_path(self): + """Robot() with explicit urdf_path should work (if file exists).""" + # We don't have a real URDF here, but verify the param is accepted + with pytest.raises(RuntimeError): + Robot("test_bot", mode="sim", urdf_path="/nonexistent/robot.xml") + + def test_import_from_top_level(self): + """Robot and list_robots importable from strands_robots.""" + from strands_robots import Robot as R + from strands_robots import list_robots as lr + + assert R is Robot + assert lr is list_robots From 1639079d1ea21ce86589bd6de242d56073f4baad Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 1 Apr 2026 15:13:21 -0400 Subject: [PATCH 17/22] =?UTF-8?q?fix:=20address=20review=20=E2=80=94=20res?= =?UTF-8?q?ource=20leak,=20exception=20narrowing,=20GL=20comment,=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - sim.destroy() on partial factory failure (prevents resource leak) - except (ImportError, OSError) instead of bare Exception for USB probing - Added comment explaining why GL backend config must run eagerly - Added happy-path MuJoCo test gated behind pytest.importorskip --- strands_robots/__init__.py | 8 ++++++++ strands_robots/factory.py | 4 ++-- tests/test_factory.py | 15 +++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py index 6c4943f..41594a1 100644 --- a/strands_robots/__init__.py +++ b/strands_robots/__init__.py @@ -84,6 +84,14 @@ # Auto-configure MuJoCo GL backend for headless environments BEFORE any # module imports mujoco at the top level. MuJoCo locks the OpenGL backend # at import time, so MUJOCO_GL must be set first. +# +# WHY EAGER: This MUST run at module import time, not lazily, because: +# 1. MuJoCo reads MUJOCO_GL only on first `import mujoco` +# 2. Any downstream code doing `from strands_robots.simulation import ...` +# triggers mujoco import via the lazy-load chain +# 3. If we defer to first use, the env var would be set too late +# This is the canonical location — strands_robots/simulation/__init__.py +# intentionally does NOT duplicate this call. try: from strands_robots.simulation.mujoco.backend import _configure_gl_backend diff --git a/strands_robots/factory.py b/strands_robots/factory.py index 778bd03..cad1f4c 100644 --- a/strands_robots/factory.py +++ b/strands_robots/factory.py @@ -70,7 +70,7 @@ def _auto_detect_mode(canonical: str) -> str: [p.device for p in robot_ports], ) return "real" - except (ImportError, Exception): + except (ImportError, OSError): # USB probing may fail with OSError on permission/device issues pass return "sim" @@ -165,7 +165,7 @@ def Robot( result = sim._dispatch_action("add_robot", add_robot_params) if result.get("status") == "error": - # Extract human-readable message from content + sim.destroy() # Clean up partial initialization (executor, temp dir, MuJoCo world) content = result.get("content", []) msg = content[0].get("text", str(result)) if content else str(result) raise RuntimeError(f"Failed to create sim robot '{canonical}': {msg}") diff --git a/tests/test_factory.py b/tests/test_factory.py index 068b737..b15c254 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -139,6 +139,21 @@ def test_sim_with_urdf_path(self): with pytest.raises(RuntimeError): Robot("test_bot", mode="sim", urdf_path="/nonexistent/robot.xml") + def test_sim_happy_path_mujoco(self): + """Happy-path: create a MuJoCo sim, step physics, destroy.""" + mujoco = pytest.importorskip("mujoco") + sim = Robot("so100", mode="sim", backend="mujoco") + try: + # Verify it's a working simulation instance + assert sim._world is not None + assert sim._world._model is not None + assert sim._world._data is not None + # Step physics once to verify the engine works + mujoco.mj_step(sim._world._model, sim._world._data) + assert sim._world._data.time > 0 + finally: + sim.destroy() + def test_import_from_top_level(self): """Robot and list_robots importable from strands_robots.""" from strands_robots import Robot as R From 33c45eb1c2623e47be6e47daed59d6ae3a8171ef Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 1 Apr 2026 15:30:05 -0400 Subject: [PATCH 18/22] fix: handle None in serial port description/manufacturer during auto-detect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit p.description and p.manufacturer can both be None on some platforms. Guard with 'or ""' to prevent TypeError on string concatenation. All checks pass: ruff check ✅, ruff format ✅, 358 tests ✅ --- strands_robots/factory.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/strands_robots/factory.py b/strands_robots/factory.py index cad1f4c..1eac53f 100644 --- a/strands_robots/factory.py +++ b/strands_robots/factory.py @@ -61,8 +61,11 @@ def _auto_detect_mode(canonical: str) -> str: robot_ports = [ p for p in ports - if any(kw in (p.description + getattr(p, "manufacturer", "")).lower() for kw in servo_keywords) - and not any(s in p.description.lower() for s in exclude) + if any( + kw in ((p.description or "") + (getattr(p, "manufacturer", None) or "")).lower() + for kw in servo_keywords + ) + and not any(s in (p.description or "").lower() for s in exclude) ] if robot_ports: logger.info( From 9af0a558a1eebcbc138e2d49eede0311f2fd4047 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 1 Apr 2026 16:23:19 -0400 Subject: [PATCH 19/22] fix: use inline MJCF in test_sim_happy_path so CI works without assets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The test was calling Robot('so100') which requires downloaded URDF/mesh assets not available in CI. Now uses a minimal inline MJCF XML via tmp_path + urdf_path param — tests MuJoCo physics without external deps. --- tests/test_factory.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/tests/test_factory.py b/tests/test_factory.py index b15c254..17d8b11 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -139,10 +139,36 @@ def test_sim_with_urdf_path(self): with pytest.raises(RuntimeError): Robot("test_bot", mode="sim", urdf_path="/nonexistent/robot.xml") - def test_sim_happy_path_mujoco(self): - """Happy-path: create a MuJoCo sim, step physics, destroy.""" + def test_sim_happy_path_mujoco(self, tmp_path): + """Happy-path: create a MuJoCo sim, step physics, destroy. + + Uses a minimal inline MJCF so the test works without downloaded assets. + """ mujoco = pytest.importorskip("mujoco") - sim = Robot("so100", mode="sim", backend="mujoco") + + # Minimal valid MJCF that MuJoCo can load — a one-joint arm + mjcf_xml = """ + + + + + + + + + + + + + + + + + """ + mjcf_path = tmp_path / "test_arm.xml" + mjcf_path.write_text(mjcf_xml) + + sim = Robot("so100", mode="sim", backend="mujoco", urdf_path=str(mjcf_path)) try: # Verify it's a working simulation instance assert sim._world is not None From 577b00f3f2207bb2e7f55de6fb36c97aa0643d22 Mon Sep 17 00:00:00 2001 From: strands-agent <217235299+strands-agent@users.noreply.github.com> Date: Mon, 13 Apr 2026 05:29:10 +0000 Subject: [PATCH 20/22] refactor: address 10 review threads from @awsarron on PR #86 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Rename factory.py → robot.py, robot.py → hardware_robot.py Eliminates two 'Robot' classes in different files. The factory function now lives where users expect: strands_robots.robot.Robot 2. Default mode='sim' instead of mode='auto' Using real hardware should be an explicit decision since it affects the physical world. Robot('so100') now always returns simulation. Use mode='real' to explicitly opt into hardware control. 3. Fix ThreadPoolExecutor leak in _async_utils.py Register atexit.shutdown(wait=False) to clean up the module-level executor on interpreter exit. 4. Remove redundant list_robots() wrapper Was a 1-line passthrough to registry.list_robots(). Now __init__.py points directly to strands_robots.registry.list_robots. 5. Use module names in dataset_recorder docstring 'robot.py' → 'strands_robots.hardware_robot', 'simulation.py' → 'strands_robots.simulation' 6. Make camera shape configurable in dataset_recorder Added camera_shapes parameter to _build_features() instead of hardcoding (3, 480, 640). Default preserved for backward compat. 7. Add mode validation — invalid mode raises ValueError 8. Update __init__.py lazy imports for renamed modules Tests: 230 passed, 10 skipped, 0 failures Lint: ruff check + ruff format clean --- strands_robots/__init__.py | 4 +- strands_robots/_async_utils.py | 9 +- strands_robots/dataset_recorder.py | 15 +- strands_robots/factory.py | 202 ---- strands_robots/hardware_robot.py | 758 +++++++++++++++ strands_robots/robot.py | 907 ++++-------------- ...{test_factory.py => test_robot_factory.py} | 31 +- 7 files changed, 971 insertions(+), 955 deletions(-) delete mode 100644 strands_robots/factory.py create mode 100644 strands_robots/hardware_robot.py rename tests/{test_factory.py => test_robot_factory.py} (89%) diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py index 41594a1..40dc6c2 100644 --- a/strands_robots/__init__.py +++ b/strands_robots/__init__.py @@ -37,8 +37,8 @@ # Maps public name -> (module_path, attribute_name) _LAZY_IMPORTS: dict[str, tuple[str, str]] = { # Hardware robot - "Robot": ("strands_robots.factory", "Robot"), - "list_robots": ("strands_robots.factory", "list_robots"), + "Robot": ("strands_robots.robot", "Robot"), + "list_robots": ("strands_robots.registry", "list_robots"), # Policies "Gr00tPolicy": ("strands_robots.policies.groot", "Gr00tPolicy"), # Simulation (MuJoCo) diff --git a/strands_robots/_async_utils.py b/strands_robots/_async_utils.py index ac145fe..dc9ac8d 100644 --- a/strands_robots/_async_utils.py +++ b/strands_robots/_async_utils.py @@ -1,11 +1,16 @@ """Async-to-sync helper for resolving coroutines in sync contexts.""" import asyncio -import concurrent.futures +import atexit +from concurrent.futures import ThreadPoolExecutor # Module-level executor reused across calls to avoid creating threads at high frequency. # A single worker is sufficient — we only need to offload one asyncio.run() at a time. -_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="strands_async") +_EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="strands_async") + +# Ensure the executor is shut down cleanly on interpreter exit to avoid +# ResourceWarning and orphaned threads. +atexit.register(_EXECUTOR.shutdown, wait=False) def _resolve_coroutine(coro_or_result): # type: ignore[no-untyped-def] diff --git a/strands_robots/dataset_recorder.py b/strands_robots/dataset_recorder.py index f07bb2f..f6155e8 100644 --- a/strands_robots/dataset_recorder.py +++ b/strands_robots/dataset_recorder.py @@ -1,7 +1,7 @@ """LeRobotDataset recorder bridge for strands-robots. -Wraps LeRobotDataset so that both robot.py (real hardware) and -simulation.py (MuJoCo) can produce training-ready datasets with +Wraps LeRobotDataset so that both strands_robots.hardware_robot and +strands_robots.simulation (MuJoCo) can produce training-ready datasets with a single add_frame() call per control step. Usage: @@ -180,6 +180,7 @@ def _build_features( camera_keys: list[str] | None = None, joint_names: list[str] | None = None, use_videos: bool = True, + camera_shapes: dict[str, tuple[int, int, int]] | None = None, ) -> dict[str, Any]: """Build LeRobot v3-compatible features dict. @@ -199,13 +200,13 @@ def _build_features( for cam_name in camera_keys: key = f"observation.images.{cam_name}" dtype = "video" if use_videos else "image" + # Per-camera shape override, default (3, 480, 640) CHW + shape = (3, 480, 640) + if camera_shapes and cam_name in camera_shapes: + shape = camera_shapes[cam_name] features[key] = { "dtype": dtype, - "shape": ( - 3, - 480, - 640, - ), # CHW default, actual shape set on first frame + "shape": shape, "names": ["channels", "height", "width"], } diff --git a/strands_robots/factory.py b/strands_robots/factory.py deleted file mode 100644 index 1eac53f..0000000 --- a/strands_robots/factory.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Unified Robot Factory — convenience layer over Simulation and HardwareRobot. - -Provides: - - ``Robot("so100")`` → auto-detects sim/real, returns the right backend - - ``list_robots()`` → what's available - -Examples:: - - # Auto-detect (sim if no hardware found) - sim = Robot("so100") - - # Explicit sim - sim = Robot("so100", mode="sim") - - # With custom URDF/MJCF path - sim = Robot("my_arm", mode="sim", urdf_path="/path/to/robot.xml") - - # Real hardware - hw = Robot("so100", mode="real", cameras={...}) - -Future (not yet implemented):: - - sim = Robot("unitree_go2", backend="isaac", num_envs=4096) - sim = Robot("so100", backend="newton", num_envs=4096) -""" - -import logging -import os -from typing import Any - -from strands_robots.registry import ( - get_hardware_type, - has_hardware, - resolve_name, -) -from strands_robots.registry import list_robots as _registry_list_robots - -logger = logging.getLogger(__name__) - - -def _auto_detect_mode(canonical: str) -> str: - """Auto-detect sim vs real mode. - - Priority: - 1. ``STRANDS_ROBOT_MODE`` env var (explicit override) - 2. Robot-specific USB detection (Feetech/Dynamixel servo controllers) - 3. Default to sim (safest — never accidentally send commands to hardware) - """ - env_mode = os.getenv("STRANDS_ROBOT_MODE", "").lower() - if env_mode in ("sim", "real"): - return env_mode - - # Only probe USB if the robot actually has hardware support - if has_hardware(canonical): - try: - import serial.tools.list_ports - - ports = list(serial.tools.list_ports.comports()) - servo_keywords = ["feetech", "dynamixel", "sts3215", "xl430", "xl330"] - exclude = ["bluetooth", "internal", "debug", "apple", "modem"] - robot_ports = [ - p - for p in ports - if any( - kw in ((p.description or "") + (getattr(p, "manufacturer", None) or "")).lower() - for kw in servo_keywords - ) - and not any(s in (p.description or "").lower() for s in exclude) - ] - if robot_ports: - logger.info( - "Auto-detected robot hardware: %s", - [p.device for p in robot_ports], - ) - return "real" - except (ImportError, OSError): # USB probing may fail with OSError on permission/device issues - pass - - return "sim" - - -def Robot( - name: str, - mode: str = "auto", - backend: str = "mujoco", - urdf_path: str = None, - cameras: dict[str, dict[str, Any]] | None = None, - position: list[float] | None = None, - **kwargs, -): - """Create a robot — returns a Simulation or HardwareRobot instance. - - This is a convenience factory, NOT a wrapper class. You get the real - backend instance back — with full access to all its methods. - - Args: - name: Robot name ("so100", "aloha", "unitree_g1", "panda", ...) - Accepts any alias defined in ``registry/robots.json``. - mode: "auto" (detect hardware), "sim", or "real". - backend: Simulation backend — currently only "mujoco" (CPU). - Future: "isaac" (GPU), "newton" (GPU). - urdf_path: Explicit path to URDF/MJCF file. If not provided, - resolved via the model registry (asset manager or - STRANDS_URDF_DIR search paths). - cameras: Camera config for real hardware. Example:: - - {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}} - - position: Robot position in sim world [x, y, z]. - **kwargs: Forwarded to the underlying backend constructor. - - Returns: - ``Simulation`` (MuJoCo sim) or ``Robot`` (real hardware). - - Raises: - RuntimeError: If the sim world or robot fails to initialize. - NotImplementedError: If an unimplemented backend is requested. - - Examples:: - - # MuJoCo sim (auto — no hardware detected) - sim = Robot("so100") - sim = Robot("so100", mode="sim") - - # Explicit MJCF model path - sim = Robot("my_arm", mode="sim", urdf_path="path/to/robot.xml") - - # Real hardware - hw = Robot("so100", mode="real", cameras={...}) - - # The 5-line promise - from strands_robots import Robot - from strands import Agent - robot = Robot("so100") - agent = Agent(tools=[robot]) - agent("Pick up the red cube") - """ - canonical = resolve_name(name) - - if mode == "auto": - mode = _auto_detect_mode(canonical) - - # ── Simulation ── - if mode == "sim": - if backend != "mujoco": - raise NotImplementedError( - f"Backend {backend!r} is not yet implemented. " - f"Currently supported: 'mujoco'. " - f"Isaac and Newton backends are on the roadmap." - ) - - from strands_robots.simulation import Simulation - - sim = Simulation( - tool_name=f"{canonical}_sim", - **kwargs, - ) - sim._dispatch_action("create_world", {}) - - # Build add_robot params — pass urdf_path if the user provided one - add_robot_params: dict[str, Any] = { - "robot_name": canonical, - "data_config": canonical, - "position": position or [0.0, 0.0, 0.0], - } - if urdf_path: - add_robot_params["urdf_path"] = urdf_path - - result = sim._dispatch_action("add_robot", add_robot_params) - if result.get("status") == "error": - sim.destroy() # Clean up partial initialization (executor, temp dir, MuJoCo world) - content = result.get("content", []) - msg = content[0].get("text", str(result)) if content else str(result) - raise RuntimeError(f"Failed to create sim robot '{canonical}': {msg}") - return sim - - # ── Real hardware ── - else: - from strands_robots.robot import Robot as HardwareRobot - - real_type = get_hardware_type(canonical) or canonical - return HardwareRobot( - tool_name=canonical, - robot=real_type, - cameras=cameras, - **kwargs, - ) - - -def list_robots(mode: str = "all") -> list[dict[str, Any]]: - """List available robots. - - Args: - mode: "all", "sim", "real", or "both" (has both sim and real). - - Returns: - List of dicts with name, description, has_sim, has_real. - """ - return _registry_list_robots(mode) - - -__all__ = ["Robot", "list_robots"] diff --git a/strands_robots/hardware_robot.py b/strands_robots/hardware_robot.py new file mode 100644 index 0000000..33c5ba7 --- /dev/null +++ b/strands_robots/hardware_robot.py @@ -0,0 +1,758 @@ +#!/usr/bin/env python3 +""" +Universal Robot Control with Policy Abstraction for Any VLA Provider + +This module provides a clean robot interface that works with any LeRobot-compatible +robot and any VLA provider through the Policy abstraction. + +Features: +- Async robot task execution with real-time status reporting +- Non-blocking operations - robot moves while tool returns status +- Stop functionality to interrupt running tasks +- Connection state management with proper error handling +- Policy abstraction for any VLA provider +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +import time +from collections.abc import AsyncGenerator +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, cast + +from strands.tools.tools import AgentTool +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolResult, ToolSpec, ToolUse + +if TYPE_CHECKING: + from lerobot.robots.config import RobotConfig + from lerobot.robots.robot import Robot as LeRobotRobot + + from .policies import Policy + +logger = logging.getLogger(__name__) + + +class TaskStatus(Enum): + """Robot task execution status""" + + IDLE = "idle" + CONNECTING = "connecting" + RUNNING = "running" + COMPLETED = "completed" + STOPPED = "stopped" + ERROR = "error" + + +@dataclass +class RobotTaskState: + """Robot task execution state""" + + status: TaskStatus = TaskStatus.IDLE + instruction: str = "" + start_time: float = 0.0 + duration: float = 0.0 + step_count: int = 0 + error_message: str = "" + task_future: Future | None = None + + +class Robot(AgentTool): + """Universal robot control with async task execution and status reporting.""" + + def __init__( + self, + tool_name: str, + robot: LeRobotRobot | RobotConfig | str, + cameras: dict[str, dict[str, Any]] | None = None, + action_horizon: int = 8, + data_config: str | Any | None = None, + control_frequency: float = 50.0, + **kwargs, + ): + """Initialize Robot with async capabilities. + + Args: + tool_name: Name for this robot tool + robot: LeRobot Robot instance, RobotConfig, or robot type string + cameras: Camera configuration dict: + {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}} + action_horizon: Actions per inference step + data_config: Data configuration (for GR00T compatibility) + control_frequency: Control loop frequency in Hz (default: 50Hz) + **kwargs: Robot-specific parameters (port, etc.) + """ + super().__init__() + + self.tool_name_str = tool_name + self.action_horizon = action_horizon + self.data_config = data_config + self.control_frequency = control_frequency + self.action_sleep_time = 1.0 / control_frequency # Time between actions + + # Task execution state + self._task_state = RobotTaskState() + self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"{tool_name}_executor") + self._shutdown_event = threading.Event() + + # Initialize robot using lerobot's abstraction + self.robot = self._initialize_robot(robot, cameras, **kwargs) + + logger.info(f"🤖 {tool_name} initialized with async capabilities") + logger.info(f"📱 Robot: {self.robot.name} (type: {getattr(self.robot, 'robot_type', 'unknown')})") + logger.info(f"⏱️ Control frequency: {control_frequency}Hz ({self.action_sleep_time * 1000:.1f}ms per action)") + + # Get camera info if available + if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): + cameras_list = list(self.robot.config.cameras.keys()) + logger.info(f"📹 Cameras: {cameras_list}") + + if data_config: + logger.info(f"⚙️ Data config: {data_config}") + + def _initialize_robot( + self, robot: LeRobotRobot | RobotConfig | str, cameras: dict[str, dict[str, Any]] | None, **kwargs + ) -> LeRobotRobot: + """Initialize LeRobot robot instance using native lerobot patterns.""" + from lerobot.robots.config import RobotConfig + from lerobot.robots.robot import Robot as LeRobotRobot + from lerobot.robots.utils import make_robot_from_config + + # Direct robot instance - use as-is + if isinstance(robot, LeRobotRobot): + return robot + + # Robot config - use lerobot's factory + elif isinstance(robot, RobotConfig): + return make_robot_from_config(robot) + + # Robot type string - create config and use lerobot's factory + elif isinstance(robot, str): + config = self._create_minimal_config(robot, cameras, **kwargs) + return make_robot_from_config(config) + + else: + raise ValueError( + f"Unsupported robot type: {type(robot)}. " + f"Expected LeRobot Robot instance, RobotConfig, or robot type string." + ) + + def _create_minimal_config( + self, robot_type: str, cameras: dict[str, dict[str, Any]] | None, **kwargs + ) -> RobotConfig: + """Create minimal robot config using specific robot config classes.""" + from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig + + # Convert cameras to lerobot format + camera_configs = {} + if cameras: + for name, config in cameras.items(): + if config.get("type", "opencv") == "opencv": + camera_configs[name] = OpenCVCameraConfig( + index_or_path=config["index_or_path"], + fps=config.get("fps", 30), + width=config.get("width", 640), + height=config.get("height", 480), + rotation=config.get("rotation", 0), + color_mode=config.get("color_mode", "rgb"), + ) + else: + raise ValueError(f"Unsupported camera type: {config.get('type')}") + + # Map robot type to specific config class + config_mapping = { + "so101_follower": ("lerobot.robots.so101_follower", "SO101FollowerConfig"), + "so100_follower": ("lerobot.robots.so100_follower", "SO100FollowerConfig"), + "bi_so100_follower": ("lerobot.robots.bi_so100_follower", "BiSO100FollowerConfig"), + "viperx": ("lerobot.robots.viperx", "ViperXConfig"), + "koch_follower": ("lerobot.robots.koch_follower", "KochFollowerConfig"), + # Add more as needed + } + + if robot_type not in config_mapping: + raise ValueError(f"Unsupported robot type: {robot_type}. Supported types: {list(config_mapping.keys())}") + + # Import specific config class dynamically + module_name, class_name = config_mapping[robot_type] + try: + import importlib + + module = importlib.import_module(module_name) + ConfigClass = getattr(module, class_name) + except Exception as e: + raise ValueError(f"Failed to import {class_name} from {module_name}: {e}") + + # Create config with proper parameters + config_data = { + "id": self.tool_name_str, + "cameras": camera_configs, + } + + # Filter kwargs to only include supported fields for this robot type + # Port is common for most serial robots + if "port" in kwargs: + config_data["port"] = kwargs["port"] + + # Add other common fields as needed + for key in ["calibration_dir", "mock", "use_degrees"]: + if key in kwargs: + config_data[key] = kwargs[key] + + try: + return ConfigClass(**config_data) + except Exception as e: + raise ValueError(f"Failed to create {class_name} for robot type '{robot_type}': {e}. Config: {config_data}") + + async def _get_policy( + self, policy_port: int | None = None, policy_host: str = "localhost", policy_provider: str = "groot" + ) -> Policy: + """Create policy on-the-fly from invocation parameters.""" + from .policies import create_policy + + if not policy_port: + raise ValueError("policy_port is required for robot operation") + + policy_config = {"port": policy_port, "host": policy_host} + + if self.data_config: + policy_config["data_config"] = self.data_config + + return create_policy(policy_provider, **policy_config) + + async def _connect_robot(self) -> tuple[bool, str]: + """Connect to robot hardware with proper error handling. + + Returns: + tuple[bool, str]: (success, error_message) - error_message is empty on success + """ + try: + # Import lerobot exceptions + from lerobot.utils.errors import DeviceAlreadyConnectedError + + # Check if already connected + if self.robot.is_connected: + logger.info(f"✅ {self.robot} already connected") + return True, "" + + logger.info(f"🔌 Connecting to {self.robot}...") + + # Handle robot connection using lerobot's error handling patterns + try: + if not self.robot.is_connected: + await asyncio.to_thread(self.robot.connect, False) # calibrate=False + + except DeviceAlreadyConnectedError: + # This is expected and fine - robot is already connected + logger.info(f"✅ {self.robot} was already connected") + + except Exception as e: + # Check if it's the string version of "already connected" error + error_str = str(e).lower() + if "already connected" in error_str or "is already connected" in error_str: + logger.info(f"✅ {self.robot} connection already established") + else: + # Re-raise if it's a different error + raise e + + # Final connection check + if not self.robot.is_connected: + error_msg = f"Failed to connect to {self.robot}" + logger.error(f"❌ {error_msg}") + return False, error_msg + + # Check robot calibration + if hasattr(self.robot, "is_calibrated") and not self.robot.is_calibrated: + error_msg = ( + f"Robot {self.robot} is not calibrated. Please calibrate the robot manually" + " first using LeRobot's calibration process (lerobot-calibrate)" + ) + logger.error(f"❌ {error_msg}") + return False, error_msg + + logger.info(f"✅ {self.robot} connected and ready") + return True, "" + + except Exception as e: + error_msg = f"Robot connection failed: {e}. Ensure robot is calibrated and accessible on the specified port" + logger.error(f"❌ {error_msg}") + return False, error_msg + + async def _initialize_policy(self, policy: Policy) -> bool: + """Initialize policy with robot state keys.""" + try: + # Get robot state keys from observation + test_obs = await asyncio.to_thread(self.robot.get_observation) + + # Filter out camera keys to get robot state keys + camera_keys = [] + if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): + camera_keys = list(self.robot.config.cameras.keys()) + + robot_state_keys = [k for k in test_obs.keys() if k not in camera_keys] + + # Set robot state keys in policy + policy.set_robot_state_keys(robot_state_keys) + return True + + except Exception as e: + logger.error(f"❌ Failed to initialize policy: {e}") + return False + + async def _execute_task_async( + self, + instruction: str, + policy_port: int | None = None, + policy_host: str = "localhost", + policy_provider: str = "groot", + duration: float = 30.0, + ) -> None: + """Execute robot task in background thread (internal method).""" + try: + # Update task state + self._task_state.status = TaskStatus.CONNECTING + self._task_state.instruction = instruction + self._task_state.start_time = time.time() + self._task_state.step_count = 0 + self._task_state.error_message = "" + + # Connect to robot + connected, connect_error = await self._connect_robot() + if not connected: + self._task_state.status = TaskStatus.ERROR + self._task_state.error_message = connect_error or f"Failed to connect to {self.tool_name_str}" + return + + # Get policy instance + policy_instance = await self._get_policy(policy_port, policy_host, policy_provider) + + # Initialize policy with robot state keys + if not await self._initialize_policy(policy_instance): + self._task_state.status = TaskStatus.ERROR + self._task_state.error_message = "Failed to initialize policy" + return + + logger.info(f"🎯 Starting task: '{instruction}' on {self.tool_name_str}") + logger.info(f"🧠 Using policy: {policy_provider} on {policy_host}:{policy_port}") + + self._task_state.status = TaskStatus.RUNNING + start_time = time.time() + + while ( + time.time() - start_time < duration + and self._task_state.status == TaskStatus.RUNNING + and not self._shutdown_event.is_set() + ): + # Get observation from robot + observation = await asyncio.to_thread(self.robot.get_observation) + + # Get actions from policy + robot_actions = await policy_instance.get_actions(observation, instruction) + + # Execute actions from chunk with proper timing control + # Wait between actions for smooth execution + for action_dict in robot_actions[: self.action_horizon]: + if self._task_state.status != TaskStatus.RUNNING: + break + await asyncio.to_thread(self.robot.send_action, action_dict) + self._task_state.step_count += 1 + # Wait for action to complete before sending next action + # Default 50Hz (0.02s) + await asyncio.sleep(self.action_sleep_time) + + # Update final state + elapsed = time.time() - start_time + self._task_state.duration = elapsed + + if self._task_state.status == TaskStatus.RUNNING: + self._task_state.status = TaskStatus.COMPLETED + logger.info( + f"✅ Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)" + ) + + except Exception as e: + logger.error(f"❌ Task execution failed: {e}") + self._task_state.status = TaskStatus.ERROR + self._task_state.error_message = str(e) + + def _execute_task_sync( + self, + instruction: str, + policy_port: int | None = None, + policy_host: str = "localhost", + policy_provider: str = "groot", + duration: float = 30.0, + ) -> dict[str, Any]: + """Execute task synchronously in thread - no new event loop.""" + + # Import here to avoid conflicts + import asyncio + + # Run task without creating new event loop - let it run in thread + async def task_runner(): + await self._execute_task_async(instruction, policy_port, policy_host, policy_provider, duration) + + # Use asyncio.run only if no loop is running, otherwise run in existing loop + try: + # Try to get the current event loop + asyncio.get_running_loop() + # If we're already in an event loop, we need to run in a thread + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as exec: + future = exec.submit(lambda: asyncio.run(task_runner())) + future.result() # Wait for completion + except RuntimeError: + # No event loop running - safe to create one + asyncio.run(task_runner()) + + # Return final status + return { + "status": "success" if self._task_state.status == TaskStatus.COMPLETED else "error", + "content": [ + { + "text": f"✅ Task: '{instruction}' - {self._task_state.status.value}\n" + f"🤖 Robot: {self.tool_name_str} ({self.robot})\n" + f"🧠 Policy: {policy_provider} on {policy_host}:{policy_port}\n" + f"⏱️ Duration: {self._task_state.duration:.1f}s\n" + f"🎯 Steps: {self._task_state.step_count}" + + (f"\n❌ Error: {self._task_state.error_message}" if self._task_state.error_message else "") + } + ], + } + + def start_task( + self, + instruction: str, + policy_port: int | None = None, + policy_host: str = "localhost", + policy_provider: str = "groot", + duration: float = 30.0, + ) -> dict[str, Any]: + """Start robot task asynchronously and return immediately.""" + + # Check if task is already running + if self._task_state.status == TaskStatus.RUNNING: + return { + "status": "error", + "content": [{"text": f"❌ Task already running: {self._task_state.instruction}"}], + } + + # Start task in background + self._task_state.task_future = self._executor.submit( + self._execute_task_sync, instruction, policy_port, policy_host, policy_provider, duration + ) + + return { + "status": "success", + "content": [ + { + "text": f"🚀 Task started: '{instruction}'\n" + f"🤖 Robot: {self.tool_name_str}\n" + f"💡 Use action='status' to check progress\n" + f"💡 Use action='stop' to interrupt" + } + ], + } + + def get_task_status(self) -> dict[str, Any]: + """Get current task execution status.""" + + # Update duration for running tasks + if self._task_state.status == TaskStatus.RUNNING: + self._task_state.duration = time.time() - self._task_state.start_time + + status_text = f"📊 Robot Status: {self._task_state.status.value.upper()}\n" + + if self._task_state.instruction: + status_text += f"🎯 Task: {self._task_state.instruction}\n" + + if self._task_state.status == TaskStatus.RUNNING: + status_text += f"⏱️ Duration: {self._task_state.duration:.1f}s\n" + status_text += f"🔄 Steps: {self._task_state.step_count}\n" + elif self._task_state.status in [TaskStatus.COMPLETED, TaskStatus.STOPPED, TaskStatus.ERROR]: + status_text += f"⏱️ Total Duration: {self._task_state.duration:.1f}s\n" + status_text += f"🎯 Total Steps: {self._task_state.step_count}\n" + + if self._task_state.error_message: + status_text += f"❌ Error: {self._task_state.error_message}\n" + + return { + "status": "success", + "content": [{"text": status_text}], + } + + def stop_task(self) -> dict[str, Any]: + """Stop currently running task.""" + + if self._task_state.status != TaskStatus.RUNNING: + return { + "status": "success", + "content": [{"text": f"💤 No task running to stop (current: {self._task_state.status.value})"}], + } + + # Signal task to stop + self._task_state.status = TaskStatus.STOPPED + + # Cancel future if it exists + if self._task_state.task_future: + self._task_state.task_future.cancel() + + logger.info(f"🛑 Task stopped: {self._task_state.instruction}") + + return { + "status": "success", + "content": [ + { + "text": f"🛑 Task stopped: '{self._task_state.instruction}'\n" + f"⏱️ Duration: {self._task_state.duration:.1f}s\n" + f"🎯 Steps completed: {self._task_state.step_count}" + } + ], + } + + @property + def tool_name(self) -> str: + return self.tool_name_str + + @property + def tool_type(self) -> str: + return "robot" + + @property + def tool_spec(self) -> ToolSpec: + """Get tool specification with async actions.""" + return { + "name": self.tool_name_str, + "description": f"Universal robot control with async task execution ({self.robot}). " + f"Actions: execute (blocking), start (async), status, stop. " + f"For execute/start actions: instruction and policy_port are required. " + f"For status/stop actions: no additional parameters needed.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "action": { + "type": "string", + "description": "Action to perform: execute (blocking), start (async), status, stop", + "enum": ["execute", "start", "status", "stop"], + "default": "execute", + }, + "instruction": { + "type": "string", + "description": "Natural language instruction (required for execute/start actions)", + }, + "policy_port": { + "type": "integer", + "description": "Policy service port (required for execute/start actions)", + }, + "policy_host": { + "type": "string", + "description": "Policy service host (default: localhost)", + "default": "localhost", + }, + "policy_provider": { + "type": "string", + "description": "Policy provider (groot, openai, etc.)", + "default": "groot", + }, + "duration": { + "type": "number", + "description": "Maximum execution time in seconds", + "default": 30.0, + }, + }, + "required": ["action"], + } + }, + } + + @staticmethod + def _make_tool_result(tool_use_id: str, result: dict[str, Any]) -> ToolResult: + """Create a ToolResult dict with the given tool_use_id merged into result.""" + return cast(ToolResult, {"toolUseId": tool_use_id, **result}) + + async def stream( + self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any + ) -> AsyncGenerator[ToolResultEvent, None]: + """Stream robot task execution with async actions.""" + try: + tool_use_id = tool_use.get("toolUseId", "") + input_data = tool_use.get("input", {}) + + action = input_data.get("action", "execute") + + # Handle different actions + if action == "execute": + # Blocking execution (legacy behavior) + instruction = input_data.get("instruction", "") + policy_port = input_data.get("policy_port") + policy_host = input_data.get("policy_host", "localhost") + policy_provider = input_data.get("policy_provider", "groot") + duration = input_data.get("duration", 30.0) + + if not instruction or not policy_port: + yield ToolResultEvent( + self._make_tool_result( + tool_use_id, + { + "status": "error", + "content": [{"text": "❌ instruction and policy_port are required for execute action"}], + }, + ) + ) + return + + # Execute task synchronously + task_result = self._execute_task_sync(instruction, policy_port, policy_host, policy_provider, duration) + yield ToolResultEvent(self._make_tool_result(tool_use_id, task_result)) + + elif action == "start": + # Asynchronous execution start + instruction = input_data.get("instruction", "") + policy_port = input_data.get("policy_port") + policy_host = input_data.get("policy_host", "localhost") + policy_provider = input_data.get("policy_provider", "groot") + duration = input_data.get("duration", 30.0) + + if not instruction or not policy_port: + yield ToolResultEvent( + self._make_tool_result( + tool_use_id, + { + "status": "error", + "content": [{"text": "❌ instruction and policy_port are required for start action"}], + }, + ) + ) + return + + # Start task asynchronously + start_result = self.start_task(instruction, policy_port, policy_host, policy_provider, duration) + yield ToolResultEvent(self._make_tool_result(tool_use_id, start_result)) + + elif action == "status": + # Get current task status + status_result = self.get_task_status() + yield ToolResultEvent(self._make_tool_result(tool_use_id, status_result)) + + elif action == "stop": + # Stop current task + stop_result = self.stop_task() + yield ToolResultEvent(self._make_tool_result(tool_use_id, stop_result)) + + else: + yield ToolResultEvent( + self._make_tool_result( + tool_use_id, + { + "status": "error", + "content": [ + {"text": f"❌ Unknown action: {action}. Valid actions: execute, start, status, stop"} + ], + }, + ) + ) + + except Exception as e: + logger.error(f"❌ {self.tool_name_str} error: {e}") + yield ToolResultEvent( + self._make_tool_result( + tool_use_id, + { + "status": "error", + "content": [{"text": f"❌ {self.tool_name_str} error: {str(e)}"}], + }, + ) + ) + + def cleanup(self): + """Cleanup resources and stop any running tasks.""" + try: + # Signal shutdown + self._shutdown_event.set() + + # Stop any running task + if self._task_state.status == TaskStatus.RUNNING: + self.stop_task() + + # Shutdown executor + self._executor.shutdown(wait=True) + + logger.info(f"🧹 {self.tool_name_str} cleanup completed") + + except Exception as e: + logger.error(f"❌ Cleanup error for {self.tool_name_str}: {e}") + + def __del__(self): + """Destructor to ensure cleanup.""" + try: + self.cleanup() + except Exception: + pass # Ignore errors in destructor + + async def get_status(self) -> dict[str, Any]: + """Get robot status including connection and task state.""" + try: + # Get robot connection status + is_connected = self.robot.is_connected if hasattr(self.robot, "is_connected") else False + is_calibrated = self.robot.is_calibrated if hasattr(self.robot, "is_calibrated") else True + + # Get camera status + camera_status = [] + if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): + for name in self.robot.config.cameras.keys(): + camera_status.append(name) + + # Build status dict + status_data = { + "robot_name": self.tool_name_str, + "robot_type": getattr(self.robot, "robot_type", self.robot.name), + "robot_info": str(self.robot), + "data_config": self.data_config, + "is_connected": is_connected, + "is_calibrated": is_calibrated, + "cameras": camera_status, + "task_status": self._task_state.status.value, + "current_instruction": self._task_state.instruction, + "task_duration": self._task_state.duration, + "task_steps": self._task_state.step_count, + } + + # Add error info if present + if self._task_state.error_message: + status_data["task_error"] = self._task_state.error_message + + return status_data + + except Exception as e: + logger.error(f"❌ Error getting status for {self.tool_name_str}: {e}") + return { + "robot_name": self.tool_name_str, + "error": str(e), + "is_connected": False, + "task_status": "error", + } + + async def stop(self): + """Stop robot and disconnect.""" + try: + # Stop any running task first + if self._task_state.status == TaskStatus.RUNNING: + self.stop_task() + + # Disconnect robot hardware + if hasattr(self.robot, "disconnect"): + await asyncio.to_thread(self.robot.disconnect) + + # Cleanup resources + self.cleanup() + + logger.info(f"🛑 {self.tool_name_str} stopped and disconnected") + + except Exception as e: + logger.error(f"❌ Error stopping robot: {e}") diff --git a/strands_robots/robot.py b/strands_robots/robot.py index 33c5ba7..fa4a364 100644 --- a/strands_robots/robot.py +++ b/strands_robots/robot.py @@ -1,758 +1,205 @@ -#!/usr/bin/env python3 -""" -Universal Robot Control with Policy Abstraction for Any VLA Provider - -This module provides a clean robot interface that works with any LeRobot-compatible -robot and any VLA provider through the Policy abstraction. - -Features: -- Async robot task execution with real-time status reporting -- Non-blocking operations - robot moves while tool returns status -- Stop functionality to interrupt running tasks -- Connection state management with proper error handling -- Policy abstraction for any VLA provider -""" - -from __future__ import annotations - -import asyncio -import logging -import threading -import time -from collections.abc import AsyncGenerator -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass -from enum import Enum -from typing import TYPE_CHECKING, Any, cast - -from strands.tools.tools import AgentTool -from strands.types._events import ToolResultEvent -from strands.types.tools import ToolResult, ToolSpec, ToolUse +"""Unified Robot factory — convenience layer over ``strands_robots.simulation`` +and ``strands_robots.hardware_robot``. -if TYPE_CHECKING: - from lerobot.robots.config import RobotConfig - from lerobot.robots.robot import Robot as LeRobotRobot +Provides: + - ``Robot("so100")`` → returns a simulation by default (safe) + - ``Robot("so100", mode="real")`` → explicit real hardware + - ``Robot("so100", mode="auto")`` → auto-detects sim/real + - ``list_robots()`` → what's available - from .policies import Policy +Environment Variables: + STRANDS_ROBOT_MODE: Override mode detection ("sim" or "real"). -logger = logging.getLogger(__name__) - - -class TaskStatus(Enum): - """Robot task execution status""" - - IDLE = "idle" - CONNECTING = "connecting" - RUNNING = "running" - COMPLETED = "completed" - STOPPED = "stopped" - ERROR = "error" - - -@dataclass -class RobotTaskState: - """Robot task execution state""" - - status: TaskStatus = TaskStatus.IDLE - instruction: str = "" - start_time: float = 0.0 - duration: float = 0.0 - step_count: int = 0 - error_message: str = "" - task_future: Future | None = None - - -class Robot(AgentTool): - """Universal robot control with async task execution and status reporting.""" - - def __init__( - self, - tool_name: str, - robot: LeRobotRobot | RobotConfig | str, - cameras: dict[str, dict[str, Any]] | None = None, - action_horizon: int = 8, - data_config: str | Any | None = None, - control_frequency: float = 50.0, - **kwargs, - ): - """Initialize Robot with async capabilities. - - Args: - tool_name: Name for this robot tool - robot: LeRobot Robot instance, RobotConfig, or robot type string - cameras: Camera configuration dict: - {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}} - action_horizon: Actions per inference step - data_config: Data configuration (for GR00T compatibility) - control_frequency: Control loop frequency in Hz (default: 50Hz) - **kwargs: Robot-specific parameters (port, etc.) - """ - super().__init__() - - self.tool_name_str = tool_name - self.action_horizon = action_horizon - self.data_config = data_config - self.control_frequency = control_frequency - self.action_sleep_time = 1.0 / control_frequency # Time between actions - - # Task execution state - self._task_state = RobotTaskState() - self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"{tool_name}_executor") - self._shutdown_event = threading.Event() - - # Initialize robot using lerobot's abstraction - self.robot = self._initialize_robot(robot, cameras, **kwargs) - - logger.info(f"🤖 {tool_name} initialized with async capabilities") - logger.info(f"📱 Robot: {self.robot.name} (type: {getattr(self.robot, 'robot_type', 'unknown')})") - logger.info(f"⏱️ Control frequency: {control_frequency}Hz ({self.action_sleep_time * 1000:.1f}ms per action)") - - # Get camera info if available - if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): - cameras_list = list(self.robot.config.cameras.keys()) - logger.info(f"📹 Cameras: {cameras_list}") - - if data_config: - logger.info(f"⚙️ Data config: {data_config}") - - def _initialize_robot( - self, robot: LeRobotRobot | RobotConfig | str, cameras: dict[str, dict[str, Any]] | None, **kwargs - ) -> LeRobotRobot: - """Initialize LeRobot robot instance using native lerobot patterns.""" - from lerobot.robots.config import RobotConfig - from lerobot.robots.robot import Robot as LeRobotRobot - from lerobot.robots.utils import make_robot_from_config - - # Direct robot instance - use as-is - if isinstance(robot, LeRobotRobot): - return robot - - # Robot config - use lerobot's factory - elif isinstance(robot, RobotConfig): - return make_robot_from_config(robot) - - # Robot type string - create config and use lerobot's factory - elif isinstance(robot, str): - config = self._create_minimal_config(robot, cameras, **kwargs) - return make_robot_from_config(config) - - else: - raise ValueError( - f"Unsupported robot type: {type(robot)}. " - f"Expected LeRobot Robot instance, RobotConfig, or robot type string." - ) +Examples:: - def _create_minimal_config( - self, robot_type: str, cameras: dict[str, dict[str, Any]] | None, **kwargs - ) -> RobotConfig: - """Create minimal robot config using specific robot config classes.""" - from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig - - # Convert cameras to lerobot format - camera_configs = {} - if cameras: - for name, config in cameras.items(): - if config.get("type", "opencv") == "opencv": - camera_configs[name] = OpenCVCameraConfig( - index_or_path=config["index_or_path"], - fps=config.get("fps", 30), - width=config.get("width", 640), - height=config.get("height", 480), - rotation=config.get("rotation", 0), - color_mode=config.get("color_mode", "rgb"), - ) - else: - raise ValueError(f"Unsupported camera type: {config.get('type')}") - - # Map robot type to specific config class - config_mapping = { - "so101_follower": ("lerobot.robots.so101_follower", "SO101FollowerConfig"), - "so100_follower": ("lerobot.robots.so100_follower", "SO100FollowerConfig"), - "bi_so100_follower": ("lerobot.robots.bi_so100_follower", "BiSO100FollowerConfig"), - "viperx": ("lerobot.robots.viperx", "ViperXConfig"), - "koch_follower": ("lerobot.robots.koch_follower", "KochFollowerConfig"), - # Add more as needed - } - - if robot_type not in config_mapping: - raise ValueError(f"Unsupported robot type: {robot_type}. Supported types: {list(config_mapping.keys())}") + # Default: simulation (safe — no physical hardware interaction) + sim = Robot("so100") - # Import specific config class dynamically - module_name, class_name = config_mapping[robot_type] - try: - import importlib + # Explicit real hardware + hw = Robot("so100", mode="real", cameras={...}) - module = importlib.import_module(module_name) - ConfigClass = getattr(module, class_name) - except Exception as e: - raise ValueError(f"Failed to import {class_name} from {module_name}: {e}") + # Auto-detect (probes USB for servo controllers) + robot = Robot("so100", mode="auto") - # Create config with proper parameters - config_data = { - "id": self.tool_name_str, - "cameras": camera_configs, - } + # With custom URDF/MJCF path + sim = Robot("my_arm", urdf_path="/path/to/robot.xml") - # Filter kwargs to only include supported fields for this robot type - # Port is common for most serial robots - if "port" in kwargs: - config_data["port"] = kwargs["port"] +Future (not yet implemented):: - # Add other common fields as needed - for key in ["calibration_dir", "mock", "use_degrees"]: - if key in kwargs: - config_data[key] = kwargs[key] - - try: - return ConfigClass(**config_data) - except Exception as e: - raise ValueError(f"Failed to create {class_name} for robot type '{robot_type}': {e}. Config: {config_data}") + sim = Robot("unitree_go2", backend="isaac", num_envs=4096) + sim = Robot("so100", backend="newton", num_envs=4096) +""" - async def _get_policy( - self, policy_port: int | None = None, policy_host: str = "localhost", policy_provider: str = "groot" - ) -> Policy: - """Create policy on-the-fly from invocation parameters.""" - from .policies import create_policy +import logging +import os +from typing import Any - if not policy_port: - raise ValueError("policy_port is required for robot operation") +from strands_robots.registry import ( + get_hardware_type, + has_hardware, + resolve_name, +) - policy_config = {"port": policy_port, "host": policy_host} +logger = logging.getLogger(__name__) - if self.data_config: - policy_config["data_config"] = self.data_config - return create_policy(policy_provider, **policy_config) +def _auto_detect_mode(canonical: str) -> str: + """Auto-detect sim vs real mode. - async def _connect_robot(self) -> tuple[bool, str]: - """Connect to robot hardware with proper error handling. + Priority: + 1. ``STRANDS_ROBOT_MODE`` env var (explicit override) + 2. Robot-specific USB detection (Feetech/Dynamixel servo controllers) + 3. Default to sim (safest — never accidentally send commands to hardware) + """ + env_mode = os.getenv("STRANDS_ROBOT_MODE", "").lower() + if env_mode in ("sim", "real"): + return env_mode - Returns: - tuple[bool, str]: (success, error_message) - error_message is empty on success - """ + # Only probe USB if the robot actually has hardware support + if has_hardware(canonical): try: - # Import lerobot exceptions - from lerobot.utils.errors import DeviceAlreadyConnectedError - - # Check if already connected - if self.robot.is_connected: - logger.info(f"✅ {self.robot} already connected") - return True, "" - - logger.info(f"🔌 Connecting to {self.robot}...") - - # Handle robot connection using lerobot's error handling patterns - try: - if not self.robot.is_connected: - await asyncio.to_thread(self.robot.connect, False) # calibrate=False - - except DeviceAlreadyConnectedError: - # This is expected and fine - robot is already connected - logger.info(f"✅ {self.robot} was already connected") - - except Exception as e: - # Check if it's the string version of "already connected" error - error_str = str(e).lower() - if "already connected" in error_str or "is already connected" in error_str: - logger.info(f"✅ {self.robot} connection already established") - else: - # Re-raise if it's a different error - raise e - - # Final connection check - if not self.robot.is_connected: - error_msg = f"Failed to connect to {self.robot}" - logger.error(f"❌ {error_msg}") - return False, error_msg - - # Check robot calibration - if hasattr(self.robot, "is_calibrated") and not self.robot.is_calibrated: - error_msg = ( - f"Robot {self.robot} is not calibrated. Please calibrate the robot manually" - " first using LeRobot's calibration process (lerobot-calibrate)" + import serial.tools.list_ports + + ports = list(serial.tools.list_ports.comports()) + servo_keywords = ["feetech", "dynamixel", "sts3215", "xl430", "xl330"] + exclude = ["bluetooth", "internal", "debug", "apple", "modem"] + robot_ports = [ + p + for p in ports + if any( + kw in ((p.description or "") + (getattr(p, "manufacturer", None) or "")).lower() + for kw in servo_keywords ) - logger.error(f"❌ {error_msg}") - return False, error_msg - - logger.info(f"✅ {self.robot} connected and ready") - return True, "" - - except Exception as e: - error_msg = f"Robot connection failed: {e}. Ensure robot is calibrated and accessible on the specified port" - logger.error(f"❌ {error_msg}") - return False, error_msg - - async def _initialize_policy(self, policy: Policy) -> bool: - """Initialize policy with robot state keys.""" - try: - # Get robot state keys from observation - test_obs = await asyncio.to_thread(self.robot.get_observation) - - # Filter out camera keys to get robot state keys - camera_keys = [] - if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): - camera_keys = list(self.robot.config.cameras.keys()) - - robot_state_keys = [k for k in test_obs.keys() if k not in camera_keys] - - # Set robot state keys in policy - policy.set_robot_state_keys(robot_state_keys) - return True - - except Exception as e: - logger.error(f"❌ Failed to initialize policy: {e}") - return False - - async def _execute_task_async( - self, - instruction: str, - policy_port: int | None = None, - policy_host: str = "localhost", - policy_provider: str = "groot", - duration: float = 30.0, - ) -> None: - """Execute robot task in background thread (internal method).""" - try: - # Update task state - self._task_state.status = TaskStatus.CONNECTING - self._task_state.instruction = instruction - self._task_state.start_time = time.time() - self._task_state.step_count = 0 - self._task_state.error_message = "" - - # Connect to robot - connected, connect_error = await self._connect_robot() - if not connected: - self._task_state.status = TaskStatus.ERROR - self._task_state.error_message = connect_error or f"Failed to connect to {self.tool_name_str}" - return - - # Get policy instance - policy_instance = await self._get_policy(policy_port, policy_host, policy_provider) - - # Initialize policy with robot state keys - if not await self._initialize_policy(policy_instance): - self._task_state.status = TaskStatus.ERROR - self._task_state.error_message = "Failed to initialize policy" - return - - logger.info(f"🎯 Starting task: '{instruction}' on {self.tool_name_str}") - logger.info(f"🧠 Using policy: {policy_provider} on {policy_host}:{policy_port}") - - self._task_state.status = TaskStatus.RUNNING - start_time = time.time() - - while ( - time.time() - start_time < duration - and self._task_state.status == TaskStatus.RUNNING - and not self._shutdown_event.is_set() - ): - # Get observation from robot - observation = await asyncio.to_thread(self.robot.get_observation) - - # Get actions from policy - robot_actions = await policy_instance.get_actions(observation, instruction) - - # Execute actions from chunk with proper timing control - # Wait between actions for smooth execution - for action_dict in robot_actions[: self.action_horizon]: - if self._task_state.status != TaskStatus.RUNNING: - break - await asyncio.to_thread(self.robot.send_action, action_dict) - self._task_state.step_count += 1 - # Wait for action to complete before sending next action - # Default 50Hz (0.02s) - await asyncio.sleep(self.action_sleep_time) - - # Update final state - elapsed = time.time() - start_time - self._task_state.duration = elapsed - - if self._task_state.status == TaskStatus.RUNNING: - self._task_state.status = TaskStatus.COMPLETED + and not any(s in (p.description or "").lower() for s in exclude) + ] + if robot_ports: logger.info( - f"✅ Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)" + "Auto-detected robot hardware: %s", + [p.device for p in robot_ports], ) + return "real" + except (ImportError, OSError): # USB probing may fail with OSError on permission/device issues + pass + + return "sim" + + +def Robot( + name: str, + mode: str = "sim", + backend: str = "mujoco", + urdf_path: str | None = None, + cameras: dict[str, dict[str, Any]] | None = None, + position: list[float] | None = None, + **kwargs: Any, +) -> Any: + """Create a robot — returns a Simulation or HardwareRobot instance. + + This is a convenience factory, NOT a wrapper class. You get the real + backend instance back — with full access to all its methods. + + Defaults to simulation mode so that ``Robot("so100")`` never + accidentally sends commands to physical hardware. Use + ``mode="real"`` to explicitly opt into hardware control. + + Args: + name: Robot name ("so100", "aloha", "unitree_g1", "panda", ...) + Accepts any alias defined in ``registry/robots.json``. + mode: "sim" (default — safe), "real" (explicit hardware), or + "auto" (probes USB for servo controllers, falls back to sim). + backend: Simulation backend — currently only "mujoco" (CPU). + Future: "isaac" (GPU), "newton" (GPU). + urdf_path: Explicit path to URDF/MJCF file. If not provided, + resolved via ``strands_robots.simulation.model_registry`` + (asset manager or ``STRANDS_ASSETS_DIR`` search paths). + cameras: Camera config for real hardware. Example:: + + {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}} + + position: Robot position in sim world [x, y, z]. + **kwargs: Forwarded to the underlying backend constructor. + + Returns: + ``strands_robots.simulation.Simulation`` (sim) or + ``strands_robots.hardware_robot.Robot`` (real hardware). + + Raises: + RuntimeError: If the sim world or robot fails to initialize. + NotImplementedError: If an unimplemented backend is requested. + + Examples:: + + # Simulation (default — safe) + sim = Robot("so100") + + # Explicit MJCF model path + sim = Robot("my_arm", urdf_path="path/to/robot.xml") + + # Real hardware (explicit opt-in) + hw = Robot("so100", mode="real", cameras={...}) + + # Auto-detect (probes USB, falls back to sim) + robot = Robot("so100", mode="auto") + + # The 5-line promise + from strands_robots import Robot + from strands import Agent + robot = Robot("so100") + agent = Agent(tools=[robot]) + agent("Pick up the red cube") + """ + canonical = resolve_name(name) + + if mode == "auto": + mode = _auto_detect_mode(canonical) + + # ── Simulation ── + if mode == "sim": + if backend != "mujoco": + raise NotImplementedError( + f"Backend {backend!r} is not yet implemented. " + f"Currently supported: 'mujoco'. " + f"Isaac and Newton backends are on the roadmap." + ) - except Exception as e: - logger.error(f"❌ Task execution failed: {e}") - self._task_state.status = TaskStatus.ERROR - self._task_state.error_message = str(e) - - def _execute_task_sync( - self, - instruction: str, - policy_port: int | None = None, - policy_host: str = "localhost", - policy_provider: str = "groot", - duration: float = 30.0, - ) -> dict[str, Any]: - """Execute task synchronously in thread - no new event loop.""" - - # Import here to avoid conflicts - import asyncio - - # Run task without creating new event loop - let it run in thread - async def task_runner(): - await self._execute_task_async(instruction, policy_port, policy_host, policy_provider, duration) - - # Use asyncio.run only if no loop is running, otherwise run in existing loop - try: - # Try to get the current event loop - asyncio.get_running_loop() - # If we're already in an event loop, we need to run in a thread - import concurrent.futures - - with concurrent.futures.ThreadPoolExecutor() as exec: - future = exec.submit(lambda: asyncio.run(task_runner())) - future.result() # Wait for completion - except RuntimeError: - # No event loop running - safe to create one - asyncio.run(task_runner()) - - # Return final status - return { - "status": "success" if self._task_state.status == TaskStatus.COMPLETED else "error", - "content": [ - { - "text": f"✅ Task: '{instruction}' - {self._task_state.status.value}\n" - f"🤖 Robot: {self.tool_name_str} ({self.robot})\n" - f"🧠 Policy: {policy_provider} on {policy_host}:{policy_port}\n" - f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - f"🎯 Steps: {self._task_state.step_count}" - + (f"\n❌ Error: {self._task_state.error_message}" if self._task_state.error_message else "") - } - ], - } + from strands_robots.simulation import Simulation - def start_task( - self, - instruction: str, - policy_port: int | None = None, - policy_host: str = "localhost", - policy_provider: str = "groot", - duration: float = 30.0, - ) -> dict[str, Any]: - """Start robot task asynchronously and return immediately.""" - - # Check if task is already running - if self._task_state.status == TaskStatus.RUNNING: - return { - "status": "error", - "content": [{"text": f"❌ Task already running: {self._task_state.instruction}"}], - } - - # Start task in background - self._task_state.task_future = self._executor.submit( - self._execute_task_sync, instruction, policy_port, policy_host, policy_provider, duration + sim = Simulation( + tool_name=f"{canonical}_sim", + **kwargs, ) + sim._dispatch_action("create_world", {}) - return { - "status": "success", - "content": [ - { - "text": f"🚀 Task started: '{instruction}'\n" - f"🤖 Robot: {self.tool_name_str}\n" - f"💡 Use action='status' to check progress\n" - f"💡 Use action='stop' to interrupt" - } - ], + add_robot_params: dict[str, Any] = { + "robot_name": canonical, + "data_config": canonical, + "position": position or [0.0, 0.0, 0.0], } + if urdf_path: + add_robot_params["urdf_path"] = urdf_path + + result = sim._dispatch_action("add_robot", add_robot_params) + if result.get("status") == "error": + sim.destroy() # Clean up partial initialization (executor, temp dir, MuJoCo world) + content = result.get("content", []) + msg = content[0].get("text", str(result)) if content else str(result) + raise RuntimeError(f"Failed to create sim robot '{canonical}': {msg}") + return sim + + # ── Real hardware (explicit opt-in) ── + elif mode == "real": + from strands_robots.hardware_robot import Robot as HardwareRobot + + real_type = get_hardware_type(canonical) or canonical + return HardwareRobot( + tool_name=canonical, + robot=real_type, + cameras=cameras, + **kwargs, + ) - def get_task_status(self) -> dict[str, Any]: - """Get current task execution status.""" - - # Update duration for running tasks - if self._task_state.status == TaskStatus.RUNNING: - self._task_state.duration = time.time() - self._task_state.start_time - - status_text = f"📊 Robot Status: {self._task_state.status.value.upper()}\n" - - if self._task_state.instruction: - status_text += f"🎯 Task: {self._task_state.instruction}\n" - - if self._task_state.status == TaskStatus.RUNNING: - status_text += f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - status_text += f"🔄 Steps: {self._task_state.step_count}\n" - elif self._task_state.status in [TaskStatus.COMPLETED, TaskStatus.STOPPED, TaskStatus.ERROR]: - status_text += f"⏱️ Total Duration: {self._task_state.duration:.1f}s\n" - status_text += f"🎯 Total Steps: {self._task_state.step_count}\n" - - if self._task_state.error_message: - status_text += f"❌ Error: {self._task_state.error_message}\n" - - return { - "status": "success", - "content": [{"text": status_text}], - } - - def stop_task(self) -> dict[str, Any]: - """Stop currently running task.""" - - if self._task_state.status != TaskStatus.RUNNING: - return { - "status": "success", - "content": [{"text": f"💤 No task running to stop (current: {self._task_state.status.value})"}], - } - - # Signal task to stop - self._task_state.status = TaskStatus.STOPPED - - # Cancel future if it exists - if self._task_state.task_future: - self._task_state.task_future.cancel() - - logger.info(f"🛑 Task stopped: {self._task_state.instruction}") - - return { - "status": "success", - "content": [ - { - "text": f"🛑 Task stopped: '{self._task_state.instruction}'\n" - f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - f"🎯 Steps completed: {self._task_state.step_count}" - } - ], - } - - @property - def tool_name(self) -> str: - return self.tool_name_str - - @property - def tool_type(self) -> str: - return "robot" - - @property - def tool_spec(self) -> ToolSpec: - """Get tool specification with async actions.""" - return { - "name": self.tool_name_str, - "description": f"Universal robot control with async task execution ({self.robot}). " - f"Actions: execute (blocking), start (async), status, stop. " - f"For execute/start actions: instruction and policy_port are required. " - f"For status/stop actions: no additional parameters needed.", - "inputSchema": { - "json": { - "type": "object", - "properties": { - "action": { - "type": "string", - "description": "Action to perform: execute (blocking), start (async), status, stop", - "enum": ["execute", "start", "status", "stop"], - "default": "execute", - }, - "instruction": { - "type": "string", - "description": "Natural language instruction (required for execute/start actions)", - }, - "policy_port": { - "type": "integer", - "description": "Policy service port (required for execute/start actions)", - }, - "policy_host": { - "type": "string", - "description": "Policy service host (default: localhost)", - "default": "localhost", - }, - "policy_provider": { - "type": "string", - "description": "Policy provider (groot, openai, etc.)", - "default": "groot", - }, - "duration": { - "type": "number", - "description": "Maximum execution time in seconds", - "default": 30.0, - }, - }, - "required": ["action"], - } - }, - } - - @staticmethod - def _make_tool_result(tool_use_id: str, result: dict[str, Any]) -> ToolResult: - """Create a ToolResult dict with the given tool_use_id merged into result.""" - return cast(ToolResult, {"toolUseId": tool_use_id, **result}) - - async def stream( - self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any - ) -> AsyncGenerator[ToolResultEvent, None]: - """Stream robot task execution with async actions.""" - try: - tool_use_id = tool_use.get("toolUseId", "") - input_data = tool_use.get("input", {}) - - action = input_data.get("action", "execute") - - # Handle different actions - if action == "execute": - # Blocking execution (legacy behavior) - instruction = input_data.get("instruction", "") - policy_port = input_data.get("policy_port") - policy_host = input_data.get("policy_host", "localhost") - policy_provider = input_data.get("policy_provider", "groot") - duration = input_data.get("duration", 30.0) - - if not instruction or not policy_port: - yield ToolResultEvent( - self._make_tool_result( - tool_use_id, - { - "status": "error", - "content": [{"text": "❌ instruction and policy_port are required for execute action"}], - }, - ) - ) - return - - # Execute task synchronously - task_result = self._execute_task_sync(instruction, policy_port, policy_host, policy_provider, duration) - yield ToolResultEvent(self._make_tool_result(tool_use_id, task_result)) - - elif action == "start": - # Asynchronous execution start - instruction = input_data.get("instruction", "") - policy_port = input_data.get("policy_port") - policy_host = input_data.get("policy_host", "localhost") - policy_provider = input_data.get("policy_provider", "groot") - duration = input_data.get("duration", 30.0) - - if not instruction or not policy_port: - yield ToolResultEvent( - self._make_tool_result( - tool_use_id, - { - "status": "error", - "content": [{"text": "❌ instruction and policy_port are required for start action"}], - }, - ) - ) - return - - # Start task asynchronously - start_result = self.start_task(instruction, policy_port, policy_host, policy_provider, duration) - yield ToolResultEvent(self._make_tool_result(tool_use_id, start_result)) - - elif action == "status": - # Get current task status - status_result = self.get_task_status() - yield ToolResultEvent(self._make_tool_result(tool_use_id, status_result)) - - elif action == "stop": - # Stop current task - stop_result = self.stop_task() - yield ToolResultEvent(self._make_tool_result(tool_use_id, stop_result)) - - else: - yield ToolResultEvent( - self._make_tool_result( - tool_use_id, - { - "status": "error", - "content": [ - {"text": f"❌ Unknown action: {action}. Valid actions: execute, start, status, stop"} - ], - }, - ) - ) - - except Exception as e: - logger.error(f"❌ {self.tool_name_str} error: {e}") - yield ToolResultEvent( - self._make_tool_result( - tool_use_id, - { - "status": "error", - "content": [{"text": f"❌ {self.tool_name_str} error: {str(e)}"}], - }, - ) - ) - - def cleanup(self): - """Cleanup resources and stop any running tasks.""" - try: - # Signal shutdown - self._shutdown_event.set() - - # Stop any running task - if self._task_state.status == TaskStatus.RUNNING: - self.stop_task() - - # Shutdown executor - self._executor.shutdown(wait=True) - - logger.info(f"🧹 {self.tool_name_str} cleanup completed") - - except Exception as e: - logger.error(f"❌ Cleanup error for {self.tool_name_str}: {e}") - - def __del__(self): - """Destructor to ensure cleanup.""" - try: - self.cleanup() - except Exception: - pass # Ignore errors in destructor - - async def get_status(self) -> dict[str, Any]: - """Get robot status including connection and task state.""" - try: - # Get robot connection status - is_connected = self.robot.is_connected if hasattr(self.robot, "is_connected") else False - is_calibrated = self.robot.is_calibrated if hasattr(self.robot, "is_calibrated") else True - - # Get camera status - camera_status = [] - if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): - for name in self.robot.config.cameras.keys(): - camera_status.append(name) - - # Build status dict - status_data = { - "robot_name": self.tool_name_str, - "robot_type": getattr(self.robot, "robot_type", self.robot.name), - "robot_info": str(self.robot), - "data_config": self.data_config, - "is_connected": is_connected, - "is_calibrated": is_calibrated, - "cameras": camera_status, - "task_status": self._task_state.status.value, - "current_instruction": self._task_state.instruction, - "task_duration": self._task_state.duration, - "task_steps": self._task_state.step_count, - } - - # Add error info if present - if self._task_state.error_message: - status_data["task_error"] = self._task_state.error_message - - return status_data - - except Exception as e: - logger.error(f"❌ Error getting status for {self.tool_name_str}: {e}") - return { - "robot_name": self.tool_name_str, - "error": str(e), - "is_connected": False, - "task_status": "error", - } - - async def stop(self): - """Stop robot and disconnect.""" - try: - # Stop any running task first - if self._task_state.status == TaskStatus.RUNNING: - self.stop_task() - - # Disconnect robot hardware - if hasattr(self.robot, "disconnect"): - await asyncio.to_thread(self.robot.disconnect) - - # Cleanup resources - self.cleanup() + else: + raise ValueError(f"Invalid mode {mode!r}. Choose 'sim', 'real', or 'auto'.") - logger.info(f"🛑 {self.tool_name_str} stopped and disconnected") - except Exception as e: - logger.error(f"❌ Error stopping robot: {e}") +__all__ = ["Robot"] diff --git a/tests/test_factory.py b/tests/test_robot_factory.py similarity index 89% rename from tests/test_factory.py rename to tests/test_robot_factory.py index 17d8b11..5f76901 100644 --- a/tests/test_factory.py +++ b/tests/test_robot_factory.py @@ -1,16 +1,16 @@ -"""Tests for strands_robots.factory — Robot(), list_robots().""" +"""Tests for strands_robots.robot — Robot() factory and list_robots().""" import os import pytest -from strands_robots.factory import Robot, _auto_detect_mode, list_robots from strands_robots.registry import ( get_robot, list_aliases, + list_robots, resolve_name, ) -from strands_robots.registry import list_robots as registry_list_robots +from strands_robots.robot import Robot, _auto_detect_mode class TestResolveNames: @@ -78,11 +78,11 @@ def test_all_aliases_point_to_valid_robots(self): def test_robot_count(self): """Ensure we have a reasonable number of robots.""" - robots = registry_list_robots() + robots = list_robots() assert len(robots) >= 30 def test_all_robots_have_description(self): - robots = registry_list_robots() + robots = list_robots() for r in robots: assert "description" in r, f"Robot '{r['name']}' missing description" assert len(r["description"]) > 0 @@ -110,9 +110,8 @@ def test_env_override_sim(self): def test_env_override_case_insensitive(self): os.environ["STRANDS_ROBOT_MODE"] = "REAL" try: - # .lower() normalizes to "real" — should match mode = _auto_detect_mode("so100") - assert mode == "real" # .lower() normalizes REAL → real + assert mode == "real" finally: del os.environ["STRANDS_ROBOT_MODE"] @@ -125,6 +124,13 @@ def test_robot_is_callable(self): assert callable(Robot) assert not inspect.isclass(Robot) + def test_default_mode_is_sim(self): + """Robot() defaults to sim mode — never accidentally sends to hardware.""" + import inspect + + sig = inspect.signature(Robot) + assert sig.parameters["mode"].default == "sim" + def test_unknown_backend_raises(self): with pytest.raises(NotImplementedError, match="not yet implemented"): Robot("so100", mode="sim", backend="isaac") @@ -133,9 +139,13 @@ def test_newton_not_implemented(self): with pytest.raises(NotImplementedError, match="not yet implemented"): Robot("so100", mode="sim", backend="newton") + def test_invalid_mode_raises(self): + with pytest.raises(ValueError, match="Invalid mode"): + Robot("so100", mode="invalid") + def test_sim_with_urdf_path(self): """Robot() with explicit urdf_path should work (if file exists).""" - # We don't have a real URDF here, but verify the param is accepted + pytest.importorskip("mujoco") with pytest.raises(RuntimeError): Robot("test_bot", mode="sim", urdf_path="/nonexistent/robot.xml") @@ -146,7 +156,6 @@ def test_sim_happy_path_mujoco(self, tmp_path): """ mujoco = pytest.importorskip("mujoco") - # Minimal valid MJCF that MuJoCo can load — a one-joint arm mjcf_xml = """ @@ -170,11 +179,9 @@ def test_sim_happy_path_mujoco(self, tmp_path): sim = Robot("so100", mode="sim", backend="mujoco", urdf_path=str(mjcf_path)) try: - # Verify it's a working simulation instance assert sim._world is not None assert sim._world._model is not None assert sim._world._data is not None - # Step physics once to verify the engine works mujoco.mj_step(sim._world._model, sim._world._data) assert sim._world._data.time > 0 finally: @@ -186,4 +193,4 @@ def test_import_from_top_level(self): from strands_robots import list_robots as lr assert R is Robot - assert lr is list_robots + assert callable(lr) From ac5533179ae62a5b379c10e5968e323b68929200 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Thu, 16 Apr 2026 00:36:31 +0000 Subject: [PATCH 21/22] docs: address 3 unresolved review threads from @awsarron on PR #86 - Add Environment Variables table to README documenting all 6 env vars used across the project (STRANDS_ROBOT_MODE, STRANDS_ASSETS_DIR, STRANDS_URDF_DIR, STRANDS_TRUST_REMOTE_CODE, GROOT_API_TOKEN, MUJOCO_GL) plus cache directory documentation - Add module-level docstring to dataset_recorder.py explaining why it lives at package root (shared by both hardware and simulation paths, avoids circular dependency) - Add docstring to load_lerobot_episode() documenting that it is consumed by simulation.mujoco.policy_runner for replay_episode --- README.md | 28 ++++++++++------------------ strands_robots/dataset_recorder.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 0a93a93..a299d01 100644 --- a/README.md +++ b/README.md @@ -486,30 +486,22 @@ while True: agent.tool.gr00t_inference(action="stop", port=8000) ``` -## Configuration - -### Environment Variables +## Environment Variables | Variable | Description | Default | |----------|-------------|---------| -| `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (MJCF, meshes) | `~/.strands_robots/assets/` | -| `GROOT_API_TOKEN` | API token for GR00T inference service | — | - -### Cache Directory +| `STRANDS_ROBOT_MODE` | Override auto-detection mode (`sim` or `real`) | (auto-detect) | +| `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (URDF/MJCF/meshes) | `~/.strands_robots/assets/` | +| `STRANDS_URDF_DIR` | **Deprecated** — use `STRANDS_ASSETS_DIR` instead | — | +| `STRANDS_TRUST_REMOTE_CODE` | Set to `1` to allow loading remote HuggingFace policies | (disabled) | +| `GROOT_API_TOKEN` | API token for NVIDIA GR00T cloud inference | — | +| `MUJOCO_GL` | OpenGL backend for MuJoCo rendering (`egl`, `osmesa`, `glfw`) | Auto-configured on headless Linux | -Robot model assets (MJCF XML files and meshes) are cached in: +**Cache directory:** Robot model assets (URDF, MJCF, meshes) are downloaded on first use to `~/.strands_robots/assets/`. Override with `STRANDS_ASSETS_DIR`. To clear cached assets: +```bash +rm -rf ~/.strands_robots/assets/ ``` -~/.strands_robots/ -└── assets/ # Downloaded robot models (from robot_descriptions / MuJoCo Menagerie) - ├── trs_so_arm100/ - ├── franka_emika_panda/ - └── ... -``` - -To clear the cache: `rm -rf ~/.strands_robots/assets/` - -To change the cache location: `export STRANDS_ASSETS_DIR=/path/to/custom/dir` ## Contributing diff --git a/strands_robots/dataset_recorder.py b/strands_robots/dataset_recorder.py index f6155e8..39039b7 100644 --- a/strands_robots/dataset_recorder.py +++ b/strands_robots/dataset_recorder.py @@ -4,6 +4,12 @@ strands_robots.simulation (MuJoCo) can produce training-ready datasets with a single add_frame() call per control step. +Why top-level (strands_robots.dataset_recorder) and not under simulation/? + Both hardware and simulation code paths need to record datasets. Placing + this module at the package root avoids a circular dependency + (simulation -> dataset_recorder -> simulation) and keeps hardware_robot + from reaching into the simulation sub-package. + Usage: recorder = DatasetRecorder.create( repo_id="user/my_dataset", @@ -457,6 +463,10 @@ def __repr__(self) -> str: def load_lerobot_episode(repo_id: str, episode: int = 0, root: str | None = None): """Load a LeRobotDataset and resolve the frame range for an episode. + Used by strands_robots.simulation.mujoco.policy_runner for the + replay_episode action — the simulation backend calls this to load + recorded joint trajectories and replay them in MuJoCo. + Returns: Tuple of (dataset, episode_start, episode_length) on success. From 0917c245e4d5bff27fa075d74abc8b5dcab4e1e5 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 22 Apr 2026 15:23:44 -0400 Subject: [PATCH 22/22] fix(hardware_robot): add missing type annotations for mypy strict Post-rebase on top of clean PR-85 (which enforces zero mypy errors), hardware_robot.py needed the same strictness: - __init__, cleanup, __del__, stop: add -> None - task_runner (nested): add -> None - **kwargs: add : Any annotation (3 methods) Result: 77/77 source files mypy clean. --- strands_robots/hardware_robot.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/strands_robots/hardware_robot.py b/strands_robots/hardware_robot.py index 33c5ba7..d36f83f 100644 --- a/strands_robots/hardware_robot.py +++ b/strands_robots/hardware_robot.py @@ -73,8 +73,8 @@ def __init__( action_horizon: int = 8, data_config: str | Any | None = None, control_frequency: float = 50.0, - **kwargs, - ): + **kwargs: Any, + ) -> None: """Initialize Robot with async capabilities. Args: @@ -116,7 +116,7 @@ def __init__( logger.info(f"⚙️ Data config: {data_config}") def _initialize_robot( - self, robot: LeRobotRobot | RobotConfig | str, cameras: dict[str, dict[str, Any]] | None, **kwargs + self, robot: LeRobotRobot | RobotConfig | str, cameras: dict[str, dict[str, Any]] | None, **kwargs: Any ) -> LeRobotRobot: """Initialize LeRobot robot instance using native lerobot patterns.""" from lerobot.robots.config import RobotConfig @@ -143,7 +143,7 @@ def _initialize_robot( ) def _create_minimal_config( - self, robot_type: str, cameras: dict[str, dict[str, Any]] | None, **kwargs + self, robot_type: str, cameras: dict[str, dict[str, Any]] | None, **kwargs: Any ) -> RobotConfig: """Create minimal robot config using specific robot config classes.""" from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig @@ -393,7 +393,7 @@ def _execute_task_sync( import asyncio # Run task without creating new event loop - let it run in thread - async def task_runner(): + async def task_runner() -> None: await self._execute_task_async(instruction, policy_port, policy_host, policy_provider, duration) # Use asyncio.run only if no loop is running, otherwise run in existing loop @@ -670,7 +670,7 @@ async def stream( ) ) - def cleanup(self): + def cleanup(self) -> None: """Cleanup resources and stop any running tasks.""" try: # Signal shutdown @@ -688,7 +688,7 @@ def cleanup(self): except Exception as e: logger.error(f"❌ Cleanup error for {self.tool_name_str}: {e}") - def __del__(self): + def __del__(self) -> None: """Destructor to ensure cleanup.""" try: self.cleanup() @@ -738,7 +738,7 @@ async def get_status(self) -> dict[str, Any]: "task_status": "error", } - async def stop(self): + async def stop(self) -> None: """Stop robot and disconnect.""" try: # Stop any running task first