diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 79b15d9..b171e27 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -26,6 +26,11 @@ jobs: python-version: '3.12' cache: 'pip' + - name: Install system dependencies (OpenGL for MuJoCo) + run: | + sudo apt-get update + sudo apt-get install -y libosmesa6-dev + - name: Install dependencies run: | pip install --no-cache-dir hatch @@ -35,4 +40,6 @@ jobs: run: hatch run lint - name: Run tests + env: + MUJOCO_GL: osmesa run: hatch run test -x --strict-markers diff --git a/README.md b/README.md index 0a93a93..a299d01 100644 --- a/README.md +++ b/README.md @@ -486,30 +486,22 @@ while True: agent.tool.gr00t_inference(action="stop", port=8000) ``` -## Configuration - -### Environment Variables +## Environment Variables | Variable | Description | Default | |----------|-------------|---------| -| `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (MJCF, meshes) | `~/.strands_robots/assets/` | -| `GROOT_API_TOKEN` | API token for GR00T inference service | — | - -### Cache Directory +| `STRANDS_ROBOT_MODE` | Override auto-detection mode (`sim` or `real`) | (auto-detect) | +| `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (URDF/MJCF/meshes) | `~/.strands_robots/assets/` | +| `STRANDS_URDF_DIR` | **Deprecated** — use `STRANDS_ASSETS_DIR` instead | — | +| `STRANDS_TRUST_REMOTE_CODE` | Set to `1` to allow loading remote HuggingFace policies | (disabled) | +| `GROOT_API_TOKEN` | API token for NVIDIA GR00T cloud inference | — | +| `MUJOCO_GL` | OpenGL backend for MuJoCo rendering (`egl`, `osmesa`, `glfw`) | Auto-configured on headless Linux | -Robot model assets (MJCF XML files and meshes) are cached in: +**Cache directory:** Robot model assets (URDF, MJCF, meshes) are downloaded on first use to `~/.strands_robots/assets/`. Override with `STRANDS_ASSETS_DIR`. To clear cached assets: +```bash +rm -rf ~/.strands_robots/assets/ ``` -~/.strands_robots/ -└── assets/ # Downloaded robot models (from robot_descriptions / MuJoCo Menagerie) - ├── trs_so_arm100/ - ├── franka_emika_panda/ - └── ... -``` - -To clear the cache: `rm -rf ~/.strands_robots/assets/` - -To change the cache location: `export STRANDS_ASSETS_DIR=/path/to/custom/dir` ## Contributing diff --git a/pyproject.toml b/pyproject.toml index f1a7090..d9ae441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,13 +48,16 @@ groot-service = [ lerobot = [ "lerobot>=0.5.0,<0.6.0", ] -sim = [ +sim-mujoco = [ "robot_descriptions>=1.11.0,<2.0.0", + "mujoco>=3.0.0,<4.0.0", + "imageio>=2.28.0,<3.0.0", + "imageio-ffmpeg>=0.4.0,<1.0.0", ] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", - "strands-robots[sim]", + "strands-robots[sim-mujoco]", ] dev = [ "pytest>=6.0,<9.0.0", @@ -128,7 +131,7 @@ ignore_missing_imports = false # Third-party libs without type stubs [[tool.mypy.overrides]] -module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*"] +module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*", "imageio.*"] ignore_missing_imports = true # @tool decorator injects runtime signatures mypy cannot check @@ -161,6 +164,22 @@ module = ["strands_robots.registry.*"] warn_return_any = false disallow_untyped_defs = false +# MuJoCo simulation — mixins use cooperative self._world patterns +# attr-defined: Mixins access self._world/self._lock/etc. from Simulation (cooperative pattern) +# assignment: PEP 484 implicit Optional (= None on typed params) +# override: Subclass signatures extend base with extra params (orientation, mesh_path) +# misc: Multiple inheritance method resolution conflicts between mixin + ABC +[[tool.mypy.overrides]] +module = ["strands_robots.simulation.mujoco.*"] +disallow_untyped_defs = false +warn_return_any = false + +# Async utils and dataset recorder — thin wrappers with dynamic types +[[tool.mypy.overrides]] +module = ["strands_robots._async_utils", "strands_robots.dataset_recorder"] +disallow_untyped_defs = false +warn_return_any = false + # Test files — relaxed type checking for mocks, fixtures, and test utilities [[tool.mypy.overrides]] module = ["tests.*", "tests_integ.*"] diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py index 8ee9c41..40dc6c2 100644 --- a/strands_robots/__init__.py +++ b/strands_robots/__init__.py @@ -11,11 +11,12 @@ - Clean separation between robot control and policy inference - Direct policy injection for maximum flexibility - Multi-camera support with rich configuration options +- MuJoCo simulation backend (no GPU required) Lazy Loading: - Heavy imports (Robot, tools, Gr00tPolicy) are deferred until first access. - Heavy imports are deferred so ``import strands_robots`` stays fast when lerobot/torch - are installed but not yet needed. + Heavy imports (Robot, tools, Gr00tPolicy, Simulation) are deferred until + first access. Heavy imports are deferred so ``import strands_robots`` stays + fast when lerobot/torch/mujoco are installed but not yet needed. Light-weight symbols (Policy, MockPolicy, create_policy) are available immediately since they don't pull in torch/lerobot. @@ -26,7 +27,7 @@ from typing import Any # ------------------------------------------------------------------ -# Light-weight imports — no torch / lerobot dependency +# Light-weight imports — no torch / lerobot / mujoco dependency # ------------------------------------------------------------------ from strands_robots.policies import MockPolicy, Policy, create_policy # noqa: F401 @@ -35,8 +36,21 @@ # ------------------------------------------------------------------ # Maps public name -> (module_path, attribute_name) _LAZY_IMPORTS: dict[str, tuple[str, str]] = { + # Hardware robot "Robot": ("strands_robots.robot", "Robot"), + "list_robots": ("strands_robots.registry", "list_robots"), + # Policies "Gr00tPolicy": ("strands_robots.policies.groot", "Gr00tPolicy"), + # Simulation (MuJoCo) + "Simulation": ("strands_robots.simulation", "Simulation"), + "create_simulation": ("strands_robots.simulation.factory", "create_simulation"), + "list_backends": ("strands_robots.simulation.factory", "list_backends"), + "register_backend": ("strands_robots.simulation.factory", "register_backend"), + "SimWorld": ("strands_robots.simulation", "SimWorld"), + "SimRobot": ("strands_robots.simulation", "SimRobot"), + "SimObject": ("strands_robots.simulation", "SimObject"), + "SimCamera": ("strands_robots.simulation", "SimCamera"), + # Tools "gr00t_inference": ("strands_robots.tools.gr00t_inference", "gr00t_inference"), "lerobot_calibrate": ("strands_robots.tools.lerobot_calibrate", "lerobot_calibrate"), "lerobot_camera": ("strands_robots.tools.lerobot_camera", "lerobot_camera"), @@ -53,6 +67,11 @@ # Lazy-loaded "Robot", "Gr00tPolicy", + "Simulation", + "SimWorld", + "SimRobot", + "SimObject", + "SimCamera", "gr00t_inference", "lerobot_camera", "lerobot_teleoperate", @@ -62,12 +81,32 @@ ] +# Auto-configure MuJoCo GL backend for headless environments BEFORE any +# module imports mujoco at the top level. MuJoCo locks the OpenGL backend +# at import time, so MUJOCO_GL must be set first. +# +# WHY EAGER: This MUST run at module import time, not lazily, because: +# 1. MuJoCo reads MUJOCO_GL only on first `import mujoco` +# 2. Any downstream code doing `from strands_robots.simulation import ...` +# triggers mujoco import via the lazy-load chain +# 3. If we defer to first use, the env var would be set too late +# This is the canonical location — strands_robots/simulation/__init__.py +# intentionally does NOT duplicate this call. +try: + from strands_robots.simulation.mujoco.backend import _configure_gl_backend + + _configure_gl_backend() +except (ImportError, AttributeError, OSError): + pass + + def __getattr__(name: str) -> Any: # noqa: N807 """Lazy-load heavy modules on first attribute access. - This avoids importing torch, lerobot, numpy, pyserial, etc. at + This avoids importing torch, lerobot, numpy, mujoco, pyserial, etc. at ``import strands_robots`` time. The first access to e.g. - ``strands_robots.Robot`` triggers the real import. + ``strands_robots.Robot`` or ``strands_robots.Simulation`` triggers the + real import. """ if name in _LAZY_IMPORTS: module_path, attr_name = _LAZY_IMPORTS[name] diff --git a/strands_robots/_async_utils.py b/strands_robots/_async_utils.py new file mode 100644 index 0000000..dc9ac8d --- /dev/null +++ b/strands_robots/_async_utils.py @@ -0,0 +1,36 @@ +"""Async-to-sync helper for resolving coroutines in sync contexts.""" + +import asyncio +import atexit +from concurrent.futures import ThreadPoolExecutor + +# Module-level executor reused across calls to avoid creating threads at high frequency. +# A single worker is sufficient — we only need to offload one asyncio.run() at a time. +_EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="strands_async") + +# Ensure the executor is shut down cleanly on interpreter exit to avoid +# ResourceWarning and orphaned threads. +atexit.register(_EXECUTOR.shutdown, wait=False) + + +def _resolve_coroutine(coro_or_result): # type: ignore[no-untyped-def] + """Safely resolve a potentially-async result to a sync value. + + Handles three cases: + 1. Already a plain value → return as-is + 2. Coroutine, no running loop → asyncio.run() + 3. Coroutine, inside running loop → offload to reused thread + + Args: + coro_or_result: Either a coroutine or an already-resolved value. + + Returns: + The resolved (sync) value. + """ + if not asyncio.iscoroutine(coro_or_result): + return coro_or_result + try: + asyncio.get_running_loop() + return _EXECUTOR.submit(asyncio.run, coro_or_result).result() + except RuntimeError: + return asyncio.run(coro_or_result) diff --git a/strands_robots/dataset_recorder.py b/strands_robots/dataset_recorder.py new file mode 100644 index 0000000..39039b7 --- /dev/null +++ b/strands_robots/dataset_recorder.py @@ -0,0 +1,516 @@ +"""LeRobotDataset recorder bridge for strands-robots. + +Wraps LeRobotDataset so that both strands_robots.hardware_robot and +strands_robots.simulation (MuJoCo) can produce training-ready datasets with +a single add_frame() call per control step. + +Why top-level (strands_robots.dataset_recorder) and not under simulation/? + Both hardware and simulation code paths need to record datasets. Placing + this module at the package root avoids a circular dependency + (simulation -> dataset_recorder -> simulation) and keeps hardware_robot + from reaching into the simulation sub-package. + +Usage: + recorder = DatasetRecorder.create( + repo_id="user/my_dataset", + fps=30, + robot_features=robot.observation_features, + action_features=robot.action_features, + task="pick up the red cube", + ) + # In control loop: + recorder.add_frame(observation, action, task="pick up the red cube") + # End of episode: + recorder.save_episode() + # Optionally: + recorder.push_to_hub() +""" + +import functools +import logging +import sys +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + +# ── Lazy check for LeRobot availability ────────────────────────────── +# We must NOT import lerobot at module level because it pulls in +# `datasets` → `pandas`, which can crash with a numpy ABI mismatch on +# systems where the system pandas was compiled against an older numpy +# (e.g. JetPack / Jetson with system pandas 2.1.4 + pip numpy 2.x). + + +@functools.lru_cache(maxsize=1) +def has_lerobot_dataset() -> bool: + """Check if lerobot is available. Result is cached after first call.""" + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: F401 + + return True + except (ImportError, ValueError, RuntimeError) as exc: + logger.debug("lerobot not available: %s", exc) + return False + + +def _get_lerobot_dataset_class(): + """Import and return LeRobotDataset class, or raise ImportError. + + Supports test mocking: if ``strands_robots.dataset_recorder.LeRobotDataset`` + has been set (by a test mock), returns that class directly. + """ + # Support test mocking: check module-level overrides + this_module = sys.modules[__name__] + + # If a test injected a mock LeRobotDataset class, use it + mock_cls = getattr(this_module, "LeRobotDataset", None) + if mock_cls is not None: + return mock_cls + + # Actual import + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + return LeRobotDataset + except (ImportError, ValueError, RuntimeError) as exc: + raise ImportError( + f"lerobot not available ({exc}). Install with: pip install lerobot\nRequired for LeRobotDataset recording." + ) from exc + + +class DatasetRecorder: + """Bridge between strands-robots control loops and LeRobotDataset. + + Handles the full lifecycle: + 1. create() — build LeRobotDataset with correct features + 2. add_frame() — called every control step with obs + action + 3. save_episode() — finalize episode (encodes video, writes parquet) + 4. push_to_hub() — upload to HuggingFace + + Works for both real hardware (robot.py) and simulation (simulation.py). + """ + + def __init__(self, dataset, task: str = "", strict: bool = True): + self.dataset = dataset + self.default_task = task + self.frame_count = 0 + self.dropped_frame_count = 0 + self.strict = strict + self.episode_count = 0 + self._closed = False + self._cached_state_keys: list[str] | None = None + self._cached_action_keys: list[str] | None = None + + @classmethod + def create( + cls, + repo_id: str, + fps: int = 30, + robot_type: str = "unknown", + robot_features: dict[str, Any] | None = None, + action_features: dict[str, Any] | None = None, + camera_keys: list[str] | None = None, + joint_names: list[str] | None = None, + task: str = "", + root: str | None = None, + use_videos: bool = True, + vcodec: str = "libsvtav1", + streaming_encoding: bool = True, + image_writer_threads: int = 4, + video_backend: str = "auto", + ) -> "DatasetRecorder": + """Create a new DatasetRecorder with auto-detected features. + + Args: + repo_id: HuggingFace dataset ID (e.g. "user/my_dataset") + fps: Recording frame rate + robot_type: Robot type string (e.g. "so100", "panda") + robot_features: Dict of observation feature names → types + (from robot.observation_features or sim joint names) + action_features: Dict of action feature names → types + camera_keys: List of camera names (images become video features) + joint_names: List of joint names (alternative to robot_features for sim) + task: Default task description + root: Local directory for dataset storage + use_videos: Encode camera frames as video (True) or keep as images + vcodec: Video codec (h264, hevc, libsvtav1) + streaming_encoding: Stream-encode video during capture + image_writer_threads: Threads for writing image frames + video_backend: Video backend for encoding ("auto" for HW encoder auto-detect) + """ + # Lazy import — this is where we actually need lerobot + LeRobotDatasetCls = _get_lerobot_dataset_class() + + # Build features dict in LeRobot format + features = cls._build_features( + robot_features=robot_features, + action_features=action_features, + camera_keys=camera_keys, + joint_names=joint_names, + use_videos=use_videos, + ) + + logger.info(f"Creating LeRobotDataset: {repo_id} @ {fps}fps, {len(features)} features, robot_type={robot_type}") + + # Build kwargs, skip unsupported params for this LeRobot version + create_kwargs = dict( + repo_id=repo_id, + fps=fps, + root=root, + robot_type=robot_type, + features=features, + use_videos=use_videos, + image_writer_threads=image_writer_threads, + vcodec=vcodec, + ) + # streaming_encoding only in newer LeRobot versions + import inspect + + create_sig = inspect.signature(LeRobotDatasetCls.create) + if "streaming_encoding" in create_sig.parameters: + create_kwargs["streaming_encoding"] = streaming_encoding + if "video_backend" in create_sig.parameters: + create_kwargs["video_backend"] = video_backend + dataset = LeRobotDatasetCls.create(**create_kwargs) + + recorder = cls(dataset=dataset, task=task) + logger.info("DatasetRecorder ready: %s", repo_id) + return recorder + + @classmethod + def _build_features( + cls, + robot_features: dict | None = None, + action_features: dict | None = None, + camera_keys: list[str] | None = None, + joint_names: list[str] | None = None, + use_videos: bool = True, + camera_shapes: dict[str, tuple[int, int, int]] | None = None, + ) -> dict[str, Any]: + """Build LeRobot v3-compatible features dict. + + LeRobot v3 features format: + { + "observation.images.camera_name": {"dtype": "video", "shape": (C, H, W), "names": [...]}, + "observation.state": {"dtype": "float32", "shape": (N,), "names": [...]}, + "action": {"dtype": "float32", "shape": (N,), "names": [...]}, + } + + Note: "names" must be a flat list of strings, NOT a dict like {"motors": [...]}. + """ + features = {} + + # --- Observation: cameras → video/image features --- + if camera_keys: + for cam_name in camera_keys: + key = f"observation.images.{cam_name}" + dtype = "video" if use_videos else "image" + # Per-camera shape override, default (3, 480, 640) CHW + shape = (3, 480, 640) + if camera_shapes and cam_name in camera_shapes: + shape = camera_shapes[cam_name] + features[key] = { + "dtype": dtype, + "shape": shape, + "names": ["channels", "height", "width"], + } + + # --- Observation: state (joint positions) --- + state_dim = 0 + state_names = [] + if robot_features: + # Count scalar features (exclude cameras) + state_keys = [ + k + for k, v in robot_features.items() + if not isinstance(v, dict) or v.get("dtype") not in ("image", "video") + ] + state_dim = len(state_keys) + state_names = state_keys + elif joint_names: + state_dim = len(joint_names) + state_names = list(joint_names) + + if state_dim > 0: + features["observation.state"] = { + "dtype": "float32", + "shape": (state_dim,), + "names": state_names, + } + + # --- Action --- + action_dim = 0 + action_names = [] + if action_features: + action_keys = [ + k + for k, v in action_features.items() + if not isinstance(v, dict) or v.get("dtype") not in ("image", "video") + ] + action_dim = len(action_keys) + action_names = action_keys + elif joint_names: + action_dim = len(joint_names) + action_names = list(joint_names) + elif state_dim > 0: + action_dim = state_dim # Same dim as state by default + action_names = state_names[:] + + if action_dim > 0: + features["action"] = { + "dtype": "float32", + "shape": (action_dim,), + "names": action_names[:action_dim], + } + + return features + + def add_frame( + self, + observation: dict[str, Any], + action: dict[str, Any], + task: str | None = None, + camera_keys: list[str] | None = None, + ) -> None: + """Add a single control-loop frame to the dataset. + + This is the key method — called every step in the control loop. + + Args: + observation: Raw observation dict from robot/sim + (joint_name → float, camera_name → np.ndarray) + action: Action dict (joint_name → float) + task: Task description (uses default if None) + camera_keys: Which keys in observation are camera images + """ + if self._closed: + return + + frame = {} + + # --- Detect camera vs state keys --- + if camera_keys is None: + camera_keys = [k for k, v in observation.items() if isinstance(v, np.ndarray) and v.ndim >= 2] + + state_keys = [k for k in observation.keys() if k not in camera_keys] + + # --- Camera images → observation.images.{name} --- + for cam_key in camera_keys: + img = observation[cam_key] + if isinstance(img, np.ndarray): + # LeRobot expects HWC uint8 for add_frame + if img.dtype != np.uint8: + img = (np.clip(img, 0, 1) * 255).astype(np.uint8) + frame[f"observation.images.{cam_key}"] = img + + # --- State → observation.state (flattened vector) --- + # Use feature schema ordering to match the dataset schema declared in _build_features(). + if state_keys: + state_vals = [] + if self._cached_state_keys is None: + feat = self.dataset.features.get("observation.state", {}) + state_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", []) + self._cached_state_keys = state_names if state_names else sorted(state_keys) + + for k in self._cached_state_keys: + v = observation.get(k) + if v is None: + state_vals.append(0.0) + elif isinstance(v, (int, float)): + state_vals.append(float(v)) + elif isinstance(v, np.ndarray) and v.ndim == 0: + state_vals.append(float(v)) + elif isinstance(v, (list, np.ndarray)): + arr = np.asarray(v, dtype=np.float32).flatten() + state_vals.extend(arr.tolist()) + if state_vals: + frame["observation.state"] = np.array(state_vals, dtype=np.float32) + + # --- Action → flattened vector --- + # Use feature schema ordering for actions too. + if action: + action_vals = [] + if self._cached_action_keys is None: + feat = self.dataset.features.get("action", {}) + action_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", []) + self._cached_action_keys = action_names if action_names else sorted(action.keys()) + + for k in self._cached_action_keys: + v = action.get(k) + if v is None: + action_vals.append(0.0) + elif isinstance(v, (int, float)): + action_vals.append(float(v)) + elif isinstance(v, np.ndarray) and v.ndim == 0: + action_vals.append(float(v)) + elif isinstance(v, (list, np.ndarray)): + arr = np.asarray(v, dtype=np.float32).flatten() + action_vals.extend(arr.tolist()) + if action_vals: + frame["action"] = np.array(action_vals, dtype=np.float32) + + # --- Task (mandatory for LeRobot v3) --- + frame["task"] = task or self.default_task or "untitled" + + # --- Reconcile camera keys between frame and feature schema --- + # Only strip *undeclared* cameras from the frame (keys present in obs + # but not registered in _build_features). This avoids LeRobot's + # "Extra features" error. Declared-but-missing cameras (e.g. when a + # render fails) are left alone — LeRobot tolerates absent columns and + # the episode simply won't have that camera's data. + declared_cam_keys = {k for k in self.dataset.features if k.startswith("observation.images.")} + frame_cam_keys = {k for k in frame if k.startswith("observation.images.")} + for extra in frame_cam_keys - declared_cam_keys: + del frame[extra] + + # --- Add to dataset --- + try: + self.dataset.add_frame(frame) + self.frame_count += 1 + except Exception as e: + if self.strict: + raise # Fail-fast per AGENTS.md convention #5 + self.dropped_frame_count += 1 + n = self.dropped_frame_count + # Log at 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, then every 1000 + if (n & (n - 1)) == 0 or n % 1000 == 0: + logger.warning( + "add_frame failed (frame %d, dropped %d): %s", + self.frame_count, + self.dropped_frame_count, + e, + ) + + def save_episode(self) -> dict[str, Any]: + """Finalize current episode — writes parquet, encodes video, computes stats. + + LeRobot v3: save_episode() takes no task argument. Tasks are stored + per-frame in the episode buffer via add_frame(). + + Returns: + Dict with episode info + """ + if self._closed: + return {"status": "error", "message": "Recorder closed"} + + try: + self.dataset.save_episode() + self.episode_count += 1 + ep_frames = self.frame_count # Total frames so far + logger.info(f"Episode {self.episode_count} saved: {ep_frames} total frames") + return { + "status": "success", + "episode": self.episode_count, + "total_frames": ep_frames, + } + except Exception as e: + logger.error("save_episode failed: %s", e) + return {"status": "error", "message": str(e)} + + def finalize(self) -> None: + """Finalize the dataset (close parquet writers, flush metadata).""" + if self._closed: + return + try: + self.dataset.finalize() + except Exception as e: + logger.warning("finalize warning: %s", e) + self._closed = True + + def push_to_hub( + self, + tags: list[str] | None = None, + private: bool = False, + ) -> dict[str, Any]: + """Push dataset to HuggingFace Hub. + + Args: + tags: Optional tags for the dataset + private: Upload as private dataset + + Returns: + Dict with push status + """ + try: + self.dataset.push_to_hub(tags=tags, private=private) + logger.info("Dataset pushed to hub: %s", self.dataset.repo_id) + return { + "status": "success", + "repo_id": self.dataset.repo_id, + "episodes": self.episode_count, + "frames": self.frame_count, + } + except Exception as e: + logger.error("push_to_hub failed: %s", e) + return {"status": "error", "message": str(e)} + + @property + def repo_id(self) -> str: + return self.dataset.repo_id + + @property + def root(self) -> str: + return str(self.dataset.root) + + def __repr__(self) -> str: + return f"DatasetRecorder(repo_id={self.repo_id}, episodes={self.episode_count}, frames={self.frame_count})" + + +# ── Shared replay-episode helpers ──────────────────────────────────── + + +def load_lerobot_episode(repo_id: str, episode: int = 0, root: str | None = None): + """Load a LeRobotDataset and resolve the frame range for an episode. + + Used by strands_robots.simulation.mujoco.policy_runner for the + replay_episode action — the simulation backend calls this to load + recorded joint trajectories and replay them in MuJoCo. + + Returns: + Tuple of (dataset, episode_start, episode_length) on success. + + Raises: + ImportError: If lerobot is not installed. + ValueError: If the episode is out of range or has no frames. + """ + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + ds = LeRobotDataset(repo_id=repo_id, root=root) + + num_episodes = ds.meta.total_episodes if hasattr(ds.meta, "total_episodes") else len(ds.meta.episodes) + if episode >= num_episodes: + raise ValueError(f"Episode {episode} out of range (0-{num_episodes - 1})") + + episode_start = 0 + episode_length = 0 + try: + if hasattr(ds, "episode_data_index"): + from_idx = ds.episode_data_index["from"][episode].item() + to_idx = ds.episode_data_index["to"][episode].item() + episode_start = from_idx + episode_length = to_idx - from_idx + else: + for i in range(episode): + ep_info = ds.meta.episodes[i] if hasattr(ds.meta, "episodes") else {} + episode_start += ep_info.get("length", 0) + ep_info = ds.meta.episodes[episode] if hasattr(ds.meta, "episodes") else {} + episode_length = ep_info.get("length", 0) + except Exception: + # Last resort: scan frames to find episode boundaries + for idx in range(len(ds)): + frame = ds[idx] + frame_ep = frame.get("episode_index", -1) if hasattr(frame, "get") else -1 + if hasattr(frame_ep, "item"): + frame_ep = frame_ep.item() + if frame_ep == episode: + if episode_length == 0: + episode_start = idx + episode_length += 1 + elif episode_length > 0: + break + + if episode_length == 0: + raise ValueError(f"Episode {episode} has no frames") + + return ds, episode_start, episode_length diff --git a/strands_robots/hardware_robot.py b/strands_robots/hardware_robot.py new file mode 100644 index 0000000..d36f83f --- /dev/null +++ b/strands_robots/hardware_robot.py @@ -0,0 +1,758 @@ +#!/usr/bin/env python3 +""" +Universal Robot Control with Policy Abstraction for Any VLA Provider + +This module provides a clean robot interface that works with any LeRobot-compatible +robot and any VLA provider through the Policy abstraction. + +Features: +- Async robot task execution with real-time status reporting +- Non-blocking operations - robot moves while tool returns status +- Stop functionality to interrupt running tasks +- Connection state management with proper error handling +- Policy abstraction for any VLA provider +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +import time +from collections.abc import AsyncGenerator +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, cast + +from strands.tools.tools import AgentTool +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolResult, ToolSpec, ToolUse + +if TYPE_CHECKING: + from lerobot.robots.config import RobotConfig + from lerobot.robots.robot import Robot as LeRobotRobot + + from .policies import Policy + +logger = logging.getLogger(__name__) + + +class TaskStatus(Enum): + """Robot task execution status""" + + IDLE = "idle" + CONNECTING = "connecting" + RUNNING = "running" + COMPLETED = "completed" + STOPPED = "stopped" + ERROR = "error" + + +@dataclass +class RobotTaskState: + """Robot task execution state""" + + status: TaskStatus = TaskStatus.IDLE + instruction: str = "" + start_time: float = 0.0 + duration: float = 0.0 + step_count: int = 0 + error_message: str = "" + task_future: Future | None = None + + +class Robot(AgentTool): + """Universal robot control with async task execution and status reporting.""" + + def __init__( + self, + tool_name: str, + robot: LeRobotRobot | RobotConfig | str, + cameras: dict[str, dict[str, Any]] | None = None, + action_horizon: int = 8, + data_config: str | Any | None = None, + control_frequency: float = 50.0, + **kwargs: Any, + ) -> None: + """Initialize Robot with async capabilities. + + Args: + tool_name: Name for this robot tool + robot: LeRobot Robot instance, RobotConfig, or robot type string + cameras: Camera configuration dict: + {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}} + action_horizon: Actions per inference step + data_config: Data configuration (for GR00T compatibility) + control_frequency: Control loop frequency in Hz (default: 50Hz) + **kwargs: Robot-specific parameters (port, etc.) + """ + super().__init__() + + self.tool_name_str = tool_name + self.action_horizon = action_horizon + self.data_config = data_config + self.control_frequency = control_frequency + self.action_sleep_time = 1.0 / control_frequency # Time between actions + + # Task execution state + self._task_state = RobotTaskState() + self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"{tool_name}_executor") + self._shutdown_event = threading.Event() + + # Initialize robot using lerobot's abstraction + self.robot = self._initialize_robot(robot, cameras, **kwargs) + + logger.info(f"🤖 {tool_name} initialized with async capabilities") + logger.info(f"📱 Robot: {self.robot.name} (type: {getattr(self.robot, 'robot_type', 'unknown')})") + logger.info(f"⏱️ Control frequency: {control_frequency}Hz ({self.action_sleep_time * 1000:.1f}ms per action)") + + # Get camera info if available + if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): + cameras_list = list(self.robot.config.cameras.keys()) + logger.info(f"📹 Cameras: {cameras_list}") + + if data_config: + logger.info(f"⚙️ Data config: {data_config}") + + def _initialize_robot( + self, robot: LeRobotRobot | RobotConfig | str, cameras: dict[str, dict[str, Any]] | None, **kwargs: Any + ) -> LeRobotRobot: + """Initialize LeRobot robot instance using native lerobot patterns.""" + from lerobot.robots.config import RobotConfig + from lerobot.robots.robot import Robot as LeRobotRobot + from lerobot.robots.utils import make_robot_from_config + + # Direct robot instance - use as-is + if isinstance(robot, LeRobotRobot): + return robot + + # Robot config - use lerobot's factory + elif isinstance(robot, RobotConfig): + return make_robot_from_config(robot) + + # Robot type string - create config and use lerobot's factory + elif isinstance(robot, str): + config = self._create_minimal_config(robot, cameras, **kwargs) + return make_robot_from_config(config) + + else: + raise ValueError( + f"Unsupported robot type: {type(robot)}. " + f"Expected LeRobot Robot instance, RobotConfig, or robot type string." + ) + + def _create_minimal_config( + self, robot_type: str, cameras: dict[str, dict[str, Any]] | None, **kwargs: Any + ) -> RobotConfig: + """Create minimal robot config using specific robot config classes.""" + from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig + + # Convert cameras to lerobot format + camera_configs = {} + if cameras: + for name, config in cameras.items(): + if config.get("type", "opencv") == "opencv": + camera_configs[name] = OpenCVCameraConfig( + index_or_path=config["index_or_path"], + fps=config.get("fps", 30), + width=config.get("width", 640), + height=config.get("height", 480), + rotation=config.get("rotation", 0), + color_mode=config.get("color_mode", "rgb"), + ) + else: + raise ValueError(f"Unsupported camera type: {config.get('type')}") + + # Map robot type to specific config class + config_mapping = { + "so101_follower": ("lerobot.robots.so101_follower", "SO101FollowerConfig"), + "so100_follower": ("lerobot.robots.so100_follower", "SO100FollowerConfig"), + "bi_so100_follower": ("lerobot.robots.bi_so100_follower", "BiSO100FollowerConfig"), + "viperx": ("lerobot.robots.viperx", "ViperXConfig"), + "koch_follower": ("lerobot.robots.koch_follower", "KochFollowerConfig"), + # Add more as needed + } + + if robot_type not in config_mapping: + raise ValueError(f"Unsupported robot type: {robot_type}. Supported types: {list(config_mapping.keys())}") + + # Import specific config class dynamically + module_name, class_name = config_mapping[robot_type] + try: + import importlib + + module = importlib.import_module(module_name) + ConfigClass = getattr(module, class_name) + except Exception as e: + raise ValueError(f"Failed to import {class_name} from {module_name}: {e}") + + # Create config with proper parameters + config_data = { + "id": self.tool_name_str, + "cameras": camera_configs, + } + + # Filter kwargs to only include supported fields for this robot type + # Port is common for most serial robots + if "port" in kwargs: + config_data["port"] = kwargs["port"] + + # Add other common fields as needed + for key in ["calibration_dir", "mock", "use_degrees"]: + if key in kwargs: + config_data[key] = kwargs[key] + + try: + return ConfigClass(**config_data) + except Exception as e: + raise ValueError(f"Failed to create {class_name} for robot type '{robot_type}': {e}. Config: {config_data}") + + async def _get_policy( + self, policy_port: int | None = None, policy_host: str = "localhost", policy_provider: str = "groot" + ) -> Policy: + """Create policy on-the-fly from invocation parameters.""" + from .policies import create_policy + + if not policy_port: + raise ValueError("policy_port is required for robot operation") + + policy_config = {"port": policy_port, "host": policy_host} + + if self.data_config: + policy_config["data_config"] = self.data_config + + return create_policy(policy_provider, **policy_config) + + async def _connect_robot(self) -> tuple[bool, str]: + """Connect to robot hardware with proper error handling. + + Returns: + tuple[bool, str]: (success, error_message) - error_message is empty on success + """ + try: + # Import lerobot exceptions + from lerobot.utils.errors import DeviceAlreadyConnectedError + + # Check if already connected + if self.robot.is_connected: + logger.info(f"✅ {self.robot} already connected") + return True, "" + + logger.info(f"🔌 Connecting to {self.robot}...") + + # Handle robot connection using lerobot's error handling patterns + try: + if not self.robot.is_connected: + await asyncio.to_thread(self.robot.connect, False) # calibrate=False + + except DeviceAlreadyConnectedError: + # This is expected and fine - robot is already connected + logger.info(f"✅ {self.robot} was already connected") + + except Exception as e: + # Check if it's the string version of "already connected" error + error_str = str(e).lower() + if "already connected" in error_str or "is already connected" in error_str: + logger.info(f"✅ {self.robot} connection already established") + else: + # Re-raise if it's a different error + raise e + + # Final connection check + if not self.robot.is_connected: + error_msg = f"Failed to connect to {self.robot}" + logger.error(f"❌ {error_msg}") + return False, error_msg + + # Check robot calibration + if hasattr(self.robot, "is_calibrated") and not self.robot.is_calibrated: + error_msg = ( + f"Robot {self.robot} is not calibrated. Please calibrate the robot manually" + " first using LeRobot's calibration process (lerobot-calibrate)" + ) + logger.error(f"❌ {error_msg}") + return False, error_msg + + logger.info(f"✅ {self.robot} connected and ready") + return True, "" + + except Exception as e: + error_msg = f"Robot connection failed: {e}. Ensure robot is calibrated and accessible on the specified port" + logger.error(f"❌ {error_msg}") + return False, error_msg + + async def _initialize_policy(self, policy: Policy) -> bool: + """Initialize policy with robot state keys.""" + try: + # Get robot state keys from observation + test_obs = await asyncio.to_thread(self.robot.get_observation) + + # Filter out camera keys to get robot state keys + camera_keys = [] + if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): + camera_keys = list(self.robot.config.cameras.keys()) + + robot_state_keys = [k for k in test_obs.keys() if k not in camera_keys] + + # Set robot state keys in policy + policy.set_robot_state_keys(robot_state_keys) + return True + + except Exception as e: + logger.error(f"❌ Failed to initialize policy: {e}") + return False + + async def _execute_task_async( + self, + instruction: str, + policy_port: int | None = None, + policy_host: str = "localhost", + policy_provider: str = "groot", + duration: float = 30.0, + ) -> None: + """Execute robot task in background thread (internal method).""" + try: + # Update task state + self._task_state.status = TaskStatus.CONNECTING + self._task_state.instruction = instruction + self._task_state.start_time = time.time() + self._task_state.step_count = 0 + self._task_state.error_message = "" + + # Connect to robot + connected, connect_error = await self._connect_robot() + if not connected: + self._task_state.status = TaskStatus.ERROR + self._task_state.error_message = connect_error or f"Failed to connect to {self.tool_name_str}" + return + + # Get policy instance + policy_instance = await self._get_policy(policy_port, policy_host, policy_provider) + + # Initialize policy with robot state keys + if not await self._initialize_policy(policy_instance): + self._task_state.status = TaskStatus.ERROR + self._task_state.error_message = "Failed to initialize policy" + return + + logger.info(f"🎯 Starting task: '{instruction}' on {self.tool_name_str}") + logger.info(f"🧠 Using policy: {policy_provider} on {policy_host}:{policy_port}") + + self._task_state.status = TaskStatus.RUNNING + start_time = time.time() + + while ( + time.time() - start_time < duration + and self._task_state.status == TaskStatus.RUNNING + and not self._shutdown_event.is_set() + ): + # Get observation from robot + observation = await asyncio.to_thread(self.robot.get_observation) + + # Get actions from policy + robot_actions = await policy_instance.get_actions(observation, instruction) + + # Execute actions from chunk with proper timing control + # Wait between actions for smooth execution + for action_dict in robot_actions[: self.action_horizon]: + if self._task_state.status != TaskStatus.RUNNING: + break + await asyncio.to_thread(self.robot.send_action, action_dict) + self._task_state.step_count += 1 + # Wait for action to complete before sending next action + # Default 50Hz (0.02s) + await asyncio.sleep(self.action_sleep_time) + + # Update final state + elapsed = time.time() - start_time + self._task_state.duration = elapsed + + if self._task_state.status == TaskStatus.RUNNING: + self._task_state.status = TaskStatus.COMPLETED + logger.info( + f"✅ Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)" + ) + + except Exception as e: + logger.error(f"❌ Task execution failed: {e}") + self._task_state.status = TaskStatus.ERROR + self._task_state.error_message = str(e) + + def _execute_task_sync( + self, + instruction: str, + policy_port: int | None = None, + policy_host: str = "localhost", + policy_provider: str = "groot", + duration: float = 30.0, + ) -> dict[str, Any]: + """Execute task synchronously in thread - no new event loop.""" + + # Import here to avoid conflicts + import asyncio + + # Run task without creating new event loop - let it run in thread + async def task_runner() -> None: + await self._execute_task_async(instruction, policy_port, policy_host, policy_provider, duration) + + # Use asyncio.run only if no loop is running, otherwise run in existing loop + try: + # Try to get the current event loop + asyncio.get_running_loop() + # If we're already in an event loop, we need to run in a thread + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as exec: + future = exec.submit(lambda: asyncio.run(task_runner())) + future.result() # Wait for completion + except RuntimeError: + # No event loop running - safe to create one + asyncio.run(task_runner()) + + # Return final status + return { + "status": "success" if self._task_state.status == TaskStatus.COMPLETED else "error", + "content": [ + { + "text": f"✅ Task: '{instruction}' - {self._task_state.status.value}\n" + f"🤖 Robot: {self.tool_name_str} ({self.robot})\n" + f"🧠 Policy: {policy_provider} on {policy_host}:{policy_port}\n" + f"⏱️ Duration: {self._task_state.duration:.1f}s\n" + f"🎯 Steps: {self._task_state.step_count}" + + (f"\n❌ Error: {self._task_state.error_message}" if self._task_state.error_message else "") + } + ], + } + + def start_task( + self, + instruction: str, + policy_port: int | None = None, + policy_host: str = "localhost", + policy_provider: str = "groot", + duration: float = 30.0, + ) -> dict[str, Any]: + """Start robot task asynchronously and return immediately.""" + + # Check if task is already running + if self._task_state.status == TaskStatus.RUNNING: + return { + "status": "error", + "content": [{"text": f"❌ Task already running: {self._task_state.instruction}"}], + } + + # Start task in background + self._task_state.task_future = self._executor.submit( + self._execute_task_sync, instruction, policy_port, policy_host, policy_provider, duration + ) + + return { + "status": "success", + "content": [ + { + "text": f"🚀 Task started: '{instruction}'\n" + f"🤖 Robot: {self.tool_name_str}\n" + f"💡 Use action='status' to check progress\n" + f"💡 Use action='stop' to interrupt" + } + ], + } + + def get_task_status(self) -> dict[str, Any]: + """Get current task execution status.""" + + # Update duration for running tasks + if self._task_state.status == TaskStatus.RUNNING: + self._task_state.duration = time.time() - self._task_state.start_time + + status_text = f"📊 Robot Status: {self._task_state.status.value.upper()}\n" + + if self._task_state.instruction: + status_text += f"🎯 Task: {self._task_state.instruction}\n" + + if self._task_state.status == TaskStatus.RUNNING: + status_text += f"⏱️ Duration: {self._task_state.duration:.1f}s\n" + status_text += f"🔄 Steps: {self._task_state.step_count}\n" + elif self._task_state.status in [TaskStatus.COMPLETED, TaskStatus.STOPPED, TaskStatus.ERROR]: + status_text += f"⏱️ Total Duration: {self._task_state.duration:.1f}s\n" + status_text += f"🎯 Total Steps: {self._task_state.step_count}\n" + + if self._task_state.error_message: + status_text += f"❌ Error: {self._task_state.error_message}\n" + + return { + "status": "success", + "content": [{"text": status_text}], + } + + def stop_task(self) -> dict[str, Any]: + """Stop currently running task.""" + + if self._task_state.status != TaskStatus.RUNNING: + return { + "status": "success", + "content": [{"text": f"💤 No task running to stop (current: {self._task_state.status.value})"}], + } + + # Signal task to stop + self._task_state.status = TaskStatus.STOPPED + + # Cancel future if it exists + if self._task_state.task_future: + self._task_state.task_future.cancel() + + logger.info(f"🛑 Task stopped: {self._task_state.instruction}") + + return { + "status": "success", + "content": [ + { + "text": f"🛑 Task stopped: '{self._task_state.instruction}'\n" + f"⏱️ Duration: {self._task_state.duration:.1f}s\n" + f"🎯 Steps completed: {self._task_state.step_count}" + } + ], + } + + @property + def tool_name(self) -> str: + return self.tool_name_str + + @property + def tool_type(self) -> str: + return "robot" + + @property + def tool_spec(self) -> ToolSpec: + """Get tool specification with async actions.""" + return { + "name": self.tool_name_str, + "description": f"Universal robot control with async task execution ({self.robot}). " + f"Actions: execute (blocking), start (async), status, stop. " + f"For execute/start actions: instruction and policy_port are required. " + f"For status/stop actions: no additional parameters needed.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "action": { + "type": "string", + "description": "Action to perform: execute (blocking), start (async), status, stop", + "enum": ["execute", "start", "status", "stop"], + "default": "execute", + }, + "instruction": { + "type": "string", + "description": "Natural language instruction (required for execute/start actions)", + }, + "policy_port": { + "type": "integer", + "description": "Policy service port (required for execute/start actions)", + }, + "policy_host": { + "type": "string", + "description": "Policy service host (default: localhost)", + "default": "localhost", + }, + "policy_provider": { + "type": "string", + "description": "Policy provider (groot, openai, etc.)", + "default": "groot", + }, + "duration": { + "type": "number", + "description": "Maximum execution time in seconds", + "default": 30.0, + }, + }, + "required": ["action"], + } + }, + } + + @staticmethod + def _make_tool_result(tool_use_id: str, result: dict[str, Any]) -> ToolResult: + """Create a ToolResult dict with the given tool_use_id merged into result.""" + return cast(ToolResult, {"toolUseId": tool_use_id, **result}) + + async def stream( + self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any + ) -> AsyncGenerator[ToolResultEvent, None]: + """Stream robot task execution with async actions.""" + try: + tool_use_id = tool_use.get("toolUseId", "") + input_data = tool_use.get("input", {}) + + action = input_data.get("action", "execute") + + # Handle different actions + if action == "execute": + # Blocking execution (legacy behavior) + instruction = input_data.get("instruction", "") + policy_port = input_data.get("policy_port") + policy_host = input_data.get("policy_host", "localhost") + policy_provider = input_data.get("policy_provider", "groot") + duration = input_data.get("duration", 30.0) + + if not instruction or not policy_port: + yield ToolResultEvent( + self._make_tool_result( + tool_use_id, + { + "status": "error", + "content": [{"text": "❌ instruction and policy_port are required for execute action"}], + }, + ) + ) + return + + # Execute task synchronously + task_result = self._execute_task_sync(instruction, policy_port, policy_host, policy_provider, duration) + yield ToolResultEvent(self._make_tool_result(tool_use_id, task_result)) + + elif action == "start": + # Asynchronous execution start + instruction = input_data.get("instruction", "") + policy_port = input_data.get("policy_port") + policy_host = input_data.get("policy_host", "localhost") + policy_provider = input_data.get("policy_provider", "groot") + duration = input_data.get("duration", 30.0) + + if not instruction or not policy_port: + yield ToolResultEvent( + self._make_tool_result( + tool_use_id, + { + "status": "error", + "content": [{"text": "❌ instruction and policy_port are required for start action"}], + }, + ) + ) + return + + # Start task asynchronously + start_result = self.start_task(instruction, policy_port, policy_host, policy_provider, duration) + yield ToolResultEvent(self._make_tool_result(tool_use_id, start_result)) + + elif action == "status": + # Get current task status + status_result = self.get_task_status() + yield ToolResultEvent(self._make_tool_result(tool_use_id, status_result)) + + elif action == "stop": + # Stop current task + stop_result = self.stop_task() + yield ToolResultEvent(self._make_tool_result(tool_use_id, stop_result)) + + else: + yield ToolResultEvent( + self._make_tool_result( + tool_use_id, + { + "status": "error", + "content": [ + {"text": f"❌ Unknown action: {action}. Valid actions: execute, start, status, stop"} + ], + }, + ) + ) + + except Exception as e: + logger.error(f"❌ {self.tool_name_str} error: {e}") + yield ToolResultEvent( + self._make_tool_result( + tool_use_id, + { + "status": "error", + "content": [{"text": f"❌ {self.tool_name_str} error: {str(e)}"}], + }, + ) + ) + + def cleanup(self) -> None: + """Cleanup resources and stop any running tasks.""" + try: + # Signal shutdown + self._shutdown_event.set() + + # Stop any running task + if self._task_state.status == TaskStatus.RUNNING: + self.stop_task() + + # Shutdown executor + self._executor.shutdown(wait=True) + + logger.info(f"🧹 {self.tool_name_str} cleanup completed") + + except Exception as e: + logger.error(f"❌ Cleanup error for {self.tool_name_str}: {e}") + + def __del__(self) -> None: + """Destructor to ensure cleanup.""" + try: + self.cleanup() + except Exception: + pass # Ignore errors in destructor + + async def get_status(self) -> dict[str, Any]: + """Get robot status including connection and task state.""" + try: + # Get robot connection status + is_connected = self.robot.is_connected if hasattr(self.robot, "is_connected") else False + is_calibrated = self.robot.is_calibrated if hasattr(self.robot, "is_calibrated") else True + + # Get camera status + camera_status = [] + if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): + for name in self.robot.config.cameras.keys(): + camera_status.append(name) + + # Build status dict + status_data = { + "robot_name": self.tool_name_str, + "robot_type": getattr(self.robot, "robot_type", self.robot.name), + "robot_info": str(self.robot), + "data_config": self.data_config, + "is_connected": is_connected, + "is_calibrated": is_calibrated, + "cameras": camera_status, + "task_status": self._task_state.status.value, + "current_instruction": self._task_state.instruction, + "task_duration": self._task_state.duration, + "task_steps": self._task_state.step_count, + } + + # Add error info if present + if self._task_state.error_message: + status_data["task_error"] = self._task_state.error_message + + return status_data + + except Exception as e: + logger.error(f"❌ Error getting status for {self.tool_name_str}: {e}") + return { + "robot_name": self.tool_name_str, + "error": str(e), + "is_connected": False, + "task_status": "error", + } + + async def stop(self) -> None: + """Stop robot and disconnect.""" + try: + # Stop any running task first + if self._task_state.status == TaskStatus.RUNNING: + self.stop_task() + + # Disconnect robot hardware + if hasattr(self.robot, "disconnect"): + await asyncio.to_thread(self.robot.disconnect) + + # Cleanup resources + self.cleanup() + + logger.info(f"🛑 {self.tool_name_str} stopped and disconnected") + + except Exception as e: + logger.error(f"❌ Error stopping robot: {e}") diff --git a/strands_robots/robot.py b/strands_robots/robot.py index 33c5ba7..fa4a364 100644 --- a/strands_robots/robot.py +++ b/strands_robots/robot.py @@ -1,758 +1,205 @@ -#!/usr/bin/env python3 -""" -Universal Robot Control with Policy Abstraction for Any VLA Provider - -This module provides a clean robot interface that works with any LeRobot-compatible -robot and any VLA provider through the Policy abstraction. - -Features: -- Async robot task execution with real-time status reporting -- Non-blocking operations - robot moves while tool returns status -- Stop functionality to interrupt running tasks -- Connection state management with proper error handling -- Policy abstraction for any VLA provider -""" - -from __future__ import annotations - -import asyncio -import logging -import threading -import time -from collections.abc import AsyncGenerator -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass -from enum import Enum -from typing import TYPE_CHECKING, Any, cast - -from strands.tools.tools import AgentTool -from strands.types._events import ToolResultEvent -from strands.types.tools import ToolResult, ToolSpec, ToolUse +"""Unified Robot factory — convenience layer over ``strands_robots.simulation`` +and ``strands_robots.hardware_robot``. -if TYPE_CHECKING: - from lerobot.robots.config import RobotConfig - from lerobot.robots.robot import Robot as LeRobotRobot +Provides: + - ``Robot("so100")`` → returns a simulation by default (safe) + - ``Robot("so100", mode="real")`` → explicit real hardware + - ``Robot("so100", mode="auto")`` → auto-detects sim/real + - ``list_robots()`` → what's available - from .policies import Policy +Environment Variables: + STRANDS_ROBOT_MODE: Override mode detection ("sim" or "real"). -logger = logging.getLogger(__name__) - - -class TaskStatus(Enum): - """Robot task execution status""" - - IDLE = "idle" - CONNECTING = "connecting" - RUNNING = "running" - COMPLETED = "completed" - STOPPED = "stopped" - ERROR = "error" - - -@dataclass -class RobotTaskState: - """Robot task execution state""" - - status: TaskStatus = TaskStatus.IDLE - instruction: str = "" - start_time: float = 0.0 - duration: float = 0.0 - step_count: int = 0 - error_message: str = "" - task_future: Future | None = None - - -class Robot(AgentTool): - """Universal robot control with async task execution and status reporting.""" - - def __init__( - self, - tool_name: str, - robot: LeRobotRobot | RobotConfig | str, - cameras: dict[str, dict[str, Any]] | None = None, - action_horizon: int = 8, - data_config: str | Any | None = None, - control_frequency: float = 50.0, - **kwargs, - ): - """Initialize Robot with async capabilities. - - Args: - tool_name: Name for this robot tool - robot: LeRobot Robot instance, RobotConfig, or robot type string - cameras: Camera configuration dict: - {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}} - action_horizon: Actions per inference step - data_config: Data configuration (for GR00T compatibility) - control_frequency: Control loop frequency in Hz (default: 50Hz) - **kwargs: Robot-specific parameters (port, etc.) - """ - super().__init__() - - self.tool_name_str = tool_name - self.action_horizon = action_horizon - self.data_config = data_config - self.control_frequency = control_frequency - self.action_sleep_time = 1.0 / control_frequency # Time between actions - - # Task execution state - self._task_state = RobotTaskState() - self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"{tool_name}_executor") - self._shutdown_event = threading.Event() - - # Initialize robot using lerobot's abstraction - self.robot = self._initialize_robot(robot, cameras, **kwargs) - - logger.info(f"🤖 {tool_name} initialized with async capabilities") - logger.info(f"📱 Robot: {self.robot.name} (type: {getattr(self.robot, 'robot_type', 'unknown')})") - logger.info(f"⏱️ Control frequency: {control_frequency}Hz ({self.action_sleep_time * 1000:.1f}ms per action)") - - # Get camera info if available - if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): - cameras_list = list(self.robot.config.cameras.keys()) - logger.info(f"📹 Cameras: {cameras_list}") - - if data_config: - logger.info(f"⚙️ Data config: {data_config}") - - def _initialize_robot( - self, robot: LeRobotRobot | RobotConfig | str, cameras: dict[str, dict[str, Any]] | None, **kwargs - ) -> LeRobotRobot: - """Initialize LeRobot robot instance using native lerobot patterns.""" - from lerobot.robots.config import RobotConfig - from lerobot.robots.robot import Robot as LeRobotRobot - from lerobot.robots.utils import make_robot_from_config - - # Direct robot instance - use as-is - if isinstance(robot, LeRobotRobot): - return robot - - # Robot config - use lerobot's factory - elif isinstance(robot, RobotConfig): - return make_robot_from_config(robot) - - # Robot type string - create config and use lerobot's factory - elif isinstance(robot, str): - config = self._create_minimal_config(robot, cameras, **kwargs) - return make_robot_from_config(config) - - else: - raise ValueError( - f"Unsupported robot type: {type(robot)}. " - f"Expected LeRobot Robot instance, RobotConfig, or robot type string." - ) +Examples:: - def _create_minimal_config( - self, robot_type: str, cameras: dict[str, dict[str, Any]] | None, **kwargs - ) -> RobotConfig: - """Create minimal robot config using specific robot config classes.""" - from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig - - # Convert cameras to lerobot format - camera_configs = {} - if cameras: - for name, config in cameras.items(): - if config.get("type", "opencv") == "opencv": - camera_configs[name] = OpenCVCameraConfig( - index_or_path=config["index_or_path"], - fps=config.get("fps", 30), - width=config.get("width", 640), - height=config.get("height", 480), - rotation=config.get("rotation", 0), - color_mode=config.get("color_mode", "rgb"), - ) - else: - raise ValueError(f"Unsupported camera type: {config.get('type')}") - - # Map robot type to specific config class - config_mapping = { - "so101_follower": ("lerobot.robots.so101_follower", "SO101FollowerConfig"), - "so100_follower": ("lerobot.robots.so100_follower", "SO100FollowerConfig"), - "bi_so100_follower": ("lerobot.robots.bi_so100_follower", "BiSO100FollowerConfig"), - "viperx": ("lerobot.robots.viperx", "ViperXConfig"), - "koch_follower": ("lerobot.robots.koch_follower", "KochFollowerConfig"), - # Add more as needed - } - - if robot_type not in config_mapping: - raise ValueError(f"Unsupported robot type: {robot_type}. Supported types: {list(config_mapping.keys())}") + # Default: simulation (safe — no physical hardware interaction) + sim = Robot("so100") - # Import specific config class dynamically - module_name, class_name = config_mapping[robot_type] - try: - import importlib + # Explicit real hardware + hw = Robot("so100", mode="real", cameras={...}) - module = importlib.import_module(module_name) - ConfigClass = getattr(module, class_name) - except Exception as e: - raise ValueError(f"Failed to import {class_name} from {module_name}: {e}") + # Auto-detect (probes USB for servo controllers) + robot = Robot("so100", mode="auto") - # Create config with proper parameters - config_data = { - "id": self.tool_name_str, - "cameras": camera_configs, - } + # With custom URDF/MJCF path + sim = Robot("my_arm", urdf_path="/path/to/robot.xml") - # Filter kwargs to only include supported fields for this robot type - # Port is common for most serial robots - if "port" in kwargs: - config_data["port"] = kwargs["port"] +Future (not yet implemented):: - # Add other common fields as needed - for key in ["calibration_dir", "mock", "use_degrees"]: - if key in kwargs: - config_data[key] = kwargs[key] - - try: - return ConfigClass(**config_data) - except Exception as e: - raise ValueError(f"Failed to create {class_name} for robot type '{robot_type}': {e}. Config: {config_data}") + sim = Robot("unitree_go2", backend="isaac", num_envs=4096) + sim = Robot("so100", backend="newton", num_envs=4096) +""" - async def _get_policy( - self, policy_port: int | None = None, policy_host: str = "localhost", policy_provider: str = "groot" - ) -> Policy: - """Create policy on-the-fly from invocation parameters.""" - from .policies import create_policy +import logging +import os +from typing import Any - if not policy_port: - raise ValueError("policy_port is required for robot operation") +from strands_robots.registry import ( + get_hardware_type, + has_hardware, + resolve_name, +) - policy_config = {"port": policy_port, "host": policy_host} +logger = logging.getLogger(__name__) - if self.data_config: - policy_config["data_config"] = self.data_config - return create_policy(policy_provider, **policy_config) +def _auto_detect_mode(canonical: str) -> str: + """Auto-detect sim vs real mode. - async def _connect_robot(self) -> tuple[bool, str]: - """Connect to robot hardware with proper error handling. + Priority: + 1. ``STRANDS_ROBOT_MODE`` env var (explicit override) + 2. Robot-specific USB detection (Feetech/Dynamixel servo controllers) + 3. Default to sim (safest — never accidentally send commands to hardware) + """ + env_mode = os.getenv("STRANDS_ROBOT_MODE", "").lower() + if env_mode in ("sim", "real"): + return env_mode - Returns: - tuple[bool, str]: (success, error_message) - error_message is empty on success - """ + # Only probe USB if the robot actually has hardware support + if has_hardware(canonical): try: - # Import lerobot exceptions - from lerobot.utils.errors import DeviceAlreadyConnectedError - - # Check if already connected - if self.robot.is_connected: - logger.info(f"✅ {self.robot} already connected") - return True, "" - - logger.info(f"🔌 Connecting to {self.robot}...") - - # Handle robot connection using lerobot's error handling patterns - try: - if not self.robot.is_connected: - await asyncio.to_thread(self.robot.connect, False) # calibrate=False - - except DeviceAlreadyConnectedError: - # This is expected and fine - robot is already connected - logger.info(f"✅ {self.robot} was already connected") - - except Exception as e: - # Check if it's the string version of "already connected" error - error_str = str(e).lower() - if "already connected" in error_str or "is already connected" in error_str: - logger.info(f"✅ {self.robot} connection already established") - else: - # Re-raise if it's a different error - raise e - - # Final connection check - if not self.robot.is_connected: - error_msg = f"Failed to connect to {self.robot}" - logger.error(f"❌ {error_msg}") - return False, error_msg - - # Check robot calibration - if hasattr(self.robot, "is_calibrated") and not self.robot.is_calibrated: - error_msg = ( - f"Robot {self.robot} is not calibrated. Please calibrate the robot manually" - " first using LeRobot's calibration process (lerobot-calibrate)" + import serial.tools.list_ports + + ports = list(serial.tools.list_ports.comports()) + servo_keywords = ["feetech", "dynamixel", "sts3215", "xl430", "xl330"] + exclude = ["bluetooth", "internal", "debug", "apple", "modem"] + robot_ports = [ + p + for p in ports + if any( + kw in ((p.description or "") + (getattr(p, "manufacturer", None) or "")).lower() + for kw in servo_keywords ) - logger.error(f"❌ {error_msg}") - return False, error_msg - - logger.info(f"✅ {self.robot} connected and ready") - return True, "" - - except Exception as e: - error_msg = f"Robot connection failed: {e}. Ensure robot is calibrated and accessible on the specified port" - logger.error(f"❌ {error_msg}") - return False, error_msg - - async def _initialize_policy(self, policy: Policy) -> bool: - """Initialize policy with robot state keys.""" - try: - # Get robot state keys from observation - test_obs = await asyncio.to_thread(self.robot.get_observation) - - # Filter out camera keys to get robot state keys - camera_keys = [] - if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): - camera_keys = list(self.robot.config.cameras.keys()) - - robot_state_keys = [k for k in test_obs.keys() if k not in camera_keys] - - # Set robot state keys in policy - policy.set_robot_state_keys(robot_state_keys) - return True - - except Exception as e: - logger.error(f"❌ Failed to initialize policy: {e}") - return False - - async def _execute_task_async( - self, - instruction: str, - policy_port: int | None = None, - policy_host: str = "localhost", - policy_provider: str = "groot", - duration: float = 30.0, - ) -> None: - """Execute robot task in background thread (internal method).""" - try: - # Update task state - self._task_state.status = TaskStatus.CONNECTING - self._task_state.instruction = instruction - self._task_state.start_time = time.time() - self._task_state.step_count = 0 - self._task_state.error_message = "" - - # Connect to robot - connected, connect_error = await self._connect_robot() - if not connected: - self._task_state.status = TaskStatus.ERROR - self._task_state.error_message = connect_error or f"Failed to connect to {self.tool_name_str}" - return - - # Get policy instance - policy_instance = await self._get_policy(policy_port, policy_host, policy_provider) - - # Initialize policy with robot state keys - if not await self._initialize_policy(policy_instance): - self._task_state.status = TaskStatus.ERROR - self._task_state.error_message = "Failed to initialize policy" - return - - logger.info(f"🎯 Starting task: '{instruction}' on {self.tool_name_str}") - logger.info(f"🧠 Using policy: {policy_provider} on {policy_host}:{policy_port}") - - self._task_state.status = TaskStatus.RUNNING - start_time = time.time() - - while ( - time.time() - start_time < duration - and self._task_state.status == TaskStatus.RUNNING - and not self._shutdown_event.is_set() - ): - # Get observation from robot - observation = await asyncio.to_thread(self.robot.get_observation) - - # Get actions from policy - robot_actions = await policy_instance.get_actions(observation, instruction) - - # Execute actions from chunk with proper timing control - # Wait between actions for smooth execution - for action_dict in robot_actions[: self.action_horizon]: - if self._task_state.status != TaskStatus.RUNNING: - break - await asyncio.to_thread(self.robot.send_action, action_dict) - self._task_state.step_count += 1 - # Wait for action to complete before sending next action - # Default 50Hz (0.02s) - await asyncio.sleep(self.action_sleep_time) - - # Update final state - elapsed = time.time() - start_time - self._task_state.duration = elapsed - - if self._task_state.status == TaskStatus.RUNNING: - self._task_state.status = TaskStatus.COMPLETED + and not any(s in (p.description or "").lower() for s in exclude) + ] + if robot_ports: logger.info( - f"✅ Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)" + "Auto-detected robot hardware: %s", + [p.device for p in robot_ports], ) + return "real" + except (ImportError, OSError): # USB probing may fail with OSError on permission/device issues + pass + + return "sim" + + +def Robot( + name: str, + mode: str = "sim", + backend: str = "mujoco", + urdf_path: str | None = None, + cameras: dict[str, dict[str, Any]] | None = None, + position: list[float] | None = None, + **kwargs: Any, +) -> Any: + """Create a robot — returns a Simulation or HardwareRobot instance. + + This is a convenience factory, NOT a wrapper class. You get the real + backend instance back — with full access to all its methods. + + Defaults to simulation mode so that ``Robot("so100")`` never + accidentally sends commands to physical hardware. Use + ``mode="real"`` to explicitly opt into hardware control. + + Args: + name: Robot name ("so100", "aloha", "unitree_g1", "panda", ...) + Accepts any alias defined in ``registry/robots.json``. + mode: "sim" (default — safe), "real" (explicit hardware), or + "auto" (probes USB for servo controllers, falls back to sim). + backend: Simulation backend — currently only "mujoco" (CPU). + Future: "isaac" (GPU), "newton" (GPU). + urdf_path: Explicit path to URDF/MJCF file. If not provided, + resolved via ``strands_robots.simulation.model_registry`` + (asset manager or ``STRANDS_ASSETS_DIR`` search paths). + cameras: Camera config for real hardware. Example:: + + {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}} + + position: Robot position in sim world [x, y, z]. + **kwargs: Forwarded to the underlying backend constructor. + + Returns: + ``strands_robots.simulation.Simulation`` (sim) or + ``strands_robots.hardware_robot.Robot`` (real hardware). + + Raises: + RuntimeError: If the sim world or robot fails to initialize. + NotImplementedError: If an unimplemented backend is requested. + + Examples:: + + # Simulation (default — safe) + sim = Robot("so100") + + # Explicit MJCF model path + sim = Robot("my_arm", urdf_path="path/to/robot.xml") + + # Real hardware (explicit opt-in) + hw = Robot("so100", mode="real", cameras={...}) + + # Auto-detect (probes USB, falls back to sim) + robot = Robot("so100", mode="auto") + + # The 5-line promise + from strands_robots import Robot + from strands import Agent + robot = Robot("so100") + agent = Agent(tools=[robot]) + agent("Pick up the red cube") + """ + canonical = resolve_name(name) + + if mode == "auto": + mode = _auto_detect_mode(canonical) + + # ── Simulation ── + if mode == "sim": + if backend != "mujoco": + raise NotImplementedError( + f"Backend {backend!r} is not yet implemented. " + f"Currently supported: 'mujoco'. " + f"Isaac and Newton backends are on the roadmap." + ) - except Exception as e: - logger.error(f"❌ Task execution failed: {e}") - self._task_state.status = TaskStatus.ERROR - self._task_state.error_message = str(e) - - def _execute_task_sync( - self, - instruction: str, - policy_port: int | None = None, - policy_host: str = "localhost", - policy_provider: str = "groot", - duration: float = 30.0, - ) -> dict[str, Any]: - """Execute task synchronously in thread - no new event loop.""" - - # Import here to avoid conflicts - import asyncio - - # Run task without creating new event loop - let it run in thread - async def task_runner(): - await self._execute_task_async(instruction, policy_port, policy_host, policy_provider, duration) - - # Use asyncio.run only if no loop is running, otherwise run in existing loop - try: - # Try to get the current event loop - asyncio.get_running_loop() - # If we're already in an event loop, we need to run in a thread - import concurrent.futures - - with concurrent.futures.ThreadPoolExecutor() as exec: - future = exec.submit(lambda: asyncio.run(task_runner())) - future.result() # Wait for completion - except RuntimeError: - # No event loop running - safe to create one - asyncio.run(task_runner()) - - # Return final status - return { - "status": "success" if self._task_state.status == TaskStatus.COMPLETED else "error", - "content": [ - { - "text": f"✅ Task: '{instruction}' - {self._task_state.status.value}\n" - f"🤖 Robot: {self.tool_name_str} ({self.robot})\n" - f"🧠 Policy: {policy_provider} on {policy_host}:{policy_port}\n" - f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - f"🎯 Steps: {self._task_state.step_count}" - + (f"\n❌ Error: {self._task_state.error_message}" if self._task_state.error_message else "") - } - ], - } + from strands_robots.simulation import Simulation - def start_task( - self, - instruction: str, - policy_port: int | None = None, - policy_host: str = "localhost", - policy_provider: str = "groot", - duration: float = 30.0, - ) -> dict[str, Any]: - """Start robot task asynchronously and return immediately.""" - - # Check if task is already running - if self._task_state.status == TaskStatus.RUNNING: - return { - "status": "error", - "content": [{"text": f"❌ Task already running: {self._task_state.instruction}"}], - } - - # Start task in background - self._task_state.task_future = self._executor.submit( - self._execute_task_sync, instruction, policy_port, policy_host, policy_provider, duration + sim = Simulation( + tool_name=f"{canonical}_sim", + **kwargs, ) + sim._dispatch_action("create_world", {}) - return { - "status": "success", - "content": [ - { - "text": f"🚀 Task started: '{instruction}'\n" - f"🤖 Robot: {self.tool_name_str}\n" - f"💡 Use action='status' to check progress\n" - f"💡 Use action='stop' to interrupt" - } - ], + add_robot_params: dict[str, Any] = { + "robot_name": canonical, + "data_config": canonical, + "position": position or [0.0, 0.0, 0.0], } + if urdf_path: + add_robot_params["urdf_path"] = urdf_path + + result = sim._dispatch_action("add_robot", add_robot_params) + if result.get("status") == "error": + sim.destroy() # Clean up partial initialization (executor, temp dir, MuJoCo world) + content = result.get("content", []) + msg = content[0].get("text", str(result)) if content else str(result) + raise RuntimeError(f"Failed to create sim robot '{canonical}': {msg}") + return sim + + # ── Real hardware (explicit opt-in) ── + elif mode == "real": + from strands_robots.hardware_robot import Robot as HardwareRobot + + real_type = get_hardware_type(canonical) or canonical + return HardwareRobot( + tool_name=canonical, + robot=real_type, + cameras=cameras, + **kwargs, + ) - def get_task_status(self) -> dict[str, Any]: - """Get current task execution status.""" - - # Update duration for running tasks - if self._task_state.status == TaskStatus.RUNNING: - self._task_state.duration = time.time() - self._task_state.start_time - - status_text = f"📊 Robot Status: {self._task_state.status.value.upper()}\n" - - if self._task_state.instruction: - status_text += f"🎯 Task: {self._task_state.instruction}\n" - - if self._task_state.status == TaskStatus.RUNNING: - status_text += f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - status_text += f"🔄 Steps: {self._task_state.step_count}\n" - elif self._task_state.status in [TaskStatus.COMPLETED, TaskStatus.STOPPED, TaskStatus.ERROR]: - status_text += f"⏱️ Total Duration: {self._task_state.duration:.1f}s\n" - status_text += f"🎯 Total Steps: {self._task_state.step_count}\n" - - if self._task_state.error_message: - status_text += f"❌ Error: {self._task_state.error_message}\n" - - return { - "status": "success", - "content": [{"text": status_text}], - } - - def stop_task(self) -> dict[str, Any]: - """Stop currently running task.""" - - if self._task_state.status != TaskStatus.RUNNING: - return { - "status": "success", - "content": [{"text": f"💤 No task running to stop (current: {self._task_state.status.value})"}], - } - - # Signal task to stop - self._task_state.status = TaskStatus.STOPPED - - # Cancel future if it exists - if self._task_state.task_future: - self._task_state.task_future.cancel() - - logger.info(f"🛑 Task stopped: {self._task_state.instruction}") - - return { - "status": "success", - "content": [ - { - "text": f"🛑 Task stopped: '{self._task_state.instruction}'\n" - f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - f"🎯 Steps completed: {self._task_state.step_count}" - } - ], - } - - @property - def tool_name(self) -> str: - return self.tool_name_str - - @property - def tool_type(self) -> str: - return "robot" - - @property - def tool_spec(self) -> ToolSpec: - """Get tool specification with async actions.""" - return { - "name": self.tool_name_str, - "description": f"Universal robot control with async task execution ({self.robot}). " - f"Actions: execute (blocking), start (async), status, stop. " - f"For execute/start actions: instruction and policy_port are required. " - f"For status/stop actions: no additional parameters needed.", - "inputSchema": { - "json": { - "type": "object", - "properties": { - "action": { - "type": "string", - "description": "Action to perform: execute (blocking), start (async), status, stop", - "enum": ["execute", "start", "status", "stop"], - "default": "execute", - }, - "instruction": { - "type": "string", - "description": "Natural language instruction (required for execute/start actions)", - }, - "policy_port": { - "type": "integer", - "description": "Policy service port (required for execute/start actions)", - }, - "policy_host": { - "type": "string", - "description": "Policy service host (default: localhost)", - "default": "localhost", - }, - "policy_provider": { - "type": "string", - "description": "Policy provider (groot, openai, etc.)", - "default": "groot", - }, - "duration": { - "type": "number", - "description": "Maximum execution time in seconds", - "default": 30.0, - }, - }, - "required": ["action"], - } - }, - } - - @staticmethod - def _make_tool_result(tool_use_id: str, result: dict[str, Any]) -> ToolResult: - """Create a ToolResult dict with the given tool_use_id merged into result.""" - return cast(ToolResult, {"toolUseId": tool_use_id, **result}) - - async def stream( - self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any - ) -> AsyncGenerator[ToolResultEvent, None]: - """Stream robot task execution with async actions.""" - try: - tool_use_id = tool_use.get("toolUseId", "") - input_data = tool_use.get("input", {}) - - action = input_data.get("action", "execute") - - # Handle different actions - if action == "execute": - # Blocking execution (legacy behavior) - instruction = input_data.get("instruction", "") - policy_port = input_data.get("policy_port") - policy_host = input_data.get("policy_host", "localhost") - policy_provider = input_data.get("policy_provider", "groot") - duration = input_data.get("duration", 30.0) - - if not instruction or not policy_port: - yield ToolResultEvent( - self._make_tool_result( - tool_use_id, - { - "status": "error", - "content": [{"text": "❌ instruction and policy_port are required for execute action"}], - }, - ) - ) - return - - # Execute task synchronously - task_result = self._execute_task_sync(instruction, policy_port, policy_host, policy_provider, duration) - yield ToolResultEvent(self._make_tool_result(tool_use_id, task_result)) - - elif action == "start": - # Asynchronous execution start - instruction = input_data.get("instruction", "") - policy_port = input_data.get("policy_port") - policy_host = input_data.get("policy_host", "localhost") - policy_provider = input_data.get("policy_provider", "groot") - duration = input_data.get("duration", 30.0) - - if not instruction or not policy_port: - yield ToolResultEvent( - self._make_tool_result( - tool_use_id, - { - "status": "error", - "content": [{"text": "❌ instruction and policy_port are required for start action"}], - }, - ) - ) - return - - # Start task asynchronously - start_result = self.start_task(instruction, policy_port, policy_host, policy_provider, duration) - yield ToolResultEvent(self._make_tool_result(tool_use_id, start_result)) - - elif action == "status": - # Get current task status - status_result = self.get_task_status() - yield ToolResultEvent(self._make_tool_result(tool_use_id, status_result)) - - elif action == "stop": - # Stop current task - stop_result = self.stop_task() - yield ToolResultEvent(self._make_tool_result(tool_use_id, stop_result)) - - else: - yield ToolResultEvent( - self._make_tool_result( - tool_use_id, - { - "status": "error", - "content": [ - {"text": f"❌ Unknown action: {action}. Valid actions: execute, start, status, stop"} - ], - }, - ) - ) - - except Exception as e: - logger.error(f"❌ {self.tool_name_str} error: {e}") - yield ToolResultEvent( - self._make_tool_result( - tool_use_id, - { - "status": "error", - "content": [{"text": f"❌ {self.tool_name_str} error: {str(e)}"}], - }, - ) - ) - - def cleanup(self): - """Cleanup resources and stop any running tasks.""" - try: - # Signal shutdown - self._shutdown_event.set() - - # Stop any running task - if self._task_state.status == TaskStatus.RUNNING: - self.stop_task() - - # Shutdown executor - self._executor.shutdown(wait=True) - - logger.info(f"🧹 {self.tool_name_str} cleanup completed") - - except Exception as e: - logger.error(f"❌ Cleanup error for {self.tool_name_str}: {e}") - - def __del__(self): - """Destructor to ensure cleanup.""" - try: - self.cleanup() - except Exception: - pass # Ignore errors in destructor - - async def get_status(self) -> dict[str, Any]: - """Get robot status including connection and task state.""" - try: - # Get robot connection status - is_connected = self.robot.is_connected if hasattr(self.robot, "is_connected") else False - is_calibrated = self.robot.is_calibrated if hasattr(self.robot, "is_calibrated") else True - - # Get camera status - camera_status = [] - if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"): - for name in self.robot.config.cameras.keys(): - camera_status.append(name) - - # Build status dict - status_data = { - "robot_name": self.tool_name_str, - "robot_type": getattr(self.robot, "robot_type", self.robot.name), - "robot_info": str(self.robot), - "data_config": self.data_config, - "is_connected": is_connected, - "is_calibrated": is_calibrated, - "cameras": camera_status, - "task_status": self._task_state.status.value, - "current_instruction": self._task_state.instruction, - "task_duration": self._task_state.duration, - "task_steps": self._task_state.step_count, - } - - # Add error info if present - if self._task_state.error_message: - status_data["task_error"] = self._task_state.error_message - - return status_data - - except Exception as e: - logger.error(f"❌ Error getting status for {self.tool_name_str}: {e}") - return { - "robot_name": self.tool_name_str, - "error": str(e), - "is_connected": False, - "task_status": "error", - } - - async def stop(self): - """Stop robot and disconnect.""" - try: - # Stop any running task first - if self._task_state.status == TaskStatus.RUNNING: - self.stop_task() - - # Disconnect robot hardware - if hasattr(self.robot, "disconnect"): - await asyncio.to_thread(self.robot.disconnect) - - # Cleanup resources - self.cleanup() + else: + raise ValueError(f"Invalid mode {mode!r}. Choose 'sim', 'real', or 'auto'.") - logger.info(f"🛑 {self.tool_name_str} stopped and disconnected") - except Exception as e: - logger.error(f"❌ Error stopping robot: {e}") +__all__ = ["Robot"] diff --git a/strands_robots/simulation/__init__.py b/strands_robots/simulation/__init__.py index d9674a9..4aea6f8 100644 --- a/strands_robots/simulation/__init__.py +++ b/strands_robots/simulation/__init__.py @@ -7,9 +7,19 @@ ├── base.py ← SimEngine ABC ├── factory.py ← create_simulation() + backend registration ├── models.py ← shared dataclasses (SimWorld, SimRobot, ...) - └── model_registry.py ← URDF/MJCF resolution (shared across backends) - - # MuJoCo backend added in subsequent PRs. + ├── model_registry.py ← URDF/MJCF resolution (shared across backends) + └── mujoco/ ← MuJoCo CPU backend + ├── __init__.py + ├── backend.py ← lazy mujoco import + GL config + ├── mjcf_builder.py ← MJCF XML builder + ├── physics.py ← advanced physics (raycasting, jacobians, forces) + ├── scene_ops.py ← XML round-trip inject/eject + ├── rendering.py ← render RGB/depth, observations + ├── policy_runner.py ← run_policy, eval_policy, replay + ├── randomization.py ← domain randomization + ├── recording.py ← LeRobotDataset recording + ├── tool_spec.json ← AgentTool input schema + └── simulation.py ← Simulation (AgentTool orchestrator) Usage:: @@ -62,10 +72,15 @@ TrajectoryStep, ) -# --- Heavy imports (lazy — loaded when mujoco backend is available) --- -# MuJoCo-specific lazy imports will be added when the mujoco/ subpackage -# is introduced. For now, only the lightweight foundation is available. -_LAZY_IMPORTS: dict[str, tuple[str, str]] = {} +# --- Heavy imports (lazy — need strands SDK + mujoco) --- +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "Simulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"), + "MuJoCoSimulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"), + "MJCFBuilder": ("strands_robots.simulation.mujoco.mjcf_builder", "MJCFBuilder"), + "_configure_gl_backend": ("strands_robots.simulation.mujoco.backend", "_configure_gl_backend"), + "_ensure_mujoco": ("strands_robots.simulation.mujoco.backend", "_ensure_mujoco"), + "_is_headless": ("strands_robots.simulation.mujoco.backend", "_is_headless"), +} __all__ = [ @@ -75,9 +90,9 @@ "create_simulation", "list_backends", "register_backend", - # Default backend alias (available when mujoco backend is installed) - # "Simulation", - # "MuJoCoSimulation", + # Default backend alias + "Simulation", + "MuJoCoSimulation", # Shared dataclasses "SimStatus", "SimRobot", @@ -85,8 +100,8 @@ "SimCamera", "SimWorld", "TrajectoryStep", - # MuJoCo builder (available when mujoco backend is installed) - # "MJCFBuilder", + # MuJoCo builder + "MJCFBuilder", # Model registry "register_urdf", "resolve_model", diff --git a/strands_robots/simulation/mujoco/__init__.py b/strands_robots/simulation/mujoco/__init__.py new file mode 100644 index 0000000..03c6a03 --- /dev/null +++ b/strands_robots/simulation/mujoco/__init__.py @@ -0,0 +1,32 @@ +"""MuJoCo simulation backend for strands-robots. + +CPU-based physics with offscreen rendering. No GPU required. +Supports URDF/MJCF loading, multi-robot scenes, policy execution, +domain randomization, and LeRobotDataset recording. + +Usage:: + + from strands_robots.simulation.mujoco import MuJoCoSimulation + + sim = MuJoCoSimulation(tool_name="my_sim") + sim.create_world() + sim.add_robot("so100", data_config="so100") + sim.run_policy("so100", policy_provider="mock", instruction="wave") + +Or via the top-level alias:: + + from strands_robots.simulation import Simulation # → MuJoCoSimulation +""" + +__all__ = [ + "MuJoCoSimulation", +] + + +def __getattr__(name: str) -> "type": + if name == "MuJoCoSimulation": + from strands_robots.simulation.mujoco.simulation import Simulation as _Sim + + globals()["MuJoCoSimulation"] = _Sim + return _Sim + raise AttributeError(f"module 'strands_robots.simulation.mujoco' has no attribute {name!r}") diff --git a/strands_robots/simulation/mujoco/backend.py b/strands_robots/simulation/mujoco/backend.py new file mode 100644 index 0000000..9c0873d --- /dev/null +++ b/strands_robots/simulation/mujoco/backend.py @@ -0,0 +1,138 @@ +"""MuJoCo lazy import and GL backend configuration.""" + +import ctypes +import logging +import os +import sys +from typing import Any + +logger = logging.getLogger(__name__) + +_mujoco = None +_mujoco_viewer = None + + +def _is_headless() -> bool: + """Detect if running in a headless environment (no display server). + + Returns True on Linux when no DISPLAY or WAYLAND_DISPLAY is set, + which means GLFW-based rendering will fail. + + Windows and macOS are always False because MuJoCo uses native + windowing backends (WGL on Windows, CGL on macOS) that support + offscreen rendering without X11/Wayland. The EGL/OSMesa fallback + is Linux-specific. + """ + if sys.platform != "linux": + return False + if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"): + return False + return True + + +def _configure_gl_backend() -> None: # noqa: C901 + """Auto-configure MuJoCo's OpenGL backend for headless environments. + + MuJoCo reads MUJOCO_GL at import time to select the OpenGL backend: + - "egl" → EGL (GPU-accelerated offscreen, requires libEGL + NVIDIA driver) + - "osmesa" → OSMesa (CPU software rendering, slower but always works) + - "glfw" → GLFW (default, requires X11/Wayland display server) + + This function MUST be called before `import mujoco`. Setting MUJOCO_GL + after import has no effect — the backend is locked at import time. + + Never overrides a user-set MUJOCO_GL value. + """ + if os.environ.get("MUJOCO_GL"): + logger.debug(f"MUJOCO_GL already set to '{os.environ['MUJOCO_GL']}', respecting user config") + return + + if not _is_headless(): + return + + # Headless Linux — probe for EGL first (GPU-accelerated), then fall back to OSMesa (CPU) + try: + ctypes.cdll.LoadLibrary("libEGL.so.1") + os.environ["MUJOCO_GL"] = "egl" + logger.info("Headless environment detected — using MUJOCO_GL=egl (GPU-accelerated offscreen)") + return + except OSError: + pass + + try: + ctypes.cdll.LoadLibrary("libOSMesa.so") + os.environ["MUJOCO_GL"] = "osmesa" + logger.info("Headless environment detected — using MUJOCO_GL=osmesa (CPU software rendering)") + return + except OSError: + pass + + logger.warning( + "Headless environment detected but neither EGL nor OSMesa found. " + "MuJoCo rendering will likely fail. Install one of:\n" + " GPU: apt-get install libegl1-mesa-dev (or NVIDIA driver provides libEGL)\n" + " CPU: apt-get install libosmesa6-dev\n" + "Then set: export MUJOCO_GL=egl (or osmesa)" + ) + + +def _ensure_mujoco() -> "Any": + """Lazy import MuJoCo to avoid hard dependency. + + Auto-configures the OpenGL backend for headless environments before + importing mujoco, since MUJOCO_GL must be set at import time. + + Uses require_optional() for consistent dependency management across + the strands-robots package. + """ + global _mujoco, _mujoco_viewer + if _mujoco is None: + _configure_gl_backend() + from strands_robots.utils import require_optional + + _mujoco = require_optional( + "mujoco", + pip_install="mujoco", + extra="sim-mujoco", + purpose="MuJoCo simulation", + ) + if _mujoco_viewer is None and not _is_headless(): + try: + import mujoco.viewer as viewer + + _mujoco_viewer = viewer + except ImportError: + pass + return _mujoco + + +_rendering_available: bool | None = None + + +def _can_render() -> bool: + """Check if MuJoCo offscreen rendering is available. + + Probes once by creating a minimal Renderer. Result is cached. + Returns False on headless environments without EGL/OSMesa. + """ + global _rendering_available + if _rendering_available is not None: + return _rendering_available + + mj = _ensure_mujoco() + try: + model = mj.MjModel.from_xml_string("") + renderer = mj.Renderer(model, height=1, width=1) + renderer.close() + del renderer + _rendering_available = True + logger.info("MuJoCo rendering available") + except Exception as e: + _rendering_available = False + logger.warning( + "MuJoCo rendering unavailable: %s. " + "Physics/policy will work, but render/camera observations will be skipped. " + "Install EGL or OSMesa for offscreen rendering.", + e, + ) + return _rendering_available diff --git a/strands_robots/simulation/mujoco/mjcf_builder.py b/strands_robots/simulation/mujoco/mjcf_builder.py new file mode 100644 index 0000000..c8bc70d --- /dev/null +++ b/strands_robots/simulation/mujoco/mjcf_builder.py @@ -0,0 +1,215 @@ +"""MJCF XML builder — programmatic scene construction.""" + +import logging +import os +import re +import subprocess +import tempfile + +from strands_robots.simulation.models import SimCamera, SimObject, SimRobot, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +_VALID_NAME_RE = re.compile(r"^[a-zA-Z0-9_][a-zA-Z0-9_.\-]{0,127}$") + + +def _sanitize_name(name: str) -> str: + """Validate and sanitize an object/body name for safe MJCF XML embedding. + + Raises ValueError if name contains characters that could cause XML injection. + """ + if not _VALID_NAME_RE.match(name): + raise ValueError(f"Invalid simulation name {name!r}: must match [a-zA-Z0-9_][a-zA-Z0-9_.\\-]{{0,127}}") + return name + + +class MJCFBuilder: + """Builds MuJoCo MJCF XML from SimWorld state.""" + + @staticmethod + def build_objects_only(world: SimWorld) -> str: + """Build MJCF XML for a world with only objects (robots loaded separately).""" + _ensure_mujoco() + + parts = [] + parts.append('') + parts.append(' ') + + gx, gy, gz = world.gravity + parts.append(f' ") + + return "\n".join(parts) + + @staticmethod + def _object_xml(obj: SimObject, indent: int = 4) -> str: + """Generate MJCF XML for a single object.""" + pad = " " * indent + px, py, pz = obj.position + qw, qx, qy, qz = obj.orientation + r, g, b, a = obj.color + lines = [] + + lines.append(f'{pad}') + + if not obj.is_static: + lines.append(f'{pad} ') + lines.append(f'{pad} ') + + if obj.shape == "box": + sx, sy, sz = [s / 2 for s in obj.size] + lines.append( + f'{pad} ' + ) + elif obj.shape == "sphere": + radius = obj.size[0] / 2 if obj.size else 0.025 + lines.append( + f'{pad} ' + ) + elif obj.shape == "cylinder": + radius = obj.size[0] / 2 if obj.size else 0.025 + half_h = obj.size[2] / 2 if len(obj.size) > 2 else 0.05 + lines.append( + f'{pad} ' + ) + elif obj.shape == "capsule": + radius = obj.size[0] / 2 if obj.size else 0.025 + half_h = obj.size[2] / 2 if len(obj.size) > 2 else 0.05 + lines.append( + f'{pad} ' + ) + elif obj.shape == "mesh" and obj.mesh_path: + lines.append( + f'{pad} ' + ) + elif obj.shape == "plane": + sx = obj.size[0] if obj.size else 1.0 + sy = obj.size[1] if len(obj.size) > 1 else sx + lines.append( + f'{pad} ' + ) + + lines.append(f"{pad}") + return "\n".join(lines) + + @staticmethod + def compose_multi_robot_scene( + robots: dict[str, SimRobot], + objects: dict[str, SimObject], + cameras: dict[str, SimCamera], + world: SimWorld, + ) -> str: + """Compose a multi-robot scene by merging URDF-derived MJCF fragments.""" + mj = _ensure_mujoco() + world._backend_state["tmpdir"] = tempfile.TemporaryDirectory(prefix="strands_sim_") + tmpdir = world._backend_state["tmpdir"].name + + robot_xmls = {} + for robot_name, robot in robots.items(): + try: + model = mj.MjModel.from_xml_path(str(robot.urdf_path)) + robot_xml_path = os.path.join(tmpdir, f"{robot_name}.xml") + mj.mj_saveLastXML(robot_xml_path, model) + robot_xmls[robot_name] = robot_xml_path + logger.debug("Converted %s → %s", robot.urdf_path, robot_xml_path) + except (FileNotFoundError, OSError, subprocess.CalledProcessError) as e: + logger.error("Failed to convert URDF for '%s': %s", robot_name, e) + raise + + parts = [] + parts.append('') + parts.append(' ') + + gx, gy, gz = world.gravity + parts.append(f' ") + + master_xml = "\n".join(parts) + master_path = os.path.join(tmpdir, "master_scene.xml") + with open(master_path, "w") as f: + f.write(master_xml) + + return master_path diff --git a/strands_robots/simulation/mujoco/physics.py b/strands_robots/simulation/mujoco/physics.py new file mode 100644 index 0000000..3e366f6 --- /dev/null +++ b/strands_robots/simulation/mujoco/physics.py @@ -0,0 +1,829 @@ +"""Physics mixin — advanced MuJoCo physics introspection and manipulation. + +Exposes the deep MuJoCo C API through clean Python methods: +- Raycasting (mj_ray) +- Jacobians (mj_jacBody, mj_jacSite, mj_jacGeom) +- Energy computation (mj_energyPos, mj_energyVel) +- External forces (mj_applyFT, xfrc_applied) +- Mass matrix (mj_fullM) +- State checkpointing (mj_getState, mj_setState) +- Inverse dynamics (mj_inverse) +- Body/joint introspection (poses, velocities, accelerations) +- Direct joint position/velocity control (qpos, qvel) +- Runtime model modification (mass, friction, color, size) +- Sensor readout (sensordata) +- Contact force analysis (mj_contactForce) +""" + +import json +import logging +from typing import TYPE_CHECKING, Any + +import numpy as np + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class PhysicsMixin: + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + + """Advanced physics capabilities for Simulation. + + Expects: self._world (SimWorld with _model, _data) + + Naming: methods match action names in tool_spec.json for direct dispatch. + """ + + # ── State Checkpointing ── + + def save_state(self, name: str = "default") -> dict[str, Any]: + """Save the full physics state (qpos, qvel, act, time) to a named checkpoint. + + Uses mj_getState with mjSTATE_PHYSICS for complete state capture. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + state_size = mj.mj_stateSize(model, mj.mjtState.mjSTATE_PHYSICS) + state = np.zeros(state_size) + mj.mj_getState(model, data, state, mj.mjtState.mjSTATE_PHYSICS) + + if not hasattr(self._world, "_checkpoints"): + self._world._checkpoints = {} + + self._world._checkpoints[name] = { + "state": state.copy(), + "sim_time": self._world.sim_time, + "step_count": self._world.step_count, + } + + return { + "status": "success", + "content": [ + { + "text": ( + f"💾 State '{name}' saved\n" + f" t={self._world.sim_time:.4f}s, step={self._world.step_count}\n" + f" State vector: {state_size} floats\n" + f" Checkpoints: {list(self._world._checkpoints.keys())}" + ) + } + ], + } + + def load_state(self, name: str = "default") -> dict[str, Any]: + """Restore physics state from a named checkpoint.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + checkpoints = getattr(self._world, "_checkpoints", {}) + if name not in checkpoints: + available = list(checkpoints.keys()) if checkpoints else ["none"] + return { + "status": "error", + "content": [{"text": f"❌ Checkpoint '{name}' not found. Available: {available}"}], + } + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + checkpoint = checkpoints[name] + + mj.mj_setState(model, data, checkpoint["state"], mj.mjtState.mjSTATE_PHYSICS) + mj.mj_forward(model, data) + + self._world.sim_time = checkpoint["sim_time"] + self._world.step_count = checkpoint["step_count"] + + return { + "status": "success", + "content": [ + {"text": f"📂 State '{name}' restored (t={self._world.sim_time:.4f}s, step={self._world.step_count})"} + ], + } + + # ── External Forces ── + + def apply_force( + self, + body_name: str, + force: list[float] | None = None, + torque: list[float] | None = None, + point: list[float] | None = None, + ) -> dict[str, Any]: + """Apply external force and/or torque to a body. + + Uses mj_applyFT for precise force application at a world-frame point. + Forces persist for one timestep — call before each step for continuous force. + + Args: + body_name: Target body name. + force: [fx, fy, fz] in world frame (Newtons). + torque: [tx, ty, tz] in world frame (N·m). + point: [px, py, pz] world-frame point of force application. + Defaults to body CoM if not specified. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + body_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"❌ Body '{body_name}' not found."}]} + + f = np.array(force or [0, 0, 0], dtype=np.float64) + t = np.array(torque or [0, 0, 0], dtype=np.float64) + p = np.array(point, dtype=np.float64) if point else data.xipos[body_id].copy() + + mj.mj_applyFT(model, data, f, t, p, body_id, data.qfrc_applied) + + return { + "status": "success", + "content": [ + { + "text": ( + f"💨 Force applied to '{body_name}' (body {body_id})\n" + f" Force: {f.tolist()} N\n" + f" Torque: {t.tolist()} N·m\n" + f" Point: {p.tolist()}" + ) + } + ], + } + + # ── Raycasting ── + + def raycast( + self, + origin: list[float], + direction: list[float], + exclude_body: int = -1, + include_static: bool = True, + ) -> dict[str, Any]: + """Cast a ray and find the first geom intersection. + + Uses mj_ray for precise distance sensing / obstacle detection. + + Args: + origin: [x, y, z] ray start point in world frame. + direction: [dx, dy, dz] ray direction (auto-normalized). + exclude_body: Body ID to exclude from intersection (-1 = none). + include_static: Whether to include static geoms. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + pnt = np.array(origin, dtype=np.float64) + vec = np.array(direction, dtype=np.float64) + # Normalize direction + norm = np.linalg.norm(vec) + if norm > 0: + vec = vec / norm + + geomid = np.array([-1], dtype=np.int32) + dist = mj.mj_ray( + model, + data, + pnt, + vec, + None, # geom group filter (None = all) + 1 if include_static else 0, + exclude_body, + geomid, + ) + + hit = dist >= 0 + geom_name = None + if hit and geomid[0] >= 0: + geom_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, geomid[0]) + + result = { + "hit": hit, + "distance": float(dist) if hit else None, + "geom_id": int(geomid[0]) if hit else None, + "geom_name": geom_name, + "hit_point": (pnt + vec * dist).tolist() if hit else None, + } + + if hit: + text = f"🎯 Ray hit '{geom_name or geomid[0]}' at dist={dist:.4f}m, point={result['hit_point']}" + else: + text = "🎯 Ray: no intersection" + + return {"status": "success", "content": [{"text": text}, {"text": json.dumps(result, default=str)}]} + + # ── Jacobians ── + + def get_jacobian( + self, + body_name: str | None = None, + site_name: str | None = None, + geom_name: str | None = None, + ) -> dict[str, Any]: + """Compute the Jacobian (position + rotation) for a body, site, or geom. + + The Jacobian maps joint velocities to Cartesian velocities: + v = J @ dq + + Returns both positional (3×nv) and rotational (3×nv) Jacobians. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + jacp = np.zeros((3, model.nv)) + jacr = np.zeros((3, model.nv)) + + if body_name: + obj_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"❌ Body '{body_name}' not found."}]} + mj.mj_jacBody(model, data, jacp, jacr, obj_id) + label = f"body '{body_name}'" + elif site_name: + obj_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_SITE, site_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"❌ Site '{site_name}' not found."}]} + mj.mj_jacSite(model, data, jacp, jacr, obj_id) + label = f"site '{site_name}'" + elif geom_name: + obj_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_GEOM, geom_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"❌ Geom '{geom_name}' not found."}]} + mj.mj_jacGeom(model, data, jacp, jacr, obj_id) + label = f"geom '{geom_name}'" + else: + return {"status": "error", "content": [{"text": "❌ Specify body_name, site_name, or geom_name."}]} + + return { + "status": "success", + "content": [ + {"text": f"🧮 Jacobian for {label}: pos={jacp.shape}, rot={jacr.shape}, nv={model.nv}"}, + { + "text": json.dumps( + { + "jacp": jacp.tolist(), + "jacr": jacr.tolist(), + "nv": model.nv, + }, + default=str, + ) + }, + ], + } + + # ── Energy ── + + def get_energy(self) -> dict[str, Any]: + """Compute potential and kinetic energy of the system.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_energyPos(model, data) + mj.mj_energyVel(model, data) + + potential = float(data.energy[0]) + kinetic = float(data.energy[1]) + total = potential + kinetic + + return { + "status": "success", + "content": [ + {"text": f"⚡ Energy: potential={potential:.4f}J, kinetic={kinetic:.4f}J, total={total:.4f}J"}, + {"text": json.dumps({"potential": potential, "kinetic": kinetic, "total": total}, default=str)}, + ], + } + + # ── Mass Matrix ── + + def get_mass_matrix(self) -> dict[str, Any]: + """Compute the full mass (inertia) matrix M(q). + + M is nv×nv where nv is the number of DoFs. + Useful for dynamics analysis, impedance control, etc. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + nv = model.nv + M = np.zeros((nv, nv)) + mj.mj_fullM(model, M, data.qM) + rank = int(np.linalg.matrix_rank(M)) + cond = float(np.linalg.cond(M)) if rank > 0 else float("inf") + + return { + "status": "success", + "content": [ + {"text": f"🧮 Mass matrix: {nv}×{nv}, rank={rank}, cond={cond:.2e}"}, + { + "text": json.dumps( + { + "shape": [nv, nv], + "rank": rank, + "condition_number": cond, + "diagonal": np.diag(M).tolist(), + "total_mass": float(np.sum(model.body_mass)), + }, + default=str, + ) + }, + ], + } + + # ── Inverse Dynamics ── + + def inverse_dynamics(self) -> dict[str, Any]: + """Compute inverse dynamics: given qacc, what forces are needed? + + Runs mj_inverse to compute qfrc_inverse — the generalized forces + that would produce the current accelerations. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_inverse(model, data) + + # Build named force mapping + forces = {} + for i in range(model.njnt): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) + if name: + dof_adr = model.jnt_dofadr[i] + forces[name] = float(data.qfrc_inverse[dof_adr]) + + return { + "status": "success", + "content": [ + {"text": f"🔄 Inverse dynamics: {len(forces)} joint forces computed"}, + {"text": json.dumps({"qfrc_inverse": forces}, default=str)}, + ], + } + + # ── Body Introspection ── + + def get_body_state( + self, + body_name: str, + ) -> dict[str, Any]: + """Get the full state of a body: position, orientation, velocity, acceleration. + + Returns Cartesian pose + 6D spatial velocity (linear + angular). + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + body_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"❌ Body '{body_name}' not found."}]} + + # Position and orientation + pos = data.xpos[body_id].tolist() + quat = data.xquat[body_id].tolist() + rotmat = data.xmat[body_id].reshape(3, 3).tolist() + + # Velocity (6D: angular then linear in world frame) + vel = np.zeros(6) + mj.mj_objectVelocity(model, data, mj.mjtObj.mjOBJ_BODY, body_id, vel, 0) + linvel = vel[3:].tolist() + angvel = vel[:3].tolist() + + # Mass and inertia + mass = float(model.body_mass[body_id]) + com = data.xipos[body_id].tolist() + + state = { + "position": pos, + "quaternion": quat, + "rotation_matrix": rotmat, + "linear_velocity": linvel, + "angular_velocity": angvel, + "mass": mass, + "center_of_mass": com, + } + + text = ( + f"🏷️ Body '{body_name}' (id={body_id}):\n" + f" pos: [{pos[0]:.4f}, {pos[1]:.4f}, {pos[2]:.4f}]\n" + f" quat: [{quat[0]:.4f}, {quat[1]:.4f}, {quat[2]:.4f}, {quat[3]:.4f}]\n" + f" linvel: [{linvel[0]:.4f}, {linvel[1]:.4f}, {linvel[2]:.4f}]\n" + f" angvel: [{angvel[0]:.4f}, {angvel[1]:.4f}, {angvel[2]:.4f}]\n" + f" mass: {mass:.4f}kg, com: {com}" + ) + + return {"status": "success", "content": [{"text": text}, {"text": json.dumps(state, default=str)}]} + + # ── Direct Joint Control ── + + def set_joint_positions( + self, + positions: dict[str, float] | None = None, + robot_name: str | None = None, + ) -> dict[str, Any]: + """Set joint positions directly (bypassing actuators). + + Writes to qpos and runs mj_forward to update kinematics. + Useful for teleportation, IK solutions, or keyframe setting. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if positions is None: + return {"status": "error", "content": [{"text": "❌ positions dict required."}]} + + set_count = 0 + for jnt_name, value in positions.items(): + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + qpos_adr = model.jnt_qposadr[jnt_id] + data.qpos[qpos_adr] = float(value) + set_count += 1 + else: + logger.warning("Joint '%s' not found, skipping", jnt_name) + + mj.mj_forward(model, data) + + return { + "status": "success", + "content": [{"text": f"🎯 Set {set_count}/{len(positions)} joint positions, FK updated"}], + } + + def set_joint_velocities( + self, + velocities: dict[str, float] | None = None, + ) -> dict[str, Any]: + """Set joint velocities directly. + + Writes to qvel. Useful for initializing dynamics. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if velocities is None: + return {"status": "error", "content": [{"text": "❌ velocities dict required."}]} + + set_count = 0 + for jnt_name, value in velocities.items(): + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + dof_adr = model.jnt_dofadr[jnt_id] + data.qvel[dof_adr] = float(value) + set_count += 1 + + return { + "status": "success", + "content": [{"text": f"💨 Set {set_count}/{len(velocities)} joint velocities"}], + } + + # ── Sensor Readout ── + + def get_sensor_data(self, sensor_name: str | None = None) -> dict[str, Any]: + """Read sensor values from the simulation. + + MuJoCo supports: jointpos, jointvel, accelerometer, gyro, force, + torque, touch, rangefinder, framequat, subtreecom, clock, etc. + + Args: + sensor_name: Specific sensor name, or None for all sensors. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if model.nsensor == 0: + return {"status": "success", "content": [{"text": "📡 No sensors in model."}]} + + mj.mj_forward(model, data) + + sensors = {} + for i in range(model.nsensor): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_SENSOR, i) + if not name: + name = f"sensor_{i}" + + adr = model.sensor_adr[i] + dim = model.sensor_dim[i] + values = data.sensordata[adr : adr + dim].tolist() + + if sensor_name and name != sensor_name: + continue + + sensors[name] = { + "values": values if dim > 1 else values[0], + "dim": int(dim), + "type": int(model.sensor_type[i]), + } + + if sensor_name and sensor_name not in sensors: + return {"status": "error", "content": [{"text": f"❌ Sensor '{sensor_name}' not found."}]} + + lines = [f"📡 Sensors ({len(sensors)}/{model.nsensor}):"] + for name, info in sensors.items(): + lines.append(f" {name}: {info['values']} (dim={info['dim']})") + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"text": json.dumps({"sensors": sensors}, default=str)}], + } + + # ── Runtime Model Modification ── + + def set_body_properties( + self, + body_name: str, + mass: float | None = None, + ) -> dict[str, Any]: + """Modify body properties at runtime (no recompile needed). + + Changes take effect on the next mj_step. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model = self._world._model + body_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"❌ Body '{body_name}' not found."}]} + + changes = [] + if mass is not None: + old_mass = float(model.body_mass[body_id]) + model.body_mass[body_id] = mass + changes.append(f"mass: {old_mass:.3f} → {mass:.3f}") + + return { + "status": "success", + "content": [{"text": f"🔧 Body '{body_name}': {', '.join(changes)}"}], + } + + def set_geom_properties( + self, + geom_name: str | None = None, + geom_id: int | None = None, + color: list[float] | None = None, + friction: list[float] | None = None, + size: list[float] | None = None, + ) -> dict[str, Any]: + """Modify geom properties at runtime (no recompile needed). + + Changes take effect immediately for rendering (color) or next step (friction, size). + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model = self._world._model + + gid = geom_id + if geom_name: + gid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_GEOM, geom_name) + if gid is None or gid < 0: + return {"status": "error", "content": [{"text": f"❌ Geom '{geom_name or geom_id}' not found."}]} + + label = geom_name or f"geom_{gid}" + changes = [] + + if color is not None: + model.geom_rgba[gid] = color[:4] if len(color) >= 4 else color[:3] + [1.0] + changes.append(f"color → {model.geom_rgba[gid].tolist()}") + + if friction is not None: + fric = friction[:3] if len(friction) >= 3 else friction + [0.0] * (3 - len(friction)) + model.geom_friction[gid] = fric + changes.append(f"friction → {fric}") + + if size is not None: + n = min(len(size), 3) + model.geom_size[gid, :n] = size[:n] + changes.append(f"size → {model.geom_size[gid].tolist()}") + + return { + "status": "success", + "content": [{"text": f"🔧 Geom '{label}': {', '.join(changes)}"}], + } + + # ── Contact Force Analysis ── + + def get_contact_forces(self) -> dict[str, Any]: + """Get detailed contact forces for all active contacts. + + Uses mj_contactForce for each active contact pair. + Returns normal and friction forces. + """ + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + contacts = [] + for i in range(data.ncon): + c = data.contact[i] + g1 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom1) or f"geom_{c.geom1}" + g2 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom2) or f"geom_{c.geom2}" + + # Get contact force (normal + friction in contact frame) + force = np.zeros(6) + mj.mj_contactForce(model, data, i, force) + + contacts.append( + { + "geom1": g1, + "geom2": g2, + "distance": float(c.dist), + "position": c.pos.tolist(), + "normal_force": float(force[0]), + "friction_force": force[1:3].tolist(), + "full_wrench": force.tolist(), + } + ) + + if not contacts: + return {"status": "success", "content": [{"text": "💥 No active contacts."}]} + + lines = [f"💥 {len(contacts)} contacts:"] + for c in contacts[:15]: + lines.append(f" {c['geom1']} ↔ {c['geom2']}: normal={c['normal_force']:.3f}N, dist={c['distance']:.4f}m") + if len(contacts) > 15: + lines.append(f" ... and {len(contacts) - 15} more") + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"text": json.dumps({"contacts": contacts}, default=str)}], + } + + # ── Multi-Ray (batch raycasting) ── + + def multi_raycast( + self, + origin: list[float], + directions: list[list[float]], + exclude_body: int = -1, + ) -> dict[str, Any]: + """Cast multiple rays from a single origin (e.g., for LIDAR simulation). + + Efficiently casts N rays using individual mj_ray calls. + Returns array of distances and hit geoms. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + pnt = np.array(origin, dtype=np.float64) + results = [] + + for d in directions: + vec = np.array(d, dtype=np.float64) + norm = np.linalg.norm(vec) + if norm > 0: + vec /= norm + geomid = np.array([-1], dtype=np.int32) + dist = mj.mj_ray(model, data, pnt, vec, None, 1, exclude_body, geomid) + results.append( + { + "distance": float(dist) if dist >= 0 else None, + "geom_id": int(geomid[0]) if dist >= 0 else None, + } + ) + + hit_count = sum(1 for r in results if r["distance"] is not None) + return { + "status": "success", + "content": [ + {"text": f"🎯 Multi-ray: {hit_count}/{len(directions)} hits from {origin}"}, + {"text": json.dumps({"rays": results}, default=str)}, + ], + } + + # ── Forward Kinematics (explicit) ── + + def forward_kinematics(self) -> dict[str, Any]: + """Run forward kinematics to update all body positions/orientations. + + Usually called implicitly by mj_step, but useful after manually + setting qpos to see updated Cartesian positions. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_kinematics(model, data) + mj.mj_comPos(model, data) + mj.mj_camlight(model, data) + + # Build body position summary + bodies = {} + for i in range(model.nbody): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + bodies[name] = { + "position": data.xpos[i].tolist(), + "quaternion": data.xquat[i].tolist(), + } + + return { + "status": "success", + "content": [ + {"text": f"🦴 FK computed for {model.nbody} bodies"}, + {"text": json.dumps({"bodies": bodies}, default=str)}, + ], + } + + # ── Total Mass ── + + def get_total_mass(self) -> dict[str, Any]: + """Get total mass and per-body mass breakdown.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model = self._world._model + + total = float(mj.mj_getTotalmass(model)) + bodies = {} + for i in range(model.nbody): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + m = float(model.body_mass[i]) + if m > 0: + bodies[name] = m + + return { + "status": "success", + "content": [ + {"text": f"⚖️ Total mass: {total:.4f}kg ({len(bodies)} bodies with mass)"}, + {"text": json.dumps({"total_mass": total, "bodies": bodies}, default=str)}, + ], + } + + # ── Export Model XML ── + + def export_xml(self, output_path: str | None = None) -> dict[str, Any]: + """Export the current model to MJCF XML. + + Uses mj_saveLastXML — exports the exact model currently loaded, + including any runtime modifications. + """ + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + + if output_path: + mj.mj_saveLastXML(output_path, self._world._model) + return {"status": "success", "content": [{"text": f"📄 Model exported to {output_path}"}]} + else: + # Return XML string via saveLastXML to temp file + import os + import tempfile + + with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as tmp: + tmpfile = tmp.name + mj.mj_saveLastXML(tmpfile, self._world._model) + try: + with open(tmpfile) as f: + xml = f.read() + finally: + os.unlink(tmpfile) + return { + "status": "success", + "content": [ + {"text": f"📄 Model XML ({len(xml)} chars):\n{xml[:2000]}{'...' if len(xml) > 2000 else ''}"} + ], + } diff --git a/strands_robots/simulation/mujoco/policy_runner.py b/strands_robots/simulation/mujoco/policy_runner.py new file mode 100644 index 0000000..71f2503 --- /dev/null +++ b/strands_robots/simulation/mujoco/policy_runner.py @@ -0,0 +1,396 @@ +import logging +import os +import time +from typing import TYPE_CHECKING, Any + +import numpy as np + +from strands_robots._async_utils import _resolve_coroutine +from strands_robots.simulation.models import TrajectoryStep +from strands_robots.simulation.mujoco.backend import _ensure_mujoco +from strands_robots.utils import require_optional + +logger = logging.getLogger(__name__) + + +class PolicyRunnerMixin: + """Policy execution for Simulation. + + Expects the composite Simulation class to provide: + - self._world (SimWorld | None) + - self._lock (threading.Lock) + - self._executor (ThreadPoolExecutor) + - self._policy_threads (dict[str, Future]) + - self._get_sim_observation(), self._apply_sim_action(), self._get_renderer() + """ + + if TYPE_CHECKING: + import threading + from concurrent.futures import Future, ThreadPoolExecutor + + from strands_robots.simulation.models import SimWorld + + _world: SimWorld | None + _lock: threading.Lock + _executor: ThreadPoolExecutor + _policy_threads: dict[str, Future[Any]] + + def _get_renderer(self, width: int, height: int) -> Any: ... + def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> dict[str, Any]: ... + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: ... + + def run_policy( + self, + robot_name: str, + policy_provider: str = "mock", + instruction: str = "", + duration: float = 10.0, + action_horizon: int = 8, + control_frequency: float = 50.0, + fast_mode: bool = False, + record_video: str | None = None, + video_fps: int = 30, + video_camera: str | None = None, + video_width: int = 640, + video_height: int = 480, + **policy_kwargs, + ) -> dict[str, Any]: + """Run a policy on a simulated robot (blocking). + + Args: + record_video: If set, path to save an MP4 recording of the run. + video_fps: Frames per second for the recording (default 30). + video_camera: Camera name for recording (default: first scene camera). + video_width: Recording width in pixels. + video_height: Recording height in pixels. + """ + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + if robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + robot = self._world.robots[robot_name] + + # Video recording setup + writer = None + frame_count = 0 + cam_id = -1 + if record_video: + imageio = require_optional( + "imageio", + pip_install="imageio imageio-ffmpeg", + extra="sim-mujoco", + purpose="video recording", + ) + + os.makedirs(os.path.dirname(os.path.abspath(record_video)), exist_ok=True) + writer = imageio.get_writer(record_video, fps=video_fps, quality=8, macro_block_size=1) # type: ignore[attr-defined] + if video_camera: + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, video_camera) + elif model.ncam > 0: + cam_id = 0 + frame_interval = control_frequency / video_fps # fractional steps per frame + + try: + from strands_robots.policies import create_policy as _create_policy + + policy = _create_policy(policy_provider, **policy_kwargs) + policy.set_robot_state_keys(robot.joint_names) + + robot.policy_running = True + robot.policy_instruction = instruction + robot.policy_steps = 0 + next_frame_step = 0.0 + + sim_duration = duration * control_frequency # target number of control steps + start_time = time.time() + action_sleep = 1.0 / control_frequency + + while robot.policy_steps < sim_duration and robot.policy_running: + observation = self._get_sim_observation(robot_name) + + coro_or_result = policy.get_actions(observation, instruction) + actions = _resolve_coroutine(coro_or_result) + + for action_dict in actions[:action_horizon]: + if not robot.policy_running: + break + + if self._world._backend_state.get("recording", False): + self._world._backend_state["trajectory"].append( + TrajectoryStep( + timestamp=time.time(), + sim_time=self._world.sim_time, + robot_name=robot_name, + observation={k: v for k, v in observation.items() if not isinstance(v, np.ndarray)}, + action=action_dict, + instruction=instruction, + ) + ) + if self._world._backend_state.get("dataset_recorder") is not None: + self._world._backend_state["dataset_recorder"].add_frame( + observation=observation, + action=action_dict, + task=instruction, + ) + + self._apply_sim_action(robot_name, action_dict) + robot.policy_steps += 1 + + if writer and robot.policy_steps >= next_frame_step: + renderer = self._get_renderer(video_width, video_height) + if renderer is not None: + if cam_id >= 0: + renderer.update_scene(data, camera=cam_id) + else: + renderer.update_scene(data) + writer.append_data(renderer.render().copy()) + frame_count += 1 + next_frame_step += frame_interval + + if not fast_mode: + time.sleep(action_sleep) + + elapsed = time.time() - start_time + robot.policy_running = False + + result_text = ( + f"✅ Policy complete on '{robot_name}'\n" + f"🧠 {policy_provider} | 🎯 {instruction}\n" + f"⏱️ {elapsed:.1f}s | 📊 {robot.policy_steps} steps | " + f"🕐 sim_t={self._world.sim_time:.3f}s" + ) + + if writer: + writer.close() + file_kb = os.path.getsize(record_video) / 1024 # type: ignore[arg-type] # narrowed by `if writer` above + result_text += ( + f"\n🎬 Video: {record_video}\n" + f"📹 {frame_count} frames, {video_fps}fps, {video_width}x{video_height} | 💾 {file_kb:.0f} KB" + ) + + return {"status": "success", "content": [{"text": result_text}]} + + except Exception as e: + robot.policy_running = False + if writer: + writer.close() + return {"status": "error", "content": [{"text": f"❌ Policy failed: {e}"}]} + + def start_policy( + self, + robot_name: str, + policy_provider: str = "mock", + instruction: str = "", + duration: float = 10.0, + fast_mode: bool = False, + **policy_kwargs, + ) -> dict[str, Any]: + """Start policy execution in background (non-blocking). + + Only one policy may run per robot at a time — MuJoCo model/data + are not thread-safe for concurrent writes. + """ + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + if robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found."}]} + + # Reject if a policy is already running on this robot (thread-safety) + existing = self._policy_threads.get(robot_name) + if existing is not None and not existing.done(): + return { + "status": "error", + "content": [{"text": f"❌ Policy already running on '{robot_name}'. Stop it first."}], + } + + future = self._executor.submit( + self.run_policy, + robot_name, + policy_provider, + instruction, + duration, + fast_mode=fast_mode, + **policy_kwargs, + ) + self._policy_threads[robot_name] = future + + return { + "status": "success", + "content": [{"text": f"🚀 Policy started on '{robot_name}' (async)"}], + } + + def replay_episode( + self, + repo_id: str, + robot_name: str | None = None, + episode: int = 0, + root: str | None = None, + speed: float = 1.0, + ) -> dict[str, Any]: + """Replay actions from a LeRobotDataset episode in simulation.""" + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world. Call create_world first."}]} + + if robot_name is None: + if not self._world.robots: + return {"status": "error", "content": [{"text": "❌ No robots in sim. Add one first."}]} + robot_name = next(iter(self._world.robots)) + + robot = self._world.robots.get(robot_name) + if robot is None: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found"}]} + + try: + from strands_robots.dataset_recorder import load_lerobot_episode + + ds, episode_start, episode_length = load_lerobot_episode(repo_id, episode, root) + except ImportError: + return {"status": "error", "content": [{"text": "❌ lerobot not installed"}]} + except (ValueError, Exception) as e: + return {"status": "error", "content": [{"text": f"❌ {e}"}]} + + mj = _ensure_mujoco() + dataset_fps = getattr(ds, "fps", 30) + frame_interval = 1.0 / (dataset_fps * speed) + model = self._world._model + data = self._world._data + n_actuators = model.nu + frames_applied = 0 + start_time = time.time() + + for frame_idx in range(episode_length): + step_start = time.time() + frame = ds[episode_start + frame_idx] + + with self._lock: + if "action" in frame: + action_vals = frame["action"] + if hasattr(action_vals, "numpy"): + action_vals = action_vals.numpy() + if hasattr(action_vals, "tolist"): + action_vals = action_vals.tolist() + for i in range(min(len(action_vals), n_actuators)): + data.ctrl[i] = float(action_vals[i]) + + mj.mj_step(model, data) + frames_applied += 1 + + elapsed = time.time() - step_start + sleep_time = frame_interval - elapsed + if sleep_time > 0: + time.sleep(sleep_time) + + duration = time.time() - start_time + return { + "status": "success", + "content": [ + { + "text": ( + f"▶️ Replayed episode {episode} from {repo_id} on '{robot_name}'\n" + f"Frames: {frames_applied}/{episode_length} | Duration: {duration:.1f}s | Speed: {speed}x" + ) + }, + { + "json": { + "episode": episode, + "robot_name": robot_name, + "frames_applied": frames_applied, + "total_frames": episode_length, + "duration_s": round(duration, 2), + "speed": speed, + } + }, + ], + } + + def eval_policy( + self, + robot_name: str | None = None, + policy_provider: str = "mock", + instruction: str = "", + n_episodes: int = 10, + max_steps: int = 300, + success_fn: str | None = None, + **policy_kwargs, + ) -> dict[str, Any]: + """Evaluate a policy over multiple episodes with success metrics.""" + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world. Call create_world first."}]} + + if robot_name is None: + if not self._world.robots: + return {"status": "error", "content": [{"text": "❌ No robots"}]} + robot_name = next(iter(self._world.robots)) + + robot = self._world.robots.get(robot_name) + if robot is None: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found"}]} + + from strands_robots.policies import create_policy + + mj = _ensure_mujoco() + policy_instance = create_policy(policy_provider, **policy_kwargs) + policy_instance.set_robot_state_keys(robot.joint_names) + + model = self._world._model + data = self._world._data + + results = [] + for ep in range(n_episodes): + mj.mj_resetData(model, data) + mj.mj_forward(model, data) + + success = False + steps = 0 + + for step in range(max_steps): + obs = self._get_sim_observation(robot_name=robot_name) + coro_or_result = policy_instance.get_actions(obs, instruction) + actions = _resolve_coroutine(coro_or_result) + + with self._lock: + if actions: + self._apply_sim_action(robot_name, actions[0]) + + mj.mj_step(model, data) + steps += 1 + + if success_fn == "contact": + for i in range(data.ncon): + if data.contact[i].dist < 0: + success = True + break + if success: + break + + results.append({"episode": ep, "steps": steps, "success": success}) + + n_success = sum(1 for r in results if r["success"]) + success_rate = n_success / max(n_episodes, 1) + avg_steps = sum(r["steps"] for r in results) / max(n_episodes, 1) + + return { + "status": "success", + "content": [ + { + "text": ( + f"📊 Evaluation: {policy_provider} on '{robot_name}'\n" + f"Episodes: {n_episodes} | Success: {n_success}/{n_episodes} ({success_rate:.1%})\n" + f"Avg steps: {avg_steps:.0f}/{max_steps}" + ) + }, + { + "json": { + "success_rate": round(success_rate, 4), + "n_episodes": n_episodes, + "n_success": n_success, + "avg_steps": round(avg_steps, 1), + "max_steps": max_steps, + "episodes": results, + } + }, + ], + } diff --git a/strands_robots/simulation/mujoco/randomization.py b/strands_robots/simulation/mujoco/randomization.py new file mode 100644 index 0000000..8851521 --- /dev/null +++ b/strands_robots/simulation/mujoco/randomization.py @@ -0,0 +1,80 @@ +"""Domain randomization mixin.""" + +import logging +from typing import TYPE_CHECKING, Any + +import numpy as np + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RandomizationMixin: + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + + """Domain randomization for Simulation. Expects self._world.""" + + def randomize( + self, + randomize_colors: bool = True, + randomize_lighting: bool = True, + randomize_physics: bool = False, + randomize_positions: bool = False, + position_noise: float = 0.02, + color_range: tuple[float, float] = (0.1, 1.0), + friction_range: tuple[float, float] = (0.5, 1.5), + mass_range: tuple[float, float] = (0.5, 2.0), + seed: int | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Apply domain randomization to the scene.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + rng = np.random.default_rng(seed) + mj = _ensure_mujoco() + model = self._world._model + data = self._world._data + changes = [] + + if randomize_colors: + for i in range(model.ngeom): + geom_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, i) + if geom_name and geom_name != "ground": + model.geom_rgba[i, :3] = rng.uniform(color_range[0], color_range[1], size=3) + changes.append(f"🎨 Colors: {model.ngeom} geoms randomized") + + if randomize_lighting: + for i in range(model.nlight): + model.light_pos[i] += rng.uniform(-0.5, 0.5, size=3) + model.light_diffuse[i] = rng.uniform(0.3, 1.0, size=3) + changes.append(f"💡 Lighting: {model.nlight} lights randomized") + + if randomize_physics: + for i in range(model.ngeom): + model.geom_friction[i, 0] *= rng.uniform(*friction_range) + for i in range(model.nbody): + if model.body_mass[i] > 0: + model.body_mass[i] *= rng.uniform(*mass_range) + changes.append(f"⚙️ Physics: friction×[{friction_range}], mass×[{mass_range}]") + + if randomize_positions: + for obj_name, obj in self._world.objects.items(): + if not obj.is_static: + jnt_name = f"{obj_name}_joint" + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + qpos_addr = model.jnt_qposadr[jnt_id] + noise = rng.uniform(-position_noise, position_noise, size=3) + data.qpos[qpos_addr : qpos_addr + 3] += noise + mj.mj_forward(model, data) + changes.append(f"📍 Positions: ±{position_noise}m noise on dynamic objects") + + return { + "status": "success", + "content": [{"text": "🎲 Domain Randomization applied:\n" + "\n".join(changes)}], + } diff --git a/strands_robots/simulation/mujoco/recording.py b/strands_robots/simulation/mujoco/recording.py new file mode 100644 index 0000000..849d0df --- /dev/null +++ b/strands_robots/simulation/mujoco/recording.py @@ -0,0 +1,158 @@ +"""Recording mixin — start/stop trajectory recording to LeRobotDataset.""" + +import logging +import shutil +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RecordingMixin: + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + + """Trajectory recording for Simulation. Expects self._world.""" + + def start_recording( + self, + repo_id: str = "local/sim_recording", + task: str = "", + fps: int = 30, + root: str | None = None, + push_to_hub: bool = False, + vcodec: str = "libsvtav1", + overwrite: bool = False, + ) -> dict[str, Any]: + """Start recording to LeRobotDataset format (parquet + video).""" + if self._world is None: + return {"status": "error", "content": [{"text": "No world."}]} + + _DatasetRecorder: Any = None + _has_lerobot = False + try: + from strands_robots.dataset_recorder import DatasetRecorder as _DatasetRecorder + from strands_robots.dataset_recorder import has_lerobot_dataset as _check_lerobot + + _has_lerobot = _check_lerobot() + except ImportError: + pass + + if not _has_lerobot or _DatasetRecorder is None: + return { + "status": "error", + "content": [ + { + "text": "lerobot not installed. Install with: pip install lerobot\nRequired for dataset recording." + } + ], + } + + self._world._backend_state["recording"] = True + self._world._backend_state["trajectory"] = [] + self._world._backend_state["push_to_hub"] = push_to_hub + + try: + if overwrite: + if root: + dataset_dir = Path(root) + elif "/" not in repo_id or repo_id.startswith("/") or repo_id.startswith("./"): + dataset_dir = Path(repo_id) + else: + dataset_dir = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id + if dataset_dir.exists() and dataset_dir.is_dir(): + shutil.rmtree(dataset_dir) + logger.info("Removed existing dataset dir: %s", dataset_dir) + + joint_names = [] + camera_keys = [] + robot_type = "unknown" + for rname, robot in self._world.robots.items(): + joint_names.extend(robot.joint_names) + robot_type = robot.data_config or rname + + mj = _ensure_mujoco() + for i in range(self._world._model.ncam): + cam_name = mj.mj_id2name(self._world._model, mj.mjtObj.mjOBJ_CAMERA, i) + if cam_name: + camera_keys.append(cam_name) + + assert _DatasetRecorder is not None # checked above + self._world._backend_state["dataset_recorder"] = _DatasetRecorder.create( + repo_id=repo_id, + fps=fps, + robot_type=robot_type, + joint_names=joint_names, + camera_keys=camera_keys, + task=task, + root=root, + vcodec=vcodec, + ) + return { + "status": "success", + "content": [ + { + "text": ( + f"Recording to LeRobotDataset: {repo_id}\n" + f"{len(joint_names)} joints, {len(camera_keys)} cameras @ {fps}fps\n" + f"Codec: {vcodec} | Task: {task or '(set per policy)'}\n" + f"Run policies to capture frames, then stop_recording to save episode" + ) + } + ], + } + except Exception as e: + self._world._backend_state["recording"] = False + logger.error("Dataset recorder init failed: %s", e) + return {"status": "error", "content": [{"text": f"Dataset init failed: {e}"}]} + + def stop_recording(self, output_path: str | None = None) -> dict[str, Any]: + """Stop recording and save episode to LeRobotDataset.""" + if self._world is None or not self._world._backend_state.get("recording", False): + return {"status": "error", "content": [{"text": "Not recording."}]} + + self._world._backend_state["recording"] = False + recorder = self._world._backend_state.get("dataset_recorder", None) + + if recorder is None: + return {"status": "error", "content": [{"text": "No dataset recorder active."}]} + + recorder.save_episode() + push_result = None + if self._world._backend_state.get("push_to_hub", False): + push_result = recorder.push_to_hub(tags=["strands-robots", "sim"]) + + repo_id = recorder.repo_id + frame_count = recorder.frame_count + episode_count = recorder.episode_count + root = recorder.root + + recorder.finalize() + self._world._backend_state["dataset_recorder"] = None + self._world._backend_state["trajectory"] = [] + + text = ( + f"Episode saved to LeRobotDataset\n" + f"{repo_id} -- {frame_count} frames, {episode_count} episode(s)\n" + f"Local: {root}" + ) + if push_result and push_result.get("status") == "success": + text += "\nPushed to HuggingFace Hub" + + return {"status": "success", "content": [{"text": text}]} + + def get_recording_status(self) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + + recording = self._world._backend_state.get("recording", False) + steps = len(self._world._backend_state.get("trajectory", [])) + + return { + "status": "success", + "content": [{"text": f"{'🔴 Recording' if recording else '⚪ Not recording'}: {steps} steps captured"}], + } diff --git a/strands_robots/simulation/mujoco/rendering.py b/strands_robots/simulation/mujoco/rendering.py new file mode 100644 index 0000000..e3b89e0 --- /dev/null +++ b/strands_robots/simulation/mujoco/rendering.py @@ -0,0 +1,243 @@ +"""Rendering mixin — render, render_depth, get_contacts, observation helpers.""" + +import io +import json +import logging +from typing import TYPE_CHECKING, Any + +from strands_robots.simulation.mujoco.backend import _can_render, _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RenderingMixin: + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + _renderer_model: Any + _renderers: dict[tuple[int, int], Any] + default_width: int + default_height: int + + """Rendering capabilities for Simulation. Expects self._world, self.default_width, self.default_height.""" + + def _get_renderer(self, width: int, height: int): + """Get a cached MuJoCo renderer, creating one only if needed. + + Returns None if rendering is unavailable (headless without EGL/OSMesa). + Callers must handle None return. + """ + if not _can_render(): + return None + mj = _ensure_mujoco() + assert self._world is not None # callers must check + key = (width, height) + if self._renderer_model is not self._world._model: + self._renderers.clear() + self._renderer_model = self._world._model + if key not in self._renderers: + self._renderers[key] = mj.Renderer(self._world._model, height=height, width=width) + return self._renderers[key] + + def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> dict[str, Any]: + """Get observation from sim (same format as real robot).""" + mj = _ensure_mujoco() + assert self._world is not None # callers must check + model, data = self._world._model, self._world._data + robot = self._world.robots[robot_name] + + obs = {} + for jnt_name in robot.joint_names: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + obs[jnt_name] = float(data.qpos[model.jnt_qposadr[jnt_id]]) + + cameras_to_render = [] + if cam_name: + cameras_to_render = [cam_name] + else: + cameras_to_render = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) for i in range(model.ncam)] + for pycam_name in self._world.cameras: + if pycam_name not in cameras_to_render: + cameras_to_render.append(pycam_name) + + for cname in cameras_to_render: + if not cname: + continue + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, cname) + cam_info = self._world.cameras.get(cname) + h = cam_info.height if cam_info else self.default_height + w = cam_info.width if cam_info else self.default_width + try: + renderer = self._get_renderer(w, h) + if renderer is None: + continue + if cam_id >= 0: + renderer.update_scene(data, camera=cam_id) + else: + renderer.update_scene(data) + obs[cname] = renderer.render().copy() + except (RuntimeError, ValueError) as e: + # Individual camera failure shouldn't stop joint state collection. + # Common cause: camera ID invalid after scene recompile. + logger.debug("Camera render failed for %s: %s", cname, e) + + return obs + + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: + """Apply action dict to sim (same interface as robot.send_action).""" + mj = _ensure_mujoco() + assert self._world is not None # callers must check + model, data = self._world._model, self._world._data + + for key, value in action_dict.items(): + act_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_ACTUATOR, key) + if act_id >= 0: + data.ctrl[act_id] = float(value) + else: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, key) + if jnt_id >= 0 and jnt_id < model.nu: + data.ctrl[jnt_id] = float(value) + + for _ in range(max(1, n_substeps)): + mj.mj_step(model, data) + + assert self._world is not None + self._world.sim_time = data.time + assert self._world is not None # callers must check + self._world.step_count += n_substeps + + if hasattr(self, "_viewer_handle") and self._viewer_handle is not None: + self._viewer_handle.sync() + + def render( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: + """Render a camera view as base64 PNG image.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + w = width or self.default_width + h = height or self.default_height + + try: + renderer = self._get_renderer(w, h) + if renderer is None: + return { + "status": "error", + "content": [ + { + "text": ( + "❌ Rendering unavailable (no OpenGL context). " + "Install EGL or OSMesa for offscreen rendering: " + "apt-get install libosmesa6-dev" + ) + } + ], + } + cam_id = mj.mj_name2id(self._world._model, mj.mjtObj.mjOBJ_CAMERA, camera_name) + if cam_id >= 0: + renderer.update_scene(self._world._data, camera=cam_id) + else: + renderer.update_scene(self._world._data) + + img = renderer.render().copy() + + from PIL import Image + + pil_img = Image.fromarray(img) + buffer = io.BytesIO() + pil_img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + return { + "status": "success", + "content": [ + {"text": f"📸 {w}x{h} from '{camera_name}' at t={self._world.sim_time:.3f}s"}, + {"image": {"format": "png", "source": {"bytes": png_bytes}}}, + ], + } + except Exception as e: + return {"status": "error", "content": [{"text": f"❌ Render failed: {e}"}]} + + def render_depth( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: + """Render depth map from a camera.""" + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + w = width or self.default_width + h = height or self.default_height + + try: + cam_id = -1 + if camera_name and camera_name != "default": + cam_id = mj.mj_name2id(self._world._model, mj.mjtObj.mjOBJ_CAMERA, camera_name) + + renderer = self._get_renderer(w, h) + if renderer is None: + return { + "status": "error", + "content": [ + { + "text": ( + "❌ Depth rendering unavailable (no OpenGL context). " + "Install EGL or OSMesa for offscreen rendering." + ) + } + ], + } + if cam_id >= 0: + renderer.update_scene(self._world._data, camera=cam_id) + else: + renderer.update_scene(self._world._data) + renderer.enable_depth_rendering() + depth = renderer.render() + renderer.disable_depth_rendering() + + return { + "status": "success", + "content": [ + { + "text": ( + f"📸 Depth {w}x{h} from '{camera_name}'\n" + f"Min: {float(depth.min()):.3f}m, Max: {float(depth.max()):.3f}m" + ) + }, + { + "text": json.dumps( + {"depth_min": float(depth.min()), "depth_max": float(depth.max())}, default=str + ) + }, + ], + } + except Exception as e: + return {"status": "error", "content": [{"text": f"❌ Depth render failed: {e}"}]} + + def get_contacts(self) -> dict[str, Any]: + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + contacts = [] + for i in range(data.ncon): + c = data.contact[i] + g1 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom1) or f"geom_{c.geom1}" + g2 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom2) or f"geom_{c.geom2}" + contacts.append({"geom1": g1, "geom2": g2, "dist": float(c.dist), "pos": c.pos.tolist()}) + + text = f"💥 {len(contacts)} contacts" if contacts else "No contacts." + if contacts: + for c in contacts[:10]: + text += f"\n • {c['geom1']} ↔ {c['geom2']} (d={c['dist']:.4f})" + + return { + "status": "success", + "content": [{"text": text}, {"text": json.dumps({"contacts": contacts}, default=str)}], + } diff --git a/strands_robots/simulation/mujoco/scene_ops.py b/strands_robots/simulation/mujoco/scene_ops.py new file mode 100644 index 0000000..fa661b7 --- /dev/null +++ b/strands_robots/simulation/mujoco/scene_ops.py @@ -0,0 +1,250 @@ +"""XML round-trip injection/ejection for scene modification. + +Shared helper `_reload_scene_from_xml` handles the common pattern: +save XML → patch paths → modify → reload → copy state → re-discover joints. +""" + +import logging +import os +import re +import shutil +import tempfile +import xml.etree.ElementTree as ET + +from strands_robots.simulation.models import SimCamera, SimObject, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco +from strands_robots.simulation.mujoco.mjcf_builder import MJCFBuilder, _sanitize_name + +logger = logging.getLogger(__name__) + + +def _patch_xml_paths(xml_content: str, robot_base_dir: str) -> str: + """Patch meshdir/texturedir in XML to absolute paths for tmpdir loading. + + Uses ElementTree for consistent XML manipulation throughout scene_ops. + Falls back to the original string if ET parsing fails (e.g. XML fragments). + """ + try: + root = ET.fromstring(xml_content) + except ET.ParseError: + # Fallback for malformed fragments — use regex as last resort + logger.debug("ET parse failed for _patch_xml_paths, using regex fallback") + meshdir_match = re.search(r'meshdir="([^"]*)"', xml_content) + if meshdir_match: + abs_meshdir = os.path.normpath(os.path.join(robot_base_dir, meshdir_match.group(1))) + xml_content = re.sub(r'meshdir="[^"]*"', f'meshdir="{abs_meshdir}"', xml_content) + texdir_match = re.search(r'texturedir="([^"]*)"', xml_content) + if texdir_match: + abs_texdir = os.path.normpath(os.path.join(robot_base_dir, texdir_match.group(1))) + xml_content = re.sub(r'texturedir="[^"]*"', f'texturedir="{abs_texdir}"', xml_content) + return xml_content + + compiler = root.find("compiler") + if compiler is None: + # No compiler element — add one with meshdir + compiler = ET.SubElement(root, "compiler") + # Insert at beginning (after root tag) + root.remove(compiler) + root.insert(0, compiler) + + existing_meshdir = compiler.get("meshdir", "") + compiler.set("meshdir", os.path.normpath(os.path.join(robot_base_dir, existing_meshdir))) + + existing_texdir = compiler.get("texturedir", "") + if existing_texdir or compiler.get("texturedir") is not None: + compiler.set("texturedir", os.path.normpath(os.path.join(robot_base_dir, existing_texdir))) + else: + compiler.set("texturedir", robot_base_dir) + + return ET.tostring(root, encoding="unicode", xml_declaration=False) + + +def _reload_scene_from_xml(world: SimWorld, scene_path: str) -> bool: + """Reload MuJoCo model from modified XML, preserving state. + + Copies qpos, qvel, ctrl from old model and re-discovers robot joint/actuator IDs. + """ + mj = _ensure_mujoco() + new_model = mj.MjModel.from_xml_path(str(scene_path)) + new_data = mj.MjData(new_model) + + # Copy state from old model + old_nq = min(world._data.qpos.shape[0], new_data.qpos.shape[0]) + old_nv = min(world._data.qvel.shape[0], new_data.qvel.shape[0]) + new_data.qpos[:old_nq] = world._data.qpos[:old_nq] + new_data.qvel[:old_nv] = world._data.qvel[:old_nv] + old_nu = min(world._data.ctrl.shape[0], new_data.ctrl.shape[0]) + new_data.ctrl[:old_nu] = world._data.ctrl[:old_nu] + + mj.mj_forward(new_model, new_data) + + world._model = new_model + world._data = new_data + + # Re-discover robot joints/actuators (IDs may shift) + for robot in world.robots.values(): + robot.joint_ids = [] + robot.actuator_ids = [] + for jnt_name in robot.joint_names: + jid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jid >= 0: + robot.joint_ids.append(jid) + for i in range(new_model.nu): + jnt_id = new_model.actuator_trnid[i, 0] + if jnt_id in robot.joint_ids: + robot.actuator_ids.append(i) + if not robot.actuator_ids: + for i in range(new_model.nu): + robot.actuator_ids.append(i) + + return True + + +def _get_robot_base_dir(world: SimWorld) -> str | None: + """Get the directory of the original robot model file.""" + if world._backend_state.get("robot_base_xml", ""): + return os.path.dirname(os.path.abspath(world._backend_state.get("robot_base_xml", ""))) + return None + + +def _save_and_patch_xml(world: SimWorld, tmpdir: str, filename: str) -> str: + """Save current model to XML in tmpdir and patch asset paths.""" + mj = _ensure_mujoco() + scene_path = os.path.join(tmpdir, filename) + mj.mj_saveLastXML(scene_path, world._model) + + robot_base_dir = _get_robot_base_dir(world) + if robot_base_dir and os.path.isdir(robot_base_dir): + with open(scene_path) as f: + xml_content = f.read() + xml_content = _patch_xml_paths(xml_content, robot_base_dir) + with open(scene_path, "w") as f: + f.write(xml_content) + + return scene_path + + +def inject_object_into_scene(world: SimWorld, obj: SimObject) -> bool: + """Inject object into a running simulation via XML round-trip. + + Uses ElementTree for XML manipulation (consistent with eject_body_from_scene). + """ + _ensure_mujoco() + if world._model is None: + return False + + tmpdir = tempfile.mkdtemp(prefix="strands_sim_") + try: + scene_path = _save_and_patch_xml(world, tmpdir, "scene_with_objects.xml") + + tree = ET.parse(scene_path) + root = tree.getroot() + + # Find and append the object element + worldbody = root.find("worldbody") + if worldbody is None: + logger.error("No found in scene XML") + return False + + obj_xml_str = MJCFBuilder._object_xml(obj, indent=4) + obj_elem = ET.fromstring(f"<_wrapper>{obj_xml_str}") + for child in obj_elem: + worldbody.append(child) + + # Remove keyframes — adding a freejoint changes qpos size + for keyframe_elem in root.findall("keyframe"): + root.remove(keyframe_elem) + + tree.write(scene_path, xml_declaration=True) + + return _reload_scene_from_xml(world, scene_path) + except (ValueError, RuntimeError, OSError) as e: + logger.error("Object injection reload failed: %s", e) + return False + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +def eject_body_from_scene(world: SimWorld, body_name: str) -> bool: + """Remove a named body from the scene via XML round-trip.""" + mj = _ensure_mujoco() + + tmpdir = tempfile.mkdtemp(prefix="strands_eject_") + try: + scene_path = os.path.join(tmpdir, "scene_ejected.xml") + mj.mj_saveLastXML(scene_path, world._model) + + tree = ET.parse(scene_path) + root = tree.getroot() + + # Patch paths + robot_base_dir = _get_robot_base_dir(world) + if robot_base_dir: + compiler = root.find("compiler") + if compiler is not None: + existing_meshdir = compiler.get("meshdir", "") + compiler.set("meshdir", os.path.normpath(os.path.join(robot_base_dir, existing_meshdir))) + existing_texdir = compiler.get("texturedir", "") + compiler.set("texturedir", os.path.normpath(os.path.join(robot_base_dir, existing_texdir))) + + # Remove target body + removed = False + for parent in root.iter(): + for child in list(parent): + if child.tag == "body" and child.get("name") == body_name: + parent.remove(child) + removed = True + + if not removed: + logger.warning(f"Body '{body_name}' not found in MJCF XML — skipping ejection.") + + # Remove keyframes + for keyframe_elem in root.findall("keyframe"): + root.remove(keyframe_elem) + + tree.write(scene_path, xml_declaration=True) + + return _reload_scene_from_xml(world, scene_path) + except (ValueError, RuntimeError, OSError) as e: + logger.error("Body ejection failed for '%s': %s", body_name, e) + return False + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +def inject_camera_into_scene(world: SimWorld, cam: SimCamera) -> bool: + """Inject a camera into a running simulation via XML round-trip. + + Uses ElementTree for XML manipulation (consistent with eject_body_from_scene). + """ + _ensure_mujoco() + if world._model is None: + return False + + tmpdir = tempfile.mkdtemp(prefix="strands_cam_") + try: + scene_path = _save_and_patch_xml(world, tmpdir, "scene_with_cameras.xml") + + tree = ET.parse(scene_path) + root = tree.getroot() + + worldbody = root.find("worldbody") + if worldbody is None: + logger.error("No found in scene XML") + return False + + px, py, pz = cam.position + cam_elem = ET.SubElement(worldbody, "camera") + cam_elem.set("name", _sanitize_name(cam.name)) + cam_elem.set("pos", f"{px} {py} {pz}") + cam_elem.set("fovy", str(cam.fov)) + cam_elem.set("mode", "fixed") + + tree.write(scene_path, xml_declaration=True) + + return _reload_scene_from_xml(world, scene_path) + except (ValueError, RuntimeError, OSError) as e: + logger.error("Camera injection reload failed: %s", e) + return False + finally: + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py new file mode 100644 index 0000000..e12013f --- /dev/null +++ b/strands_robots/simulation/mujoco/simulation.py @@ -0,0 +1,949 @@ +"""MuJoCo Simulation — AgentTool orchestrator composing physics/rendering/policy mixins.""" + +import json +import logging +import os +import re +import threading +from collections.abc import AsyncGenerator +from concurrent.futures import Future, ThreadPoolExecutor +from pathlib import Path +from typing import Any + +from strands.tools.tools import AgentTool +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolSpec, ToolUse + +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.model_registry import ( + list_available_models, + register_urdf, + resolve_model, +) +from strands_robots.simulation.models import SimCamera, SimObject, SimRobot, SimStatus, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco +from strands_robots.simulation.mujoco.mjcf_builder import MJCFBuilder +from strands_robots.simulation.mujoco.physics import PhysicsMixin +from strands_robots.simulation.mujoco.policy_runner import PolicyRunnerMixin +from strands_robots.simulation.mujoco.randomization import RandomizationMixin +from strands_robots.simulation.mujoco.recording import RecordingMixin +from strands_robots.simulation.mujoco.rendering import RenderingMixin +from strands_robots.simulation.mujoco.scene_ops import ( + eject_body_from_scene, + inject_camera_into_scene, + inject_object_into_scene, +) + +logger = logging.getLogger(__name__) + +_TOOL_SPEC_PATH = Path(__file__).parent / "tool_spec.json" + + +class Simulation( + PhysicsMixin, + PolicyRunnerMixin, + RenderingMixin, + RecordingMixin, + RandomizationMixin, + SimEngine, + AgentTool, +): + """Programmatic simulation environment as a Strands AgentTool. + + Gives AI agents the ability to create, modify, and control MuJoCo + simulation environments through natural language → tool actions. + """ + + def __init__( + self, + tool_name: str = "sim", + default_timestep: float = 0.002, + default_width: int = 640, + default_height: int = 480, + mesh: bool = True, + peer_id: str | None = None, + **kwargs, + ): + super().__init__() + self.tool_name_str = tool_name + self.default_timestep = default_timestep + self.default_width = default_width + self.default_height = default_height + + self._world: SimWorld | None = None + self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix=f"{tool_name}_sim") + self._policy_threads: dict[str, Future] = {} + self._shutdown_event = threading.Event() + self._lock = threading.Lock() + + self._viewer_handle = None + self._viewer_thread = None + + self._renderers: dict[tuple, Any] = {} + self._renderer_model = None + + # Fail fast: verify MuJoCo is importable at construction time + # so consumers catch missing-dependency errors immediately. + self._mj = _ensure_mujoco() + logger.info("🎮 Simulation tool '%s' initialized", tool_name) + + # --- Public Properties --- + + @property + def mj_model(self): + """Direct access to the MuJoCo model (mujoco.MjModel).""" + return self._world._model if self._world else None + + @property + def mj_data(self): + """Direct access to the MuJoCo data (mujoco.MjData).""" + return self._world._data if self._world else None + + # --- Robot-compatible interface --- + + def get_observation(self, robot_name: str | None = None, camera_name: str | None = None) -> dict[str, Any]: + """Get observation from simulation (Robot ABC compatible).""" + if self._world is None or self._world._model is None: + return {} + if robot_name is None: + if not self._world.robots: + return {} + robot_name = next(iter(self._world.robots)) + if robot_name not in self._world.robots: + return {} + return self._get_sim_observation(robot_name, cam_name=camera_name) + + def send_action(self, action: dict[str, Any], robot_name: str | None = None, n_substeps: int = 1) -> None: + """Apply action to simulation (Robot ABC compatible).""" + if self._world is None or self._world._model is None: + return + if robot_name is None: + if not self._world.robots: + return + robot_name = next(iter(self._world.robots)) + if robot_name not in self._world.robots: + return + self._apply_sim_action(robot_name, action, n_substeps=n_substeps) + + # --- World Management --- + + def _cheap_robot_count(self) -> int: + try: + from strands_robots.registry import list_robots as _registry_list_robots + + return len(_registry_list_robots(mode="sim")) + except ImportError: + return 0 + + def create_world( + self, timestep: float | None = None, gravity: list[float] | None = None, ground_plane: bool = True + ) -> dict[str, Any]: + """Create a new simulation world.""" + # mujoco verified at __init__ + + if self._world is not None and self._world._model is not None: + return { + "status": "error", + "content": [{"text": "❌ World already exists. Use action='destroy' first, or action='reset'."}], + } + + if gravity is None: + _gravity = [0.0, 0.0, -9.81] + elif isinstance(gravity, (int, float)): + _gravity = [0.0, 0.0, float(gravity)] + else: + _gravity = list(gravity) + + self._world = SimWorld( + timestep=timestep or self.default_timestep, + gravity=_gravity, + ground_plane=ground_plane, + ) + + self._world.cameras["default"] = SimCamera( + name="default", + position=[1.5, 1.5, 1.2], + target=[0.0, 0.0, 0.3], + width=self.default_width, + height=self.default_height, + ) + + self._compile_world() + + return { + "status": "success", + "content": [ + { + "text": ( + "🌍 Simulation world created\n" + f"⚙️ Timestep: {self._world.timestep}s ({1 / self._world.timestep:.0f}Hz physics)\n" + f"🌐 Gravity: {self._world.gravity}\n" + f"📷 Default camera ready\n" + f"🤖 Robot models: {self._cheap_robot_count()} available\n" + "💡 Add robots: action='add_robot' (urdf_path or data_config)\n" + "💡 Add objects: action='add_object'\n" + "💡 List URDFs: action='list_urdfs'" + ) + } + ], + } + + def load_scene(self, scene_path: str) -> dict[str, Any]: + """Load a complete scene from MJCF XML or URDF file.""" + mj = self._mj + + if not os.path.exists(scene_path): + return {"status": "error", "content": [{"text": f"❌ Scene file not found: {scene_path}"}]} + + try: + self._world = SimWorld() + self._world._model = mj.MjModel.from_xml_path(str(scene_path)) + self._world._data = mj.MjData(self._world._model) + self._world.status = SimStatus.IDLE + + return { + "status": "success", + "content": [ + { + "text": ( + f"🌍 Scene loaded from {os.path.basename(scene_path)}\n" + f"🦴 Bodies: {self._world._model.nbody}, 🔩 Joints: {self._world._model.njnt}, ⚡ Actuators: {self._world._model.nu}\n" + "💡 Use action='get_state' to inspect, action='step' to simulate" + ) + } + ], + } + except Exception as e: + logger.error("Failed to load scene: %s", e) + return {"status": "error", "content": [{"text": f"❌ Failed to load scene: {e}"}]} + + def _compile_world(self): + mj = self._mj + xml = MJCFBuilder.build_objects_only(self._world) + self._world._backend_state["xml"] = xml + self._world._model = mj.MjModel.from_xml_string(xml) + self._world._data = mj.MjData(self._world._model) + self._world.status = SimStatus.IDLE + + def _recompile_world(self) -> dict[str, Any]: + try: + self._compile_world() + return {"status": "success"} + except Exception as e: + return {"status": "error", "content": [{"text": f"❌ Recompile failed: {e}"}]} + + # --- Robot Management --- + + @staticmethod + def _ensure_meshes(model_path: str, robot_name: str): + """Check if mesh files referenced by a model XML exist; auto-download if missing.""" + model_dir = os.path.dirname(os.path.abspath(model_path)) + + files_to_check = [model_path] + try: + with open(model_path) as _f: + top_content = _f.read() + for inc in re.findall(r' dict[str, Any]: + """Add a robot to the simulation.""" + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world. Use action='create_world' first."}]} + if name in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{name}' already exists."}]} + + resolved_path = urdf_path + if not resolved_path and data_config: + resolved_path = resolve_model(data_config) + if not resolved_path: + return { + "status": "error", + "content": [ + { + "text": f"❌ No model found for '{data_config}'.\n💡 Use action='list_urdfs' to see available robots" + } + ], + } + elif not resolved_path and name: + resolved_path = resolve_model(name) + + if not resolved_path: + return {"status": "error", "content": [{"text": "❌ Either urdf_path or data_config is required."}]} + if not os.path.exists(resolved_path): + return {"status": "error", "content": [{"text": f"❌ File not found: {resolved_path}"}]} + + mj = self._mj + + robot = SimRobot( + name=name, + urdf_path=resolved_path, + position=position or [0.0, 0.0, 0.0], + orientation=orientation or [1.0, 0.0, 0.0, 0.0], + data_config=data_config, + namespace=f"{name}/", + ) + + try: + self._ensure_meshes(resolved_path, data_config or name) + + model = mj.MjModel.from_xml_path(str(resolved_path)) + data = mj.MjData(model) + + joint_names = [] + for i in range(model.njnt): + jnt_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) + if jnt_name: + joint_names.append(jnt_name) + robot.joint_ids.append(i) + robot.joint_names = joint_names + + for i in range(model.nu): + act_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_ACTUATOR, i) + if act_name: + jnt_id = model.actuator_trnid[i, 0] + if jnt_id in robot.joint_ids: + robot.actuator_ids.append(i) + else: + robot.actuator_ids.append(i) + if not robot.actuator_ids: + for i in range(model.nu): + robot.actuator_ids.append(i) + + for i in range(model.ncam): + cam_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) + if cam_name and cam_name not in self._world.cameras: + self._world.cameras[cam_name] = SimCamera( + name=cam_name, + camera_id=i, + width=self.default_width, + height=self.default_height, + ) + + self._world._model = model + self._world._data = data + self._world._backend_state["robot_base_xml"] = resolved_path + self._world.robots[name] = robot + + for _ in range(100): + mj.mj_step(model, data) + + source = f"data_config='{data_config}'" if data_config else os.path.basename(resolved_path) + return { + "status": "success", + "content": [ + { + "text": ( + f"🤖 Robot '{name}' added to simulation\n" + f"📁 Source: {source} → {os.path.basename(resolved_path)}\n" + f"📍 Position: {robot.position}\n" + f"🔩 Joints: {len(robot.joint_names)} ({', '.join(robot.joint_names[:8])}{'...' if len(robot.joint_names) > 8 else ''})\n" + f"⚡ Actuators: {len(robot.actuator_ids)}\n" + f"📷 Cameras: {list(self._world.cameras.keys())}\n" + f"💡 Run policy: action='run_policy', robot_name='{name}'" + ) + } + ], + } + except Exception as e: + logger.error("Failed to add robot '%s': %s", name, e) + return {"status": "error", "content": [{"text": f"❌ Failed to load: {e}"}]} + + def remove_robot(self, name: str) -> dict[str, Any]: + if self._world is None or name not in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{name}' not found."}]} + if name in self._policy_threads: + self._world.robots[name].policy_running = False + try: + self._policy_threads[name].result(timeout=5.0) + except Exception: + pass + del self._policy_threads[name] + del self._world.robots[name] + return {"status": "success", "content": [{"text": f"🗑️ Robot '{name}' removed."}]} + + def list_robots(self) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + if not self._world.robots: + return {"status": "success", "content": [{"text": "No robots. Use action='add_robot'."}]} + + lines = ["🤖 Robots in simulation:\n"] + for name, robot in self._world.robots.items(): + status = "🟢 running" if robot.policy_running else "⚪ idle" + lines.append( + f" • {name} ({os.path.basename(robot.urdf_path)})\n" + f" Position: {robot.position}, Joints: {len(robot.joint_names)}, " + f"Config: {robot.data_config or 'direct'}, Status: {status}" + ) + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + def get_robot_state(self, robot_name: str) -> dict[str, Any]: + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation running."}]} + if robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"❌ Robot '{robot_name}' not found."}]} + + mj = self._mj + robot = self._world.robots[robot_name] + model, data = self._world._model, self._world._data + + state = {} + for jnt_name in robot.joint_names: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + state[jnt_name] = { + "position": float(data.qpos[model.jnt_qposadr[jnt_id]]), + "velocity": float(data.qvel[model.jnt_dofadr[jnt_id]]), + } + + text = f"🤖 '{robot_name}' state (t={self._world.sim_time:.3f}s):\n" + for jnt, vals in state.items(): + text += f" {jnt}: pos={vals['position']:.4f}, vel={vals['velocity']:.4f}\n" + + return {"status": "success", "content": [{"text": text}, {"text": json.dumps({"state": state}, default=str)}]} + + # --- Object Management --- + + def add_object( + self, + name: str, + shape: str = "box", + position: list[float] | None = None, + orientation: list[float] | None = None, + size: list[float] | None = None, + color: list[float] | None = None, + mass: float = 0.1, + is_static: bool = False, + mesh_path: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Add an object to the simulation.""" + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + if name in self._world.objects: + return {"status": "error", "content": [{"text": f"❌ Object '{name}' exists."}]} + + obj = SimObject( + name=name, + shape=shape, + position=position or [0.0, 0.0, 0.0], + orientation=orientation or [1.0, 0.0, 0.0, 0.0], + size=size or [0.05, 0.05, 0.05], + color=color or [0.5, 0.5, 0.5, 1.0], + mass=mass, + mesh_path=mesh_path, + is_static=is_static, + ) + self._world.objects[name] = obj + + if self._world.robots: + try: + result = inject_object_into_scene(self._world, obj) + if result: + return { + "status": "success", + "content": [{"text": f"📦 '{name}' spawned: {shape} at {obj.position}"}], + } + return { + "status": "success", + "content": [ + { + "text": ( + f"📦 '{name}' registered: {shape} at {obj.position}\n" + "⚠️ Robot scene loaded — object is tracked but not physically spawned." + ) + } + ], + } + except (ValueError, RuntimeError) as e: + raise RuntimeError( + f"Object injection into live scene failed for '{name}': {e}. " + f"Check that the MJCF XML is valid and compatible with the current scene." + ) from e + + recompile_result = self._recompile_world() + if recompile_result["status"] == "error": + del self._world.objects[name] + return recompile_result + + return { + "status": "success", + "content": [ + { + "text": f"📦 '{name}' added: {shape} at {obj.position}, size={obj.size}, {'static' if is_static else f'{mass}kg'}" + } + ], + } + + def remove_object(self, name: str) -> dict[str, Any]: + if self._world is None or name not in self._world.objects: + return {"status": "error", "content": [{"text": f"❌ Object '{name}' not found."}]} + del self._world.objects[name] + if self._world.robots: + eject_body_from_scene(self._world, name) + else: + self._recompile_world() + return {"status": "success", "content": [{"text": f"🗑️ '{name}' removed."}]} + + def move_object( + self, name: str, position: list[float] | None = None, orientation: list[float] | None = None + ) -> dict[str, Any]: + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + if name not in self._world.objects: + return {"status": "error", "content": [{"text": f"❌ '{name}' not found."}]} + + mj = self._mj + model, data = self._world._model, self._world._data + + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, f"{name}_joint") + if jnt_id >= 0: + qpos_addr = model.jnt_qposadr[jnt_id] + if position: + data.qpos[qpos_addr : qpos_addr + 3] = position + self._world.objects[name].position = position + if orientation: + data.qpos[qpos_addr + 3 : qpos_addr + 7] = orientation + self._world.objects[name].orientation = orientation + mj.mj_forward(model, data) + + return {"status": "success", "content": [{"text": f"📍 '{name}' moved to {position or 'same'}"}]} + + def list_objects(self) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + if not self._world.objects: + return {"status": "success", "content": [{"text": "No objects."}]} + + lines = ["📦 Objects:\n"] + for name, obj in self._world.objects.items(): + lines.append(f" • {name}: {obj.shape} at {obj.position}, {'static' if obj.is_static else f'{obj.mass}kg'}") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + # --- Camera Management --- + + def add_camera( + self, + name: str, + position: list[float] | None = None, + target: list[float] | None = None, + fov: float = 60.0, + width: int = 640, + height: int = 480, + ) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + + cam = SimCamera( + name=name, + position=position or [1.0, 1.0, 1.0], + target=target or [0.0, 0.0, 0.0], + fov=fov, + width=width, + height=height, + ) + self._world.cameras[name] = cam + + if self._world.robots and self._world._model is not None: + try: + inject_camera_into_scene(self._world, cam) + except (ValueError, RuntimeError) as e: + raise RuntimeError( + f"Camera injection into live scene failed for '{name}': {e}. " + f"Check that camera parameters are valid." + ) from e + else: + self._recompile_world() + + return {"status": "success", "content": [{"text": f"📷 Camera '{name}' added at {cam.position}"}]} + + def remove_camera(self, name: str) -> dict[str, Any]: + if self._world is None or name not in self._world.cameras: + return {"status": "error", "content": [{"text": f"❌ Camera '{name}' not found."}]} + del self._world.cameras[name] + return {"status": "success", "content": [{"text": f"🗑️ Camera '{name}' removed."}]} + + # --- Simulation Control --- + + def step(self, n_steps: int = 1) -> dict[str, Any]: + if self._world is None or self._world._data is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + mj = self._mj + for _ in range(n_steps): + mj.mj_step(self._world._model, self._world._data) + self._world.sim_time = self._world._data.time + self._world.step_count += n_steps + return { + "status": "success", + "content": [ + {"text": f"⏩ +{n_steps} steps | t={self._world.sim_time:.4f}s | total={self._world.step_count}"} + ], + } + + def reset(self) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + mj = self._mj + mj.mj_resetData(self._world._model, self._world._data) + self._world.sim_time = 0.0 + self._world.step_count = 0 + for r in self._world.robots.values(): + r.policy_running = False + r.policy_steps = 0 + return {"status": "success", "content": [{"text": "🔄 Reset to initial state."}]} + + def get_state(self) -> dict[str, Any]: + if self._world is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + lines = [ + "🌍 Simulation State", + f"🕐 t={self._world.sim_time:.4f}s (step {self._world.step_count})", + f"⚙️ dt={self._world.timestep}s | 🌐 g={self._world.gravity}", + f"🤖 Robots: {len(self._world.robots)} | 📦 Objects: {len(self._world.objects)} | 📷 Cameras: {len(self._world.cameras)}", + ] + if self._world._model: + lines.append( + f"🦴 Bodies: {self._world._model.nbody} | 🔩 Joints: {self._world._model.njnt} | ⚡ Actuators: {self._world._model.nu}" + ) + if self._world._backend_state.get("recording", False): + lines.append(f"🔴 Recording: {len(self._world._backend_state['trajectory'])} steps") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + def destroy(self) -> dict[str, Any]: + if self._world is None: + return {"status": "success", "content": [{"text": "No world to destroy."}]} + for r in self._world.robots.values(): + r.policy_running = False + self._close_viewer() + self._world = None + return {"status": "success", "content": [{"text": "🗑️ World destroyed."}]} + + def set_gravity(self, gravity: list[float] | float | int) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + if isinstance(gravity, (int, float)): + gravity = [0.0, 0.0, float(gravity)] + self._world._model.opt.gravity[:] = gravity + self._world.gravity = gravity + return {"status": "success", "content": [{"text": f"🌐 Gravity: {gravity}"}]} + + def set_timestep(self, timestep: float) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No world."}]} + self._world._model.opt.timestep = timestep + self._world.timestep = timestep + return {"status": "success", "content": [{"text": f"⏱️ Timestep: {timestep}s ({1 / timestep:.0f}Hz)"}]} + + # --- Viewer --- + + def open_viewer(self) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation to view."}]} + from strands_robots.simulation.mujoco.backend import _mujoco_viewer + + if _mujoco_viewer is None: + return {"status": "error", "content": [{"text": "❌ mujoco.viewer not available."}]} + if self._viewer_handle is not None: + return {"status": "success", "content": [{"text": "👁️ Viewer already open."}]} + try: + self._viewer_handle = _mujoco_viewer.launch_passive(self._world._model, self._world._data) + return {"status": "success", "content": [{"text": "👁️ Interactive viewer opened."}]} + except Exception as e: + return {"status": "error", "content": [{"text": f"❌ Viewer failed: {e}"}]} + + def _close_viewer(self) -> None: + if self._viewer_handle is not None: + try: + self._viewer_handle.close() + except Exception: + pass + self._viewer_handle = None + + def close_viewer(self) -> dict[str, Any]: + self._close_viewer() + return {"status": "success", "content": [{"text": "👁️ Viewer closed."}]} + + # --- URDF Registry --- + + def list_urdfs_action(self) -> dict[str, Any]: + return {"status": "success", "content": [{"text": list_available_models()}]} + + def register_urdf_action(self, data_config: str, urdf_path: str) -> dict[str, Any]: + register_urdf(data_config, urdf_path) + resolved = resolve_model(data_config) + return { + "status": "success", + "content": [{"text": f"📋 Registered '{data_config}' → {urdf_path}\nResolved: {resolved or 'NOT FOUND'}"}], + } + + # --- Introspection --- + + def get_features(self) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "❌ No simulation."}]} + + mj = self._mj + model = self._world._model + + joint_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) for i in range(model.njnt)] + joint_names = [n for n in joint_names if n] + actuator_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_ACTUATOR, i) for i in range(model.nu)] + actuator_names = [n for n in actuator_names if n] + camera_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) for i in range(model.ncam)] + camera_names = [n for n in camera_names if n] + + robots_info = {} + for rname, robot in self._world.robots.items(): + robots_info[rname] = { + "joint_names": robot.joint_names, + "n_joints": len(robot.joint_names), + "n_actuators": len(robot.actuator_ids), + "data_config": robot.data_config, + "source": os.path.basename(robot.urdf_path), + } + + features = { + "n_bodies": model.nbody, + "n_joints": model.njnt, + "n_actuators": model.nu, + "n_cameras": model.ncam, + "timestep": model.opt.timestep, + "joint_names": joint_names, + "actuator_names": actuator_names, + "camera_names": camera_names, + "robots": robots_info, + } + + lines = [ + "🔍 Simulation Features", + f"🦴 Joints ({model.njnt}): {', '.join(joint_names[:12])}{'...' if len(joint_names) > 12 else ''}", + f"⚡ Actuators ({model.nu}): {', '.join(actuator_names[:12])}{'...' if len(actuator_names) > 12 else ''}", + f"📷 Cameras ({model.ncam}): {', '.join(camera_names) if camera_names else 'none (free camera only)'}", + f"⏱️ Timestep: {model.opt.timestep}s ({1 / model.opt.timestep:.0f}Hz)", + ] + for rname, rinfo in robots_info.items(): + lines.append( + f"🤖 {rname}: {rinfo['n_joints']} joints, {rinfo['n_actuators']} actuators ({rinfo['source']})" + ) + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"text": json.dumps({"features": features}, default=str)}], + } + + # --- AgentTool Interface --- + + @property + def tool_name(self) -> str: + return self.tool_name_str + + @property + def tool_type(self) -> str: + return "simulation" + + @property + def tool_spec(self) -> ToolSpec: + with open(_TOOL_SPEC_PATH) as f: + schema = json.load(f) + return { + "name": self.tool_name_str, + "description": ( + "Programmatic MuJoCo simulation environment. Create worlds, add robots from URDF " + "(direct path or auto-resolve from data_config name), add objects, run VLA policies, " + "render cameras, record trajectories, domain randomize. " + "Same Policy ABC as real robot control — sim ↔ real with zero code changes. " + "Actions: create_world, load_scene, reset, get_state, destroy, " + "add_robot, remove_robot, list_robots, get_robot_state, " + "add_object, remove_object, move_object, list_objects, " + "add_camera, remove_camera, " + "run_policy, start_policy, stop_policy, " + "render, render_depth, get_contacts, " + "step, set_gravity, set_timestep, " + "randomize, " + "start_recording, stop_recording, get_recording_status, " + "open_viewer, close_viewer, " + "list_urdfs, register_urdf, get_features" + ), + "inputSchema": {"json": schema}, + } + + async def stream( + self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any + ) -> AsyncGenerator[ToolResultEvent, None]: + try: + tool_use_id = tool_use.get("toolUseId", "") + input_data = tool_use.get("input", {}) + result = self._dispatch_action(input_data.get("action", ""), input_data) + yield ToolResultEvent(dict(toolUseId=tool_use_id, **result)) # type: ignore[typeddict-item] + except Exception as e: + yield ToolResultEvent( + { + "toolUseId": tool_use.get("toolUseId", ""), + "status": "error", + "content": [{"text": f"❌ Sim error: {e}"}], + } + ) + + def _dispatch_action(self, action: str, d: dict[str, Any]) -> dict[str, Any]: + """Route action string to method via getattr. + + Method names match action names directly (with a few aliases). + """ + # Aliases for actions whose method names differ + _ALIASES = { + "list_urdfs": "list_urdfs_action", + "register_urdf": "register_urdf_action", + "stop_policy": "_stop_policy", + } + + # Map input field names to method parameter names for physics actions + _FIELD_MAP = { + "checkpoint_name": "name", + "torque_vec": "torque", + } + + method_name = _ALIASES.get(action, action) + method = getattr(self, method_name, None) + + if method is None or action.startswith("_"): + return {"status": "error", "content": [{"text": f"❌ Unknown action: {action}"}]} + + # Build kwargs from input dict, excluding 'action' itself + # Signatures are cached per method to avoid repeated introspection. + import inspect + + cache = getattr(self, "_sig_cache", None) + if cache is None: + self._sig_cache = cache = {} + if method_name not in cache: + cache[method_name] = inspect.signature(method) + sig = cache[method_name] + # Apply field name remapping + remapped = dict(d) + for field_key, param_key in _FIELD_MAP.items(): + if field_key in remapped and param_key not in remapped: + remapped[param_key] = remapped.pop(field_key) + + kwargs = {} + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + # Handle name/robot_name/body_name ambiguity in the input schema + if param_name == "name" and "name" not in remapped and "robot_name" in remapped: + kwargs["name"] = remapped["robot_name"] + elif param_name == "name" and "name" not in remapped and "checkpoint_name" in d: + kwargs["name"] = d["checkpoint_name"] + elif param_name == "robot_name" and "robot_name" not in remapped and "name" in remapped: + kwargs["robot_name"] = remapped["name"] + elif param_name in remapped: + kwargs[param_name] = remapped[param_name] + # Forward policy kwargs + elif param.kind == inspect.Parameter.VAR_KEYWORD: + for k in ( + "policy_port", + "policy_host", + "model_path", + "server_address", + "policy_type", + "pretrained_name_or_path", + "device", + ): + if k in d: + kwargs[k] = d[k] + + return method(**kwargs) + + def _stop_policy(self, robot_name: str = "", **kwargs) -> dict[str, Any]: + if self._world and robot_name in self._world.robots: + self._world.robots[robot_name].policy_running = False + return {"status": "success", "content": [{"text": f"🛑 Stopped on '{robot_name}'"}]} + return {"status": "error", "content": [{"text": f"❌ '{robot_name}' not found."}]} + + # --- Cleanup --- + + def cleanup(self) -> None: + if hasattr(self, "mesh") and self.mesh: + self.mesh.stop() + if self._world: + for r in self._world.robots.values(): + r.policy_running = False + self._world = None + self._close_viewer() + for renderer in getattr(self, "_renderers", {}).values(): + try: + renderer.close() + except Exception: + pass + self._renderers.clear() + self._executor.shutdown(wait=False) + self._shutdown_event.set() + + def __enter__(self) -> "Simulation": + return self + + def __exit__(self, *exc: object) -> None: + self.cleanup() + + def __del__(self) -> None: + try: + self.cleanup() + except Exception: + pass diff --git a/strands_robots/simulation/mujoco/tool_spec.json b/strands_robots/simulation/mujoco/tool_spec.json new file mode 100644 index 0000000..4147a4b --- /dev/null +++ b/strands_robots/simulation/mujoco/tool_spec.json @@ -0,0 +1,351 @@ +{ + "type": "object", + "properties": { + "action": { + "type": "string", + "description": "Action to perform", + "enum": [ + "create_world", + "load_scene", + "reset", + "get_state", + "destroy", + "add_robot", + "remove_robot", + "list_robots", + "get_robot_state", + "add_object", + "remove_object", + "move_object", + "list_objects", + "add_camera", + "remove_camera", + "run_policy", + "start_policy", + "stop_policy", + "render", + "render_depth", + "get_contacts", + "step", + "set_gravity", + "set_timestep", + "randomize", + "start_recording", + "stop_recording", + "get_recording_status", + "open_viewer", + "close_viewer", + "list_urdfs", + "register_urdf", + "get_features", + "replay_episode", + "eval_policy", + "save_state", + "load_state", + "apply_force", + "raycast", + "multi_raycast", + "get_jacobian", + "get_energy", + "get_mass_matrix", + "inverse_dynamics", + "get_body_state", + "set_joint_positions", + "set_joint_velocities", + "get_sensor_data", + "set_body_properties", + "set_geom_properties", + "get_contact_forces", + "forward_kinematics", + "get_total_mass", + "export_xml" + ] + }, + "scene_path": { + "type": "string", + "description": "Path to MJCF/URDF scene file" + }, + "timestep": { + "type": "number" + }, + "gravity": { + "type": "array", + "items": { + "type": "number" + } + }, + "ground_plane": { + "type": "boolean" + }, + "urdf_path": { + "type": "string", + "description": "Path to URDF/MJCF file" + }, + "robot_name": { + "type": "string" + }, + "data_config": { + "type": "string", + "description": "Data config name (auto-resolves URDF)" + }, + "name": { + "type": "string", + "description": "Object/camera name" + }, + "shape": { + "type": "string", + "enum": [ + "box", + "sphere", + "cylinder", + "capsule", + "mesh", + "plane" + ] + }, + "position": { + "type": "array", + "items": { + "type": "number" + } + }, + "orientation": { + "type": "array", + "items": { + "type": "number" + } + }, + "size": { + "type": "array", + "items": { + "type": "number" + } + }, + "color": { + "type": "array", + "items": { + "type": "number" + } + }, + "mass": { + "type": "number" + }, + "is_static": { + "type": "boolean" + }, + "mesh_path": { + "type": "string" + }, + "target": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Camera target point" + }, + "fov": { + "type": "number", + "description": "Camera field of view" + }, + "width": { + "type": "integer" + }, + "height": { + "type": "integer" + }, + "policy_provider": { + "type": "string", + "description": "Policy provider name (e.g. groot, lerobot_async, lerobot_local, dreamgen, mock)" + }, + "instruction": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "policy_port": { + "type": "integer" + }, + "policy_host": { + "type": "string" + }, + "model_path": { + "type": "string" + }, + "action_horizon": { + "type": "integer" + }, + "control_frequency": { + "type": "number" + }, + "camera_name": { + "type": "string" + }, + "n_steps": { + "type": "integer" + }, + "output_path": { + "type": "string", + "description": "Trajectory/video export path" + }, + "fps": { + "type": "integer", + "description": "Video frames per second (for run_policy record_video)" + }, + "pretrained_name_or_path": { + "type": "string", + "description": "HuggingFace model ID for lerobot_local" + }, + "randomize_colors": { + "type": "boolean" + }, + "randomize_lighting": { + "type": "boolean" + }, + "randomize_physics": { + "type": "boolean" + }, + "randomize_positions": { + "type": "boolean" + }, + "position_noise": { + "type": "number" + }, + "seed": { + "type": "integer", + "description": "Random seed" + }, + "repo_id": { + "type": "string", + "description": "HuggingFace dataset repo ID" + }, + "push_to_hub": { + "type": "boolean", + "description": "Auto-push dataset to HuggingFace Hub on stop_recording" + }, + "vcodec": { + "type": "string", + "description": "Video codec for dataset recording (h264, hevc, libsvtav1)" + }, + "task": { + "type": "string", + "description": "Task description for dataset recording" + }, + "episode": { + "type": "integer", + "description": "Episode index for replay_episode" + }, + "root": { + "type": "string", + "description": "Local dataset root directory" + }, + "speed": { + "type": "number", + "description": "Replay speed multiplier (1.0 = original)" + }, + "n_episodes": { + "type": "integer", + "description": "Number of eval episodes" + }, + "max_steps": { + "type": "integer", + "description": "Max steps per eval episode" + }, + "success_fn": { + "type": "string", + "description": "Success function ('contact')" + }, + "fast_mode": { + "type": "boolean", + "description": "Skip sleep between actions for faster data collection" + }, + "body_name": { + "type": "string", + "description": "Target body name" + }, + "site_name": { + "type": "string", + "description": "Site name for Jacobian" + }, + "geom_name": { + "type": "string", + "description": "Geom name" + }, + "geom_id": { + "type": "integer", + "description": "Geom ID (alternative to geom_name)" + }, + "force": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Force vector [fx, fy, fz] in Newtons" + }, + "torque_vec": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Torque vector [tx, ty, tz] in N\u00b7m" + }, + "point": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Point of force application [x, y, z]" + }, + "origin": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Ray origin [x, y, z]" + }, + "direction": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Ray direction [dx, dy, dz]" + }, + "directions": { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "number" + } + }, + "description": "Multiple ray directions for multi_raycast" + }, + "exclude_body": { + "type": "integer", + "description": "Body ID to exclude from raycast (-1=none)" + }, + "include_static": { + "type": "boolean", + "description": "Include static geoms in raycast" + }, + "positions": { + "type": "object", + "description": "Joint name \u2192 position mapping for set_joint_positions" + }, + "velocities": { + "type": "object", + "description": "Joint name \u2192 velocity mapping for set_joint_velocities" + }, + "sensor_name": { + "type": "string", + "description": "Specific sensor name (or omit for all)" + }, + "checkpoint_name": { + "type": "string", + "description": "Named checkpoint for save_state/load_state" + } + }, + "required": [ + "action" + ] +} \ No newline at end of file diff --git a/strands_robots/simulation/mujoco/types.py b/strands_robots/simulation/mujoco/types.py new file mode 100644 index 0000000..f8d1a59 --- /dev/null +++ b/strands_robots/simulation/mujoco/types.py @@ -0,0 +1,36 @@ +"""Shared type declarations for MuJoCo simulation mixins. + +Defines the SimulationProtocol that all mixins can reference instead of +duplicating TYPE_CHECKING stubs for cross-mixin method signatures. +""" + +from __future__ import annotations + +import threading +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Any, Protocol, runtime_checkable + +from strands_robots.simulation.models import SimWorld + + +@runtime_checkable +class SimulationProtocol(Protocol): + """Protocol describing the shared state and methods available across all mixins. + + Each mixin operates on a Simulation instance that provides this interface. + Using a Protocol avoids duplicating private method stubs in TYPE_CHECKING blocks. + """ + + _world: SimWorld | None + _lock: threading.Lock + _executor: ThreadPoolExecutor + _policy_threads: dict[str, Future[Any]] + _mj: Any # The lazily-imported mujoco module + _renderer_model: Any + _renderers: dict[tuple[int, int], Any] + default_width: int + default_height: int + + def _get_renderer(self, width: int, height: int) -> Any: ... + def _get_sim_observation(self, robot_name: str, cam_name: str | None = None) -> dict[str, Any]: ... + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: ... diff --git a/tests/test_mujoco_e2e.py b/tests/test_mujoco_e2e.py new file mode 100644 index 0000000..4dd8fbc --- /dev/null +++ b/tests/test_mujoco_e2e.py @@ -0,0 +1,316 @@ +"""End-to-end MuJoCo simulation test with Policy ABC. + +Tests the full observe → policy → act → step → render pipeline +without requiring strands SDK or lerobot — just mujoco + numpy. + +Run: python -m pytest tests/test_mujoco_e2e.py -v +""" + +import asyncio +import os +import shutil +import tempfile + +import numpy as np +import pytest + +# Skip entire module if mujoco not installed +mj = pytest.importorskip("mujoco") + + +def _has_opengl() -> bool: + """Check if OpenGL rendering is available.""" + try: + model = mj.MjModel.from_xml_string("") + renderer = mj.Renderer(model, height=1, width=1) + del renderer + return True + except Exception: + return False + + +requires_gl = pytest.mark.skipif( + not _has_opengl(), + reason="No OpenGL context available (headless environment without EGL/OSMesa)", +) + + +from strands_robots.policies import MockPolicy # noqa: E402 +from strands_robots.simulation.base import SimEngine # noqa: E402 +from strands_robots.simulation.models import SimObject, SimRobot, SimStatus, SimWorld # noqa: E402 + +# ── Fixtures ── + +ROBOT_XML = """ + + + +""" + + +@pytest.fixture +def sim_env(): + """Create a MuJoCo model+data from test XML.""" + tmpdir = tempfile.mkdtemp() + xml_path = os.path.join(tmpdir, "test_arm.xml") + with open(xml_path, "w") as f: + f.write(ROBOT_XML) + + model = mj.MjModel.from_xml_path(xml_path) + data = mj.MjData(model) + + yield model, data + + shutil.rmtree(tmpdir, ignore_errors=True) + + +JOINT_NAMES = ["shoulder_pan", "shoulder_lift", "elbow"] + + +def read_joints(model, data): + obs = {} + for jname in JOINT_NAMES: + jid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jname) + obs[jname] = float(data.qpos[model.jnt_qposadr[jid]]) + return obs + + +def apply_action(model, data, action_dict): + for key, val in action_dict.items(): + act_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_ACTUATOR, f"{key}_act") + if act_id >= 0: + data.ctrl[act_id] = val + + +# ── Tests ── + + +class TestSimulationBase: + def test_abc_has_required_methods(self): + required = [ + "create_world", + "destroy", + "reset", + "step", + "get_state", + "add_robot", + "remove_robot", + "add_object", + "remove_object", + "get_observation", + "send_action", + "render", + ] + for method in required: + assert hasattr(SimEngine, method) + + def test_shared_dataclasses(self): + w = SimWorld() + assert w.timestep == 0.002 + assert w.gravity == [0.0, 0.0, -9.81] + assert w.status == SimStatus.IDLE + + r = SimRobot(name="test", urdf_path="/tmp/test.urdf") + assert r.joint_names == [] + + o = SimObject(name="cube", shape="box") + assert o.mass == 0.1 + + +class TestMuJoCoPhysics: + def test_step_advances_time(self, sim_env): + model, data = sim_env + assert data.time == 0.0 + for _ in range(100): + mj.mj_step(model, data) + assert data.time == pytest.approx(0.2, abs=1e-6) + + def test_position_actuators_move_joints(self, sim_env): + model, data = sim_env + data.ctrl[0] = 1.0 # shoulder_pan target + for _ in range(1000): + mj.mj_step(model, data) + obs = read_joints(model, data) + assert abs(obs["shoulder_pan"] - 1.0) < 0.15 + + def test_contacts_detected(self, sim_env): + model, data = sim_env + for _ in range(100): + mj.mj_step(model, data) + assert data.ncon > 0 # cube on ground + + def test_reset_zeros_time(self, sim_env): + model, data = sim_env + for _ in range(100): + mj.mj_step(model, data) + mj.mj_resetData(model, data) + assert data.time == 0.0 + + +@requires_gl +class TestMuJoCoRendering: + def test_render_rgb(self, sim_env): + model, data = sim_env + mj.mj_forward(model, data) + renderer = mj.Renderer(model, height=240, width=320) + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, "front") + renderer.update_scene(data, camera=cam_id) + img = renderer.render() + assert img.shape == (240, 320, 3) + assert img.dtype == np.uint8 + assert img.max() > 0 + del renderer + + def test_render_depth(self, sim_env): + model, data = sim_env + mj.mj_forward(model, data) + renderer = mj.Renderer(model, height=120, width=160) + renderer.update_scene(data) + renderer.enable_depth_rendering() + depth = renderer.render() + renderer.disable_depth_rendering() + assert depth.shape == (120, 160) + assert depth.max() > 0 + del renderer + + +class TestMockPolicyLoop: + def test_mock_policy_generates_actions(self): + policy = MockPolicy() + policy.set_robot_state_keys(JOINT_NAMES) + obs = {j: 0.0 for j in JOINT_NAMES} + actions = asyncio.run(policy.get_actions(obs, "test")) + assert len(actions) == 8 + assert all(j in actions[0] for j in JOINT_NAMES) + + def test_full_observe_act_loop(self, sim_env): + model, data = sim_env + policy = MockPolicy() + policy.set_robot_state_keys(JOINT_NAMES) + + for step in range(20): + obs = read_joints(model, data) + actions = asyncio.run(policy.get_actions(obs, "pick up cube")) + apply_action(model, data, actions[0]) + mj.mj_step(model, data) + + assert data.time > 0 + final_obs = read_joints(model, data) + # Joints should have moved from 0 + assert any(abs(v) > 0.001 for v in final_obs.values()) + + @requires_gl + def test_loop_with_rendering(self, sim_env): + """Full loop: observe → policy → act → step → render (10 iterations).""" + model, data = sim_env + policy = MockPolicy() + policy.set_robot_state_keys(JOINT_NAMES) + renderer = mj.Renderer(model, height=120, width=160) + + frames = [] + for _ in range(10): + obs = read_joints(model, data) + actions = asyncio.run(policy.get_actions(obs, "wave")) + apply_action(model, data, actions[0]) + mj.mj_step(model, data) + + renderer.update_scene(data) + frames.append(renderer.render().copy()) + + assert len(frames) == 10 + assert all(f.shape == (120, 160, 3) for f in frames) + # Frames should differ (robot is moving) + assert not np.array_equal(frames[0], frames[-1]) + del renderer + + +class TestDomainRandomization: + def test_color_randomization(self, sim_env): + model, data = sim_env + orig = model.geom_rgba.copy() + rng = np.random.default_rng(42) + for i in range(model.ngeom): + model.geom_rgba[i, :3] = rng.uniform(0.1, 1.0, size=3) + assert not np.array_equal(orig, model.geom_rgba) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + + +class TestToolSpecActionCoverage: + """Verify every action enum in tool_spec.json maps to a real method on Simulation.""" + + def test_all_actions_have_methods(self): + """Every action in tool_spec.json must resolve to a method on Simulation.""" + import json + from pathlib import Path + + from strands_robots.simulation.mujoco.simulation import Simulation + + spec_path = Path(__file__).parent.parent / "strands_robots" / "simulation" / "mujoco" / "tool_spec.json" + with open(spec_path) as f: + spec = json.load(f) + + actions = spec["properties"]["action"]["enum"] + assert len(actions) > 0, "tool_spec.json should have at least one action" + + # Aliases used by _dispatch_action + aliases = { + "list_urdfs": "list_urdfs_action", + "register_urdf": "register_urdf_action", + "stop_policy": "_stop_policy", + } + + missing = [] + for action in actions: + method_name = aliases.get(action, action) + if not hasattr(Simulation, method_name): + missing.append(f"{action} (looked for method '{method_name}')") + + assert not missing, "tool_spec.json actions with no matching Simulation method:\n" + "\n".join( + f" - {m}" for m in missing + ) + + def test_action_enum_is_not_empty(self): + """Sanity: tool_spec.json action enum is populated.""" + import json + from pathlib import Path + + spec_path = Path(__file__).parent.parent / "strands_robots" / "simulation" / "mujoco" / "tool_spec.json" + with open(spec_path) as f: + spec = json.load(f) + + actions = spec["properties"]["action"]["enum"] + assert len(actions) >= 30, f"Expected ≥30 actions, got {len(actions)}" diff --git a/tests/test_physics.py b/tests/test_physics.py new file mode 100644 index 0000000..17e03a2 --- /dev/null +++ b/tests/test_physics.py @@ -0,0 +1,350 @@ +"""Tests for PhysicsMixin — advanced MuJoCo physics features. + +Tests: raycasting, jacobians, energy, forces, state checkpointing, +inverse dynamics, sensor readout, body introspection, runtime modification. + +Run: uv run pytest tests/test_physics.py -v +""" + +import json +import os + +import numpy as np +import pytest + +mj = pytest.importorskip("mujoco") + +from strands_robots.simulation.mujoco.simulation import Simulation # noqa: E402 + +ROBOT_XML = """ + + + +""" + + +@pytest.fixture +def sim(): + """Create a Simulation with the test scene loaded directly.""" + from strands_robots.simulation.models import SimStatus, SimWorld + + s = Simulation(tool_name="test_sim", mesh=False) + s._world = SimWorld() + s._world._model = mj.MjModel.from_xml_string(ROBOT_XML) + s._world._data = mj.MjData(s._world._model) + s._world.status = SimStatus.IDLE + mj.mj_forward(s._world._model, s._world._data) + yield s + s.cleanup() + + +class TestRaycasting: + def test_raycast_hits_ground(self, sim): + result = sim.raycast(origin=[0, 0, 2], direction=[0, 0, -1]) + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert data["hit"] is True + assert data["distance"] is not None + assert data["distance"] > 0 + + def test_raycast_hits_box(self, sim): + result = sim.raycast(origin=[0, 0, 2], direction=[0, 0, -1]) + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert data["hit"] is True + assert data["geom_name"] in ("box_geom", "ground") + + def test_raycast_misses(self, sim): + result = sim.raycast(origin=[0, 0, 2], direction=[0, 0, 1]) # shooting up + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert data["hit"] is False + + def test_multi_raycast(self, sim): + dirs = [[0, 0, -1], [1, 0, 0], [0, 1, 0], [0, 0, 1]] + result = sim.multi_raycast(origin=[0, 0, 2], directions=dirs) + assert result["status"] == "success" + rays = json.loads(result["content"][1]["text"])["rays"] + assert len(rays) == 4 + # At least the downward ray should hit + assert rays[0]["distance"] is not None + + +class TestJacobians: + def test_body_jacobian(self, sim): + result = sim.get_jacobian(body_name="link2") + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert len(data["jacp"]) == 3 # 3×nv + assert data["nv"] == sim._world._model.nv + + def test_site_jacobian(self, sim): + result = sim.get_jacobian(site_name="end_effector") + assert result["status"] == "success" + + def test_geom_jacobian(self, sim): + result = sim.get_jacobian(geom_name="link2_geom") + assert result["status"] == "success" + + def test_jacobian_no_target(self, sim): + result = sim.get_jacobian() + assert result["status"] == "error" + + def test_jacobian_invalid_body(self, sim): + result = sim.get_jacobian(body_name="nonexistent") + assert result["status"] == "error" + + +class TestEnergy: + def test_get_energy(self, sim): + result = sim.get_energy() + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert "potential" in data + assert "kinetic" in data + assert "total" in data + # Box at height 0.5 should have nonzero potential energy + assert data["potential"] != 0 or data["kinetic"] != 0 + + def test_energy_changes_after_step(self, sim): + e1 = json.loads(sim.get_energy()["content"][1]["text"]) + # Step physics to let box fall + for _ in range(100): + mj.mj_step(sim._world._model, sim._world._data) + e2 = json.loads(sim.get_energy()["content"][1]["text"]) + # Kinetic energy should change (box falls) + assert e1["kinetic"] != e2["kinetic"] or e1["potential"] != e2["potential"] + + +class TestExternalForces: + def test_apply_force(self, sim): + result = sim.apply_force(body_name="box1", force=[0, 0, 100]) + assert result["status"] == "success" + assert "box1" in result["content"][0]["text"] + + def test_apply_force_invalid_body(self, sim): + result = sim.apply_force(body_name="nonexistent", force=[0, 0, 10]) + assert result["status"] == "error" + + def test_force_changes_acceleration(self, sim): + # Get initial state + data = sim._world._data + old_qfrc = data.qfrc_applied.copy() + sim.apply_force(body_name="box1", force=[0, 0, 100]) + # qfrc_applied should change + assert not np.array_equal(old_qfrc, data.qfrc_applied) + + +class TestMassMatrix: + def test_get_mass_matrix(self, sim): + result = sim.get_mass_matrix() + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + nv = sim._world._model.nv + assert data["shape"] == [nv, nv] + assert data["rank"] > 0 + assert data["total_mass"] > 0 + + def test_mass_diagonal_positive(self, sim): + result = sim.get_mass_matrix() + diag = json.loads(result["content"][1]["text"])["diagonal"] + assert all(d >= 0 for d in diag) + + +class TestStateCheckpointing: + def test_save_and_load_state(self, sim): + # Set a known joint position + sim._world._data.qpos[7] = 1.0 # shoulder + mj.mj_forward(sim._world._model, sim._world._data) + + # Save + result = sim.save_state(name="test_checkpoint") + assert result["status"] == "success" + + # Change state + sim._world._data.qpos[7] = -1.0 + mj.mj_forward(sim._world._model, sim._world._data) + assert sim._world._data.qpos[7] == pytest.approx(-1.0) + + # Restore + result = sim.load_state(name="test_checkpoint") + assert result["status"] == "success" + assert sim._world._data.qpos[7] == pytest.approx(1.0) + + def test_load_nonexistent_checkpoint(self, sim): + result = sim.load_state(name="doesnt_exist") + assert result["status"] == "error" + + +class TestInverseDynamics: + def test_inverse_dynamics(self, sim): + mj.mj_forward(sim._world._model, sim._world._data) + result = sim.inverse_dynamics() + assert result["status"] == "success" + forces = json.loads(result["content"][1]["text"])["qfrc_inverse"] + assert "shoulder" in forces or "elbow" in forces + + +class TestBodyState: + def test_get_body_state(self, sim): + result = sim.get_body_state(body_name="box1") + assert result["status"] == "success" + state = json.loads(result["content"][1]["text"]) + assert "position" in state + assert "quaternion" in state + assert "linear_velocity" in state + assert "angular_velocity" in state + assert "mass" in state + assert len(state["position"]) == 3 + assert len(state["quaternion"]) == 4 + assert state["mass"] == pytest.approx(1.0) + + def test_body_state_invalid(self, sim): + result = sim.get_body_state(body_name="nonexistent") + assert result["status"] == "error" + + +class TestDirectJointControl: + def test_set_joint_positions(self, sim): + result = sim.set_joint_positions(positions={"shoulder": 0.5, "elbow": -0.3}) + assert result["status"] == "success" + assert "2/2" in result["content"][0]["text"] + + # Verify positions were set + model, data = sim._world._model, sim._world._data + shoulder_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, "shoulder") + qpos_adr = model.jnt_qposadr[shoulder_id] + assert data.qpos[qpos_adr] == pytest.approx(0.5) + + def test_set_joint_velocities(self, sim): + result = sim.set_joint_velocities(velocities={"shoulder": 1.0}) + assert result["status"] == "success" + + +class TestSensors: + def test_get_all_sensors(self, sim): + result = sim.get_sensor_data() + assert result["status"] == "success" + sensors = json.loads(result["content"][1]["text"])["sensors"] + assert "shoulder_pos" in sensors + assert "elbow_pos" in sensors + + def test_get_specific_sensor(self, sim): + result = sim.get_sensor_data(sensor_name="shoulder_pos") + assert result["status"] == "success" + sensors = json.loads(result["content"][1]["text"])["sensors"] + assert len(sensors) == 1 + assert "shoulder_pos" in sensors + + def test_sensor_values_change(self, sim): + # Set shoulder position + sim.set_joint_positions(positions={"shoulder": 1.0}) + result = sim.get_sensor_data(sensor_name="shoulder_pos") + val = json.loads(result["content"][1]["text"])["sensors"]["shoulder_pos"]["values"] + assert abs(val - 1.0) < 0.01 + + +class TestRuntimeModification: + def test_set_body_mass(self, sim): + result = sim.set_body_properties(body_name="box1", mass=5.0) + assert result["status"] == "success" + body_id = mj.mj_name2id(sim._world._model, mj.mjtObj.mjOBJ_BODY, "box1") + assert sim._world._model.body_mass[body_id] == pytest.approx(5.0) + + def test_set_geom_color(self, sim): + result = sim.set_geom_properties(geom_name="box_geom", color=[0, 1, 0, 1]) + assert result["status"] == "success" + geom_id = mj.mj_name2id(sim._world._model, mj.mjtObj.mjOBJ_GEOM, "box_geom") + assert sim._world._model.geom_rgba[geom_id][1] == pytest.approx(1.0) + + def test_set_geom_friction(self, sim): + result = sim.set_geom_properties(geom_name="box_geom", friction=[0.5, 0.01, 0.001]) + assert result["status"] == "success" + + def test_invalid_geom(self, sim): + result = sim.set_geom_properties(geom_name="nonexistent", color=[1, 0, 0, 1]) + assert result["status"] == "error" + + +class TestContactForces: + def test_get_contact_forces_after_settling(self, sim): + # Let box fall and settle + for _ in range(500): + mj.mj_step(sim._world._model, sim._world._data) + result = sim.get_contact_forces() + assert result["status"] == "success" + # Box should be in contact with ground + contacts = json.loads(result["content"][1]["text"])["contacts"] + assert len(contacts) > 0 + assert contacts[0]["normal_force"] != 0 + + +class TestForwardKinematics: + def test_forward_kinematics(self, sim): + result = sim.forward_kinematics() + assert result["status"] == "success" + bodies = json.loads(result["content"][1]["text"])["bodies"] + assert "box1" in bodies + assert "link1" in bodies + assert len(bodies["box1"]["position"]) == 3 + + +class TestTotalMass: + def test_get_total_mass(self, sim): + result = sim.get_total_mass() + assert result["status"] == "success" + data = json.loads(result["content"][1]["text"]) + assert data["total_mass"] > 0 + assert "box1" in data["bodies"] + assert data["bodies"]["box1"] == pytest.approx(1.0) + + +class TestExportXML: + def test_export_xml_string(self, sim): + result = sim.export_xml() + assert result["status"] == "success" + text = result["content"][0]["text"] + assert "mujoco" in text.lower() or "Model XML" in text + + def test_export_xml_file(self, sim, tmp_path): + path = str(tmp_path / "exported.xml") + result = sim.export_xml(output_path=path) + assert result["status"] == "success" + assert os.path.exists(path) + with open(path) as f: + content = f.read() + assert " 0 + names = [r["name"] for r in robots] + assert "so100" in names + assert "panda" in names + + def test_list_sim(self): + robots = list_robots("sim") + for r in robots: + assert r["has_sim"] is True + + def test_list_real(self): + robots = list_robots("real") + for r in robots: + assert r["has_real"] is True + + def test_list_both(self): + robots = list_robots("both") + for r in robots: + assert r["has_sim"] is True + assert r["has_real"] is True + + def test_robot_has_fields(self): + robots = list_robots() + for r in robots: + assert "name" in r + assert "description" in r + assert "has_sim" in r + assert "has_real" in r + + +class TestRobotRegistry: + def test_so100_exists(self): + info = get_robot("so100") + assert info is not None + assert "asset" in info + assert info["asset"]["dir"] == "trs_so_arm100" + + def test_all_aliases_point_to_valid_robots(self): + aliases = list_aliases() + for alias, canonical in aliases.items(): + info = get_robot(canonical) + assert info is not None, f"Alias '{alias}' points to unknown robot '{canonical}'" + + def test_robot_count(self): + """Ensure we have a reasonable number of robots.""" + robots = list_robots() + assert len(robots) >= 30 + + def test_all_robots_have_description(self): + robots = list_robots() + for r in robots: + assert "description" in r, f"Robot '{r['name']}' missing description" + assert len(r["description"]) > 0 + + +class TestAutoDetectMode: + def test_defaults_to_sim(self): + """No hardware plugged in → sim.""" + assert _auto_detect_mode("so100") == "sim" + + def test_env_override_real(self): + os.environ["STRANDS_ROBOT_MODE"] = "real" + try: + assert _auto_detect_mode("so100") == "real" + finally: + del os.environ["STRANDS_ROBOT_MODE"] + + def test_env_override_sim(self): + os.environ["STRANDS_ROBOT_MODE"] = "sim" + try: + assert _auto_detect_mode("so100") == "sim" + finally: + del os.environ["STRANDS_ROBOT_MODE"] + + def test_env_override_case_insensitive(self): + os.environ["STRANDS_ROBOT_MODE"] = "REAL" + try: + mode = _auto_detect_mode("so100") + assert mode == "real" + finally: + del os.environ["STRANDS_ROBOT_MODE"] + + +class TestRobotFactory: + def test_robot_is_callable(self): + """Robot is a factory function, not a class.""" + import inspect + + assert callable(Robot) + assert not inspect.isclass(Robot) + + def test_default_mode_is_sim(self): + """Robot() defaults to sim mode — never accidentally sends to hardware.""" + import inspect + + sig = inspect.signature(Robot) + assert sig.parameters["mode"].default == "sim" + + def test_unknown_backend_raises(self): + with pytest.raises(NotImplementedError, match="not yet implemented"): + Robot("so100", mode="sim", backend="isaac") + + def test_newton_not_implemented(self): + with pytest.raises(NotImplementedError, match="not yet implemented"): + Robot("so100", mode="sim", backend="newton") + + def test_invalid_mode_raises(self): + with pytest.raises(ValueError, match="Invalid mode"): + Robot("so100", mode="invalid") + + def test_sim_with_urdf_path(self): + """Robot() with explicit urdf_path should work (if file exists).""" + pytest.importorskip("mujoco") + with pytest.raises(RuntimeError): + Robot("test_bot", mode="sim", urdf_path="/nonexistent/robot.xml") + + def test_sim_happy_path_mujoco(self, tmp_path): + """Happy-path: create a MuJoCo sim, step physics, destroy. + + Uses a minimal inline MJCF so the test works without downloaded assets. + """ + mujoco = pytest.importorskip("mujoco") + + mjcf_xml = """ + + + + + + + + + + + + + + + + + """ + mjcf_path = tmp_path / "test_arm.xml" + mjcf_path.write_text(mjcf_xml) + + sim = Robot("so100", mode="sim", backend="mujoco", urdf_path=str(mjcf_path)) + try: + assert sim._world is not None + assert sim._world._model is not None + assert sim._world._data is not None + mujoco.mj_step(sim._world._model, sim._world._data) + assert sim._world._data.time > 0 + finally: + sim.destroy() + + def test_import_from_top_level(self): + """Robot and list_robots importable from strands_robots.""" + from strands_robots import Robot as R + from strands_robots import list_robots as lr + + assert R is Robot + assert callable(lr)